Advertisement
max2201111

blob

Sep 13th, 2023
770
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.50 KB | Science | 0 0
  1. # Author: Romain Tavenard
  2. # License: BSD 3 clause
  3.  
  4. import numpy
  5. from sklearn.metrics import accuracy_score
  6.  
  7. from tslearn.generators import random_walk_blobs
  8. from tslearn.preprocessing import TimeSeriesScalerMinMax, \
  9.     TimeSeriesScalerMeanVariance
  10. from tslearn.neighbors import KNeighborsTimeSeriesClassifier, \
  11.     KNeighborsTimeSeries
  12.  
  13. numpy.random.seed(0)
  14. n_ts_per_blob, sz, d, n_blobs = 20, 100, 1, 2
  15.  
  16. # Prepare data
  17. X, y = random_walk_blobs(n_ts_per_blob=n_ts_per_blob,
  18.                          sz=sz,
  19.                          d=d,
  20.                          n_blobs=n_blobs)
  21. scaler = TimeSeriesScalerMinMax(value_range=(0., 0.0011))  # Rescale time series
  22. X_scaled = scaler.fit_transform(X)
  23.  
  24. indices_shuffle = numpy.random.permutation(n_ts_per_blob * n_blobs)
  25. X_shuffle = X_scaled[indices_shuffle]
  26. y_shuffle = y[indices_shuffle]
  27.  
  28. X_train = X_shuffle[:n_ts_per_blob * n_blobs // 2]
  29. X_test = X_shuffle[n_ts_per_blob * n_blobs // 2:]
  30. y_train = y_shuffle[:n_ts_per_blob * n_blobs // 2]
  31. y_test = y_shuffle[n_ts_per_blob * n_blobs // 2:]
  32.  
  33. # Nearest neighbor search
  34. knn = KNeighborsTimeSeries(n_neighbors=3, metric="dtw")
  35. knn.fit(X_train, y_train)
  36. dists, ind = knn.kneighbors(X_test)
  37. print("1. Nearest neighbour search")
  38. print("Computed nearest neighbor indices (wrt DTW)\n", ind)
  39. print("First nearest neighbor class:", y_test[ind[:, 0]])
  40.  
  41. # Nearest neighbor classification
  42. knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="dtw")
  43. knn_clf.fit(X_train, y_train)
  44. predicted_labels = knn_clf.predict(X_test)
  45. print("\n2. Nearest neighbor classification using DTW")
  46. print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
  47.  
  48. # Nearest neighbor classification with a different metric (Euclidean distance)
  49. knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="euclidean")
  50. knn_clf.fit(X_train, y_train)
  51. predicted_labels = knn_clf.predict(X_test)
  52. print("\n3. Nearest neighbor classification using L2")
  53. print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
  54.  
  55. # Nearest neighbor classification based on SAX representation
  56. metric_params = {'n_segments': 10, 'alphabet_size_avg': 5}
  57. knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="sax",
  58.                                          metric_params=metric_params)
  59. knn_clf.fit(X_train, y_train)
  60. predicted_labels = knn_clf.predict(X_test)
  61. print("\n4. Nearest neighbor classification using SAX+MINDIST")
  62. print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement