#! /usr/bin/env python3 # from sklearn import * def six_points ( ): #*****************************************************************************80 # ## six_points is an artificial logistic regression example using six points. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 27 January 2020 # # Author: # # John Burkardt # import numpy as np import matplotlib.pyplot as plt import platform import sklearn from sklearn import linear_model print ( '' ) print ( 'six_points:' ) print ( ' python version: %s' % ( platform.python_version ( ) ) ) print ( ' scikit-learn version is %s' % ( sklearn.__version__ ) ) print ( ' Use sklearn LogisticRegression() on a six data point example.' ) print ( '' ) x = [ 1, 2, 3, 10,11, 12 ] y = [ 0, 0, 0, 1, 1, 1 ] xx = np.reshape ( x, [6,1] ) lr = linear_model.LogisticRegression ( ) lr.fit ( xx, y ) lr_coef = np.ndarray.flatten ( lr.coef_ ) lr_intercept = np.ndarray.flatten ( lr.intercept_ ) print ( lr_coef, lr_intercept ) cutoff = - lr_intercept / lr_coef print ( "cutoff", cutoff ) plt.scatter ( x, y ) x_eval = np.linspace ( -10, 20, 100 ) y_eval = logistic_fun ( x_eval * lr_coef + lr_intercept ) plt.plot ( x_eval, y_eval ) plt.axvline ( x = cutoff, linestyle = '--', color = 'red' ) plt.title ( 'six_points logistic regression example' ) plt.grid ( True ) filename = 'six_points.png' plt.savefig ( filename ) print ( '' ) print ( ' Graphics saved as "%s"' % ( filename ) ) plt.show ( ) print ( '' ) print ( 'six_points' ) print ( ' Normal end of execution.' ) return def logistic_fun ( x ): #*****************************************************************************80 # ## logistic_fun evaluates the logistic function. # # Licensing: # # This code is distributed under the MIT license. # # Modified: # # 08 January 2020 # # Author: # # John Burkardt # # Input: # # real x, the arguments. # # Output: # # real value, the logistic function evaluated at x. # import numpy as np value = 1.0 / ( 1.0 + np.exp ( - x ) ); return value if ( __name__ == '__main__' ): six_points ( )