#! /usr/bin/env python3 # def moon_classify_forest ( ): #*****************************************************************************80 # ## moon_classify_forest() uses a random forest to classify the moon data. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 19 July 2023 # # Author: # # Andreas Mueller, Sarah Guido. # Modifications by John Burkardt. # # Reference: # # Andreas Mueller, Sarah Guido, # Introduction to Machine Learning with Python, # OReilly, 2017, # ISBN: 978-1-449-36941-5 # from sklearn.datasets import make_moons from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt import mglearn import numpy as np import platform import sklearn print ( '' ) print ( 'moon_classify_forest():' ) print ( ' Python version: ' + platform.python_version ( ) ) print ( ' scikit-learn version: '+ sklearn.__version__ ) print ( ' Generate 100 samples of the artificial moon dataset.' ) print ( ' Classify the data using a random forest.' ) print ( '' ) X, y = make_moons ( n_samples = 100, noise = 0.25, random_state = 3 ) X_train, X_test, y_train, y_test = train_test_split ( \ X, y, stratify = y, random_state = 42 ) forest = RandomForestClassifier ( n_estimators = 5, random_state = 2 ) forest.fit ( X_train, y_train ) plt.clf ( ) fig, axes = plt.subplots ( 2, 3, figsize = ( 20, 10 ) ) for i, ( ax, tree ) in enumerate ( zip ( axes.ravel(), forest.estimators_ ) ): ax.set_title ( "Tree {}".format(i) ) mglearn.plots.plot_tree_partition ( X_train, y_train, tree, ax = ax ) mglearn.plots.plot_2d_separator ( forest, X_train, fill = True, ax = axes [ -1, -1 ], alpha = 0.4 ) axes[-1,-1].set_title ( "Random Forest" ) mglearn.discrete_scatter ( X_train[:,0], X_train[:,1], y_train ) filename = 'moon_classify_forest.png' plt.savefig ( filename ) print ( ' Graphics saved as "' + filename + '"' ) # # Terminate. # print ( '' ) print ( 'moon_classify_forest():' ) print ( ' Normal end of execution.' ) return def timestamp ( ): #*****************************************************************************80 # ## timestamp() prints the date as a timestamp. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 21 August 2019 # # Author: # # John Burkardt # import time t = time.time ( ) print ( time.ctime ( t ) ) return if ( __name__ == '__main__' ): timestamp ( ) moon_classify_forest ( ) timestamp ( )