Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from astropy.io import fits
- from astropy.utils.data import download_file
- import matplotlib.pyplot as plt
- from keras.models import Model
- from sklearn import metrics
- from keras.layers import Input, Flatten, Dense, Dropout, Conv1D, MaxPooling1D
- np.random.seed(42)
- file_url = 'https://archive.stsci.edu/hlsps/hellouniverse/hellouniverse_stella_500.fits'
- hdu = fits.open(download_file(file_url, cache=True))
- train_data = hdu[1].data['train_data']
- train_labels = hdu[1].data['train_labels']
- test_data = hdu[2].data['test_data bhu8']
- test_labels = hdu[2].data['test_labels']
- val_data = hdu[3].data['val_data']
- val_labels = hdu[3].data['val_labels']
- example_ids = np.random.choice(len(train_labels), 16)
- example_lightcurves = [train_data[j] for j in example_ids]
- example_labels = [train_labels[j] for j in example_ids]
- fig = plt.figure(figsize=(10, 10))
- colors = {1: 'r', 0: 'k'}
- titles = {1: 'Flare', 0: 'Non-flare'}
- for i in range(len(example_ids)):
- plt.subplot(4, 4, i + 1)
- plt.plot(example_lightcurves[i], color=colors[example_labels[i]])
- plt.title(titles[example_labels[i]])
- plt.xlabel('Cadences')
- plt.tight_layout()
- plt.show()
- seed = 2
- np.random.seed(seed)
- filter1 = 16
- filter2 = 64
- dense = 32
- dropout = 0.1
- data_shape = np.shape(train_data)
- input_shape = (np.shape(train_data)[1], 1)
- x_in = Input(shape=input_shape)
- c0 = Conv1D(7, filter1, activation='relu', padding='same', input_shape=input_shape)(x_in)
- b0 = MaxPooling1D(pool_size=2)(c0)
- d0 = Dropout(dropout)(b0)
- c1 = Conv1D(3, filter2, activation='relu', padding='same')(d0)
- b1 = MaxPooling1D(pool_size=2)(c1)
- d1 = Dropout(dropout)(b1)
- f = Flatten()(d1)
- z0 = Dense(dense, activation='relu')(f)
- d2 = Dropout(dropout)(z0)
- y_out = Dense(1, activation='sigmoid')(d2)
- cnn = Model(inputs=x_in, outputs=y_out)
- optimizer = 'adam'
- fit_metrics = ['accuracy']
- loss = 'binary_crossentropy'
- cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)
- nb_epoch = 20
- batch_size = 64
- shuffle = True
- history = cnn.fit(train_data, train_labels,
- batch_size=batch_size,
- epochs=nb_epoch,
- validation_data=(val_data, val_labels),
- shuffle=shuffle,
- verbose=True)
- cnn_file = 'flare_model.h5'
- cnn.save(cnn_file)
- def plot_confusion_matrix(cnn, input_data, input_labels):
- # Compute flare predictions for the test dataset
- predictions = cnn.predict(input_data)
- # Convert to binary classification
- predictions = (predictions > 0.5).astype('int32')
- # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions
- cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
- cm = cm.astype('float')
- # Normalize the confusion matrix results.
- cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
- # Plotting
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.matshow(cm_norm, cmap='binary_r')
- plt.title('Confusion matrix', y=1.08)
- ax.set_xticks([0, 1])
- ax.set_xticklabels(['Flare', 'No Flare'])
- ax.set_yticks([0, 1])
- ax.set_yticklabels(['Flare', 'No Flare'])
- plt.xlabel('Predicted')
- plt.ylabel('True')
- fmt = '.2f'
- thresh = cm_norm.max() / 2.
- for i in range(cm_norm.shape[0]):
- for j in range(cm_norm.shape[1]):
- ax.text(j, i, format(cm_norm[i, j], fmt),
- ha="center", va="center", color="white" if cm_norm[i, j] < thresh else "black")
- plt.show()
- def plot_confusion_matrix(cnn, input_data, input_labels):
- # Compute flare predictions for the test dataset
- predictions = cnn.predict(input_data)
- # Convert to binary classification
- predictions = (predictions > 0.5).astype('int32')
- # Compute the confusion matrix by comparing the test labels with the test predictions
- cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
- cm = cm.astype('float')
- # Normalize the confusion matrix results.
- cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
- # Plotting
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.matshow(cm_norm, cmap='binary_r')
- plt.title('Confusion matrix', y=1.08)
- ax.set_xticks([0, 1])
- ax.set_xticklabels(['Flare', 'No Flare'])
- ax.set_yticks([0, 1])
- ax.set_yticklabels(['Flare', 'No Flare'])
- plt.xlabel('Predicted')
- plt.ylabel('True')
- fmt = '.2f'
- thresh = cm_norm.max() / 2.
- for i in range(cm_norm.shape[0]):
- for j in range(cm_norm.shape[1]):
- ax.text(j, i, format(cm_norm[i, j], fmt),
- ha="center", va="center", color="white" if cm_norm[i, j] < thresh else "black")
- plt.show()
- # Assuming `cnn` is your trained model and `test_data`, `test_labels` are your test dataset
- plot_confusion_matrix(cnn, test_data, test_labels)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement