Advertisement
makispaiktis

ML - Lab 8 - Hierarchical Clustering: Dendrograms and Silhouette

Oct 23rd, 2022 (edited)
1,142
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.28 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import sklearn
  5.  
  6.  
  7. # Read Data
  8. europe = pd.read_csv("./europe.txt")
  9. print("Data = ")
  10. print(europe)
  11. print()
  12. print("Data Summary = ")
  13. print(europe.describe())
  14. print()
  15.  
  16.  
  17. # Preprocessing
  18. from sklearn.preprocessing import StandardScaler
  19. scaler = StandardScaler()
  20. scaler = scaler.fit(europe)
  21. europe = pd.DataFrame(scaler.transform(europe), columns=europe.columns, index=europe.index)
  22. print("Data processed = ")
  23. print(europe)
  24. print()
  25.  
  26.  
  27. # Hierarchical Clustering with complete link - Dendrogram
  28. from sklearn.cluster import AgglomerativeClustering
  29. from scipy.cluster.hierarchy import dendrogram
  30. clustering = AgglomerativeClustering(n_clusters=None, linkage="complete", distance_threshold=0).fit(europe)
  31. linkage_matrix = np.column_stack([clustering.children_, clustering.distances_, np.ones(len(europe.index)-1)]).astype(float)
  32. dendrogram(linkage_matrix, labels=europe.index)
  33. plt.title("Complete link (No clusters)")
  34. plt.show()
  35.  
  36.  
  37. # Silhouette score for the whole cluster
  38. from sklearn.metrics import silhouette_score
  39. slc = []
  40. for i in range(2, 21):
  41.     clustering = AgglomerativeClustering(n_clusters=i, linkage="complete").fit(europe)
  42.     SILHOUETTE = silhouette_score(europe, clustering.labels_)
  43.     slc.append(SILHOUETTE)
  44.  
  45. plt.plot(range(2, 21), slc)
  46. plt.xticks(range(2, 21), range(2, 21))
  47. plt.title("Silhouette score with complete link")
  48. plt.xlabel("# of clusters")
  49. plt.ylabel("Silhouette Score")
  50. plt.show()
  51.  
  52.  
  53. # Max silhouette at n = 7
  54. n_clusters = 7
  55. clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="complete").fit(europe)
  56. # 3-D plot of clustering
  57. fig = plt.figure()
  58. ax = fig.add_subplot(projection='3d')
  59. ax.scatter(europe.GDP, europe.Inflation, europe.Unemployment, c=clustering.labels_, cmap="bwr")
  60. for i in range(len(europe.index)):
  61.     ax.text(europe.loc[europe.index[i], "GDP"], europe.loc[europe.index[i], "Inflation"], europe.loc[europe.index[i], "Unemployment"], '%s' % (str(europe.index[i])), size=5, zorder=1)
  62. ax.set_xlabel('GDP')
  63. ax.set_ylabel('Inflation')
  64. ax.set_zlabel('Unemployment')
  65. plt.title("Clustering in " + str(n_clusters) + " clusters")
  66. plt.show()
  67. print("n = " + str(n_clusters))
  68. print("Silhouette score = " + str(silhouette_score(europe, clustering.labels_)))
  69.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement