#! /usr/bin/env python3
#
def exercise3():

#*****************************************************************************80
#
## exercise3() processes the Ruspini dataset.
#
#  Discussion:
#
#    The data is stored as a text file, with spaces used as the delimiter.
#
#  Licensing:
#
#    This code is distributed under the MIT license. 
#
#  Modified:
#
#    24 March 2025
#
#  Author:
#
#    John Burkardt
#
  from scipy.cluster.vq import kmeans2
  import matplotlib.pyplot as plt
  import numpy as np

  print ( "exercise3():" )
  print ( "  Process the Ruspini data." )
#
#  Read the data.
#
  datafile = 'ruspini_data.txt'
  data = np.loadtxt ( datafile )
  rows, cols = np.shape ( data )
  print ( '' )
  print ( '  "' + datafile + '" contains', rows, 'rows and', cols, 'columns.' )
#
#  Print the first five lines.
#
  print ( '' )
  print ( '  First five lines of data:' )
  print ( '' )
  print ( data[0:5,:] )
#
#  Statistics.
#
  print ( '' )
  print ( '  Statistics for data:' )
  print ( '' )
  print ( '    data.shape:     ', data.shape )
  print ( '    np.min(data,axis=0):  ', np.min ( data, axis = 0 ) )
  print ( '    np.mean(data,axis=0): ', np.mean ( data, axis = 0 ) )
  print ( '    np.max(data,axis=0):  ', np.max ( data, axis = 0 ) )
  print ( '    np.std(data,axis=0):  ', np.std ( data, axis = 0 ) )
  print ( '    np.var(data,axis=0):  ', np.var ( data, axis = 0 ) )
#
#  Standardize.
#
  data = ( data - np.mean ( data, axis = 0 ) ) / np.std ( data, axis = 0 )
#
#  Scatter plot.
#
  plt.clf ( )
  plt.scatter ( data[:,0], data[:,1] )
  r = np.sqrt ( np.sum ( np.var ( data[:,:], axis = 0 ) ) )
  tc = np.linspace ( 0, 2.0 * np.pi, 51 )
  xc = r * np.cos ( tc )
  yc = r * np.sin ( tc )
  plt.plot ( xc, yc, 'r-', linewidth = 2 )
  plt.plot ( 2.0*xc, 2.0*yc, 'r-', linewidth = 2 )
  plt.plot ( 3.0*xc, 3.0*yc, 'r-', linewidth = 2 )
  plt.xlabel ( 'X' )
  plt.ylabel ( 'Y' )
  plt.title ( 'Ruspini dataset' )
  plt.grid ( True )
  plt.axis ( 'equal' )
  plotfile = 'exercise3.png'
  plt.savefig ( plotfile )
  print ( '  Graphics saved as "' + plotfile + '"' )
  plt.show ( )
  plt.close ( )
#
#  Try K-Means, for k = 1 to 10.
#
  print ( '' )
  print ( '  k  Energy' )
  print ( '' )
  kmax = 10
  E = np.zeros ( kmax )
  for k in range ( 1, kmax + 1 ):
    Z, C = kmeans2 ( data, k )
    for i in range ( 0, k ):
      bd = ( data[:,0] - Z[i,0] )**2 + ( data[:,1] - Z[i,1] )**2
      E[k-1] = E[k-1] + np.sum ( bd[C==i] )
    print ( '  %d  %g' % ( k, E[k-1] ) )
#
#  Plot the inertia.
#
  plt.clf ( )
  plt.plot ( np.arange ( 1, kmax + 1 ), E, 'bo-', linewidth = 3 )
  plt.grid ( True )
  plt.xlabel ( 'K: Number of clusters' )
  plt.ylabel ( 'E(K): Cluster energy' )
  plt.title ( 'Ruspini energy E(k) with increasing number of clusters' )
  plotfile = 'exercise3_inertia.png'
  plt.savefig ( plotfile )
  print ( '  Graphics saved as "' + plotfile + '"' )
  plt.show ( )
  plt.close ( )
#
#  Use the chosen value of K to cluster the data.
#
  k = 4
  Z, C = kmeans2 ( data, k )
#
#  Plot the clusters using different colors.
#
#  The value "r" is the standard deviation of the distance 
#  between data and center.  We use it in order to draw rings around
#  each cluster.
#
  plt.clf ( )
  for i in range ( 0, k ):
    plt.scatter ( data[C==i,0], data[C==i,1], marker = '.' )
    plt.scatter ( Z[i,0], Z[i,1], c = 'black', s = 250, marker = '*')
    r = np.sqrt ( np.sum ( np.var ( data[C==i,:], axis = 0 ) ) )
    tc = np.linspace ( 0, 2.0 * np.pi, 51 )
    xc = r * np.cos ( tc )
    yc = r * np.sin ( tc )
    plt.plot (     xc + Z[i,0],     yc + Z[i,1], linewidth = 2 )
    plt.plot ( 2.0*xc + Z[i,0], 2.0*yc + Z[i,1], linewidth = 2 )
    plt.plot ( 3.0*xc + Z[i,0], 3.0*yc + Z[i,1], linewidth = 2 )
  plt.title ( 'Ruspini data' )
  plt.grid ( True )
  plt.axis ( 'equal' )
  plotfile = 'exercise3_clusters.png'
  plt.savefig ( plotfile )
  print ( '  Graphics saved as "' + plotfile + '"' )
  plt.show ( )
  plt.close ( )
#
#  Terminate.
#
  print ( "" )
  print ( "exercise3():" )
  print ( "  Normal end of execution." )

  return

if ( __name__ == "__main__" ):
  exercise3 ( )