Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Author: Romain Tavenard
- # License: BSD 3 clause
- import numpy
- from sklearn.metrics import accuracy_score
- from tslearn.generators import random_walk_blobs
- from tslearn.preprocessing import TimeSeriesScalerMinMax, \
- TimeSeriesScalerMeanVariance
- from tslearn.neighbors import KNeighborsTimeSeriesClassifier, \
- KNeighborsTimeSeries
- numpy.random.seed(0)
- n_ts_per_blob, sz, d, n_blobs = 20, 100, 1, 2
- # Prepare data
- X, y = random_walk_blobs(n_ts_per_blob=n_ts_per_blob,
- sz=sz,
- d=d,
- n_blobs=n_blobs)
- scaler = TimeSeriesScalerMinMax(value_range=(0., 0.0011)) # Rescale time series
- X_scaled = scaler.fit_transform(X)
- indices_shuffle = numpy.random.permutation(n_ts_per_blob * n_blobs)
- X_shuffle = X_scaled[indices_shuffle]
- y_shuffle = y[indices_shuffle]
- X_train = X_shuffle[:n_ts_per_blob * n_blobs // 2]
- X_test = X_shuffle[n_ts_per_blob * n_blobs // 2:]
- y_train = y_shuffle[:n_ts_per_blob * n_blobs // 2]
- y_test = y_shuffle[n_ts_per_blob * n_blobs // 2:]
- # Nearest neighbor search
- knn = KNeighborsTimeSeries(n_neighbors=3, metric="dtw")
- knn.fit(X_train, y_train)
- dists, ind = knn.kneighbors(X_test)
- print("1. Nearest neighbour search")
- print("Computed nearest neighbor indices (wrt DTW)\n", ind)
- print("First nearest neighbor class:", y_test[ind[:, 0]])
- # Nearest neighbor classification
- knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="dtw")
- knn_clf.fit(X_train, y_train)
- predicted_labels = knn_clf.predict(X_test)
- print("\n2. Nearest neighbor classification using DTW")
- print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
- # Nearest neighbor classification with a different metric (Euclidean distance)
- knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="euclidean")
- knn_clf.fit(X_train, y_train)
- predicted_labels = knn_clf.predict(X_test)
- print("\n3. Nearest neighbor classification using L2")
- print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
- # Nearest neighbor classification based on SAX representation
- metric_params = {'n_segments': 10, 'alphabet_size_avg': 5}
- knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=3, metric="sax",
- metric_params=metric_params)
- knn_clf.fit(X_train, y_train)
- predicted_labels = knn_clf.predict(X_test)
- print("\n4. Nearest neighbor classification using SAX+MINDIST")
- print("Correct classification rate:", accuracy_score(y_test, predicted_labels))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement