Advertisement
mirosh111000

DBSCAN

Feb 15th, 2024
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.01 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import seaborn as sns
  4. import pandas as pd
  5. from sklearn.datasets import make_blobs
  6. from sklearn.decomposition import PCA
  7. from sklearn.model_selection import train_test_split
  8. from sklearn.metrics import silhouette_score, accuracy_score, adjusted_rand_score
  9. from sklearn.neighbors import NearestNeighbors
  10. from sklearn.cluster import DBSCAN as sk_DBSCAN
  11.  
  12.  
  13.  
  14. def best_k_distance_graph(X, k=9):
  15.  
  16.     plt.figure(figsize=(10, 6))
  17.     plt.xlabel('Points sorted by distance')
  18.     plt.ylabel(f'k-distance')
  19.     plt.title(f'K-Distance Graph')
  20.     plt.grid(True)
  21.    
  22.     neigh = NearestNeighbors(n_neighbors=k)
  23.     neigh.fit(X)
  24.     distances, _ = neigh.kneighbors(X)
  25.    
  26.     for i in range(1, k):
  27.        
  28.         k_distances = distances[:, i]
  29.         k_distances_sorted = np.sort(k_distances)
  30.         plt.plot(np.arange(len(X)), k_distances_sorted, marker='.', label=f'k={i+1}')
  31.        
  32.     plt.legend()
  33.     plt.show()
  34.  
  35.  
  36. class DBSCAN:
  37.     def __init__(self, eps=0.5, min_samples=5):
  38.         self.eps = eps
  39.         self.min_samples = min_samples
  40.         self.labels_ = None
  41.        
  42.     def fit(self, X):
  43.         X = np.copy(X)
  44.         self.labels_ = np.zeros(len(X), dtype=int)
  45.         cluster_label = 0
  46.        
  47.         for i, point in enumerate(X):
  48.             if self.labels_[i] != 0:  
  49.                 continue
  50.                
  51.             neighbors = self._get_neighbors(X, i)
  52.             if len(neighbors) < self.min_samples:
  53.                 self.labels_[i] = -1
  54.                 continue
  55.                
  56.             cluster_label += 1
  57.             self._expand_cluster(X, i, neighbors, cluster_label)
  58.            
  59.     def predict(self, X):
  60.         return self.labels_
  61.    
  62.     def _expand_cluster(self, X, point_index, neighbors, cluster_label):
  63.         self.labels_[point_index] = cluster_label
  64.        
  65.         i = 0
  66.         while i < len(neighbors):
  67.             neighbor = neighbors[i]
  68.             if self.labels_[neighbor] == -1:
  69.                 self.labels_[neighbor] = cluster_label
  70.             elif self.labels_[neighbor] == 0:
  71.                 self.labels_[neighbor] = cluster_label
  72.                 new_neighbors = self._get_neighbors(X, neighbor)
  73.                 if len(new_neighbors) >= self.min_samples:
  74.                     neighbors.extend(new_neighbors)
  75.             i += 1
  76.        
  77.     def _get_neighbors(self, X, point_index):
  78.         neighbors = []
  79.         for i, point in enumerate(X):
  80.             if np.linalg.norm(point - X[point_index]) < self.eps:
  81.                 neighbors.append(i)
  82.         return neighbors
  83.  
  84.  
  85.    
  86.    
  87.    
  88. def visualize_clusters(X, labels, title='DBSCAN Clustering'):
  89.     X = np.copy(X)
  90.     labels = np.copy(labels)
  91.     plt.figure(figsize=(10, 6))
  92.     plt.grid()
  93.      
  94.     for i, label in enumerate(np.unique(labels)):
  95.         if label == -1:
  96.             plt.scatter(X[labels == label][:, 0], X[labels == label][:, 1], color='k', edgecolors='black', label='Noise')
  97.         else:
  98.             plt.scatter(X[labels == label][:, 0], X[labels == label][:, 1], edgecolors='black', label=f'Cluster {label}')
  99.     plt.title(title)
  100.     plt.xlabel('Feature 1')
  101.     plt.ylabel('Feature 2')
  102.     plt.legend()
  103.     plt.show()
  104.    
  105.    
  106.  
  107.  
  108. X, y = make_blobs(
  109.     n_samples=250,
  110.     n_features=5,
  111.     centers=6,
  112.     cluster_std=2,
  113.     random_state=42
  114. )
  115.  
  116.  
  117. pca = PCA(n_components=2)
  118. X = pca.fit_transform(X)
  119. data = pd.DataFrame(X)
  120. data['target'] = y
  121. sns.pairplot(data, hue='target', palette='dark')
  122. plt.show()
  123.  
  124. best_k_distance_graph(np.copy(X))
  125. dbscan = DBSCAN(eps=2.3, min_samples=5)
  126. dbscan.fit(X)
  127. labels = dbscan.predict(X)
  128. visualize_clusters(X, labels)
  129. print(f"Adjusted Rand Index: {(round(adjusted_rand_score(data.target, labels)*100, 2))}%")
  130.  
  131.  
  132. dbscan = sk_DBSCAN(eps=2.3, min_samples=5)
  133. dbscan.fit(X)
  134. labels = dbscan.labels_
  135. visualize_clusters(X, labels, title='Sklearn DBSCAN')
  136. print(f"Sklearn Adjusted Rand Index: {(round(adjusted_rand_score(data.target, labels)*100, 2))}%")
  137.  
  138.  
  139.  
  140. data = pd.read_csv('moons.csv')
  141. sns.pairplot(data)
  142. plt.show()
  143.  
  144. X = data.copy()
  145. best_k_distance_graph(np.copy(X))
  146. dbscan = DBSCAN(eps=0.35, min_samples=3)
  147. dbscan.fit(X)
  148. labels = dbscan.predict(X)
  149. visualize_clusters(X, labels)
  150.  
  151. dbscan = sk_DBSCAN(eps=0.35, min_samples=3)
  152. dbscan.fit(X)
  153. labels = dbscan.labels_
  154. visualize_clusters(X, labels, title='Sklearn DBSCAN')
  155.  
  156.  
  157.  
  158.  
  159. data = pd.read_csv('sklearn_moons.csv')
  160. data.columns = ['x', 'y', 'target']
  161. sns.pairplot(data, hue='target', palette='dark')
  162. plt.show()
  163.  
  164. X = data.copy()
  165. best_k_distance_graph(np.copy(X))
  166. dbscan = DBSCAN(eps=0.125, min_samples=3)
  167. dbscan.fit(X)
  168. labels = dbscan.predict(X)
  169. visualize_clusters(X, labels)
  170. print(f"Adjusted Rand Index: {(round(adjusted_rand_score(data.target, labels)*100, 2))}%")
  171.  
  172.  
  173.  
  174. dbscan = sk_DBSCAN(eps=0.125, min_samples=3)
  175. dbscan.fit(X)
  176. labels = dbscan.labels_
  177. visualize_clusters(X, labels, title='Sklearn DBSCAN')
  178. print(f"Sklearn Adjusted Rand Index: {(round(adjusted_rand_score(data.target, labels)*100, 2))}%")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement