#! /usr/bin/env python3 # def cancer_classify_decision ( ): #*****************************************************************************80 # ## cancer_classify_decision() uses a decision tree to classify cancer data. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 18 July 2023 # # Author: # # Andreas Mueller, Sarah Guido. # Modifications by John Burkardt. # # Reference: # # Andreas Mueller, Sarah Guido, # Introduction to Machine Learning with Python, # OReilly, 2017, # ISBN: 978-1-449-36941-5 # from sklearn.tree import export_graphviz import graphviz import matplotlib.pyplot as plt import mglearn import numpy as np import pandas as pd import platform import sklearn print ( '' ) print ( 'cancer_classify_decision():' ) print ( ' Python version: ' + platform.python_version ( ) ) print ( ' scikit-learn version: '+ sklearn.__version__ ) # # Generate the dataset. # print ( '' ) print ( ' Retrieve the cancer dataset, (X, y).' ) from sklearn.datasets import load_breast_cancer cancer = load_breast_cancer ( ) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split ( \ cancer.data, cancer.target, stratify = cancer.target, random_state = 42 ) # # Compute training and testing accuracy for varying number of neighbors. # from sklearn.tree import DecisionTreeClassifier # # Create a tree with unlimited depth. # tree = DecisionTreeClassifier ( random_state = 0 ) tree.fit ( X_train, y_train ) print ( '' ) print ( ' Create an unpruned decision tree:' ) print ( ' Accuracy on training set = ', tree.score ( X_train, y_train ) ) print ( ' Accuracy on test set = ', tree.score ( X_test, y_test ) ) # # Visualize the tree. # from sklearn.tree import export_graphviz export_graphviz ( tree, out_file = 'tree.dot', class_names = [ 'malignant', 'benign' ], feature_names = cancer.feature_names, impurity = False, filled = True ) with open ( 'tree.dot' ) as f: dot_graph = f.read ( ) s = graphviz.Source ( dot_graph, filename = 'unpruned_tree', format = 'png' ) s.view ( ) print ( ' Graphics saved as "unpruned_tree.png"' ) # # Create a tree with limited depth. # tree = DecisionTreeClassifier ( max_depth = 4, random_state = 0 ) tree.fit ( X_train, y_train ) print ( '' ) print ( ' Create a pruned decision tree (max depth = 4 ):' ) print ( ' Accuracy on training set = ', tree.score ( X_train, y_train ) ) print ( ' Accuracy on test set = ', tree.score ( X_test, y_test ) ) # # Visualize the tree. # export_graphviz ( tree, out_file = 'tree.dot', class_names = [ 'malignant', 'benign' ], feature_names = cancer.feature_names, impurity = False, filled = True ) with open ( 'tree.dot' ) as f: dot_graph = f.read ( ) s = graphviz.Source ( dot_graph, filename = 'pruned_tree', format = 'png' ) s.view ( ) print ( ' Graphics saved as "pruned_tree.png"' ) # # Report feature importance. # print ( '' ) print ( ' Feature importance:' ) print ( tree.feature_importances_ ) plot_feature_importances_cancer ( cancer, tree ) # # Terminate. # print ( '' ) print ( 'cancer_classify_decision():' ) print ( ' Normal end of execution.' ) return def plot_feature_importances_cancer ( cancer, model ): #*****************************************************************************80 # ## plot_feature_importances_cancer() ... # import matplotlib.pyplot as plt import numpy as np n_features = cancer.data.shape[1] plt.barh ( np.arange ( n_features ), model.feature_importances_, align = 'center' ) plt.yticks ( np.arange ( n_features ), cancer.feature_names ) plt.xlabel ( 'Feature importance' ) plt.ylabel ( 'Feature' ) plt.ylim ( -1, n_features ) filename = 'feature_importance.png' plt.savefig ( filename ) print ( ' Graphics saved as ' + filename + '"' ) return def timestamp ( ): #*****************************************************************************80 # ## timestamp() prints the date as a timestamp. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 21 August 2019 # # Author: # # John Burkardt # import time t = time.time ( ) print ( time.ctime ( t ) ) return if ( __name__ == '__main__' ): timestamp ( ) cancer_classify_decision ( ) timestamp ( )