#! /usr/bin/env python3 # def apple ( ): #*****************************************************************************80 # ## apple() classifies fruit as oranges(0) or apples(1). # # Licensing: # # This code is distributed under the GNU LGPL license. # # Modified: # # 13 February 2022 # # Author: # # John Burkardt # import matplotlib.pyplot as plt import numpy as np from logistic_regression import logistic_regression print ( 'apple():' ) # # Read the data file. # data = np.loadtxt ( 'apple_data.txt' ) # # First column is weight in grams. # Second column is 0 for orange, 1 for apple. # g = data[:,0] y = data[:,1] m = len ( g ) # # Normalize the G data. # gmin = np.min ( g ) gmax = np.max ( g ) gn = ( g - np.min ( g ) ) / ( np.max ( g ) - np.min ( g ) ) # # Create a data array X = [ 1 | g ]. # X = np.zeros ( [ m, 2 ] ) X[:,0] = 1 X[:,1] = gn # # Set the learning rate ALPHA # alpha = 0.02 # # Set the number of iterations KMAX. # kmax = 100000 # # Compute logistic weights W. # wn = logistic_regression ( X, y, alpha, kmax ) print ( '' ) print ( ' Weights for normalized data Wn = (%g,%g)' % ( wn[0], wn[1] ) ) w = np.array ( [ wn[0] - wn[1] * gmin / ( gmax - gmin ), wn[1] / ( gmax - gmin ) ] ) print ( '' ) print ( ' Weights for original data W = (%g,%g)' % ( w[0], w[1] ) ) cutoff = - w[0] / w[1] print ( ' Cutoff value is ', cutoff ) # # Display the data. # gplot = np.linspace ( gmin, gmax, 101 ) yplot = 1.0 / ( 1.0 + np.exp ( - ( w[0] + w[1] * gplot ) ) ) plt.plot ( g[y==0], y[y==0], 'mo' ) plt.plot ( g[y==1], y[y==1], 'go' ) plt.plot ( gplot, yplot, linewidth = 3 ) plt.plot ( [cutoff,cutoff], [0,1], '--', color = 'r' ) plt.grid ( True ) plt.xlabel ( '<-- G: Weight in grams -->' ) plt.ylabel ( '<-- Y(G): Orange = 0, Apple = 1 -->' ) plt.title ( 'Orange/Apple classified by weight' ) filename = 'apple.png' print ( ' Graphics saved as "%s"' % ( filename ) ) plt.savefig ( filename ) plt.show ( ) plt.close ( ) # # Terminate. # print ( '' ) print ( 'apple():' ) print ( ' Normal end of execution.' ) return if ( __name__ == '__main__' ): apple ( )