Advertisement
UF6

Stellar Learning

UF6
Apr 3rd, 2024
807
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.93 KB | Source Code | 0 0
  1. import numpy as np
  2. from astropy.io import fits
  3. from astropy.utils.data import download_file
  4. import matplotlib.pyplot as plt
  5. from keras.models import Model
  6. from sklearn import metrics
  7. from keras.layers import Input, Flatten, Dense, Dropout, Conv1D, MaxPooling1D
  8.  
  9. np.random.seed(42)
  10.  
  11. file_url = 'https://archive.stsci.edu/hlsps/hellouniverse/hellouniverse_stella_500.fits'
  12. hdu = fits.open(download_file(file_url, cache=True))
  13.  
  14. train_data = hdu[1].data['train_data']
  15. train_labels = hdu[1].data['train_labels']
  16.  
  17. test_data = hdu[2].data['test_data  bhu8']
  18. test_labels = hdu[2].data['test_labels']
  19.  
  20. val_data = hdu[3].data['val_data']
  21. val_labels = hdu[3].data['val_labels']
  22.  
  23. example_ids = np.random.choice(len(train_labels), 16)
  24. example_lightcurves = [train_data[j] for j in example_ids]
  25. example_labels = [train_labels[j] for j in example_ids]
  26.  
  27. fig = plt.figure(figsize=(10, 10))
  28. colors = {1: 'r', 0: 'k'}
  29. titles = {1: 'Flare', 0: 'Non-flare'}
  30. for i in range(len(example_ids)):
  31.     plt.subplot(4, 4, i + 1)
  32.     plt.plot(example_lightcurves[i], color=colors[example_labels[i]])
  33.     plt.title(titles[example_labels[i]])
  34.     plt.xlabel('Cadences')
  35.  
  36. plt.tight_layout()
  37. plt.show()
  38.  
  39. seed = 2
  40. np.random.seed(seed)
  41.  
  42. filter1 = 16
  43. filter2 = 64
  44. dense = 32
  45. dropout = 0.1
  46.  
  47. data_shape = np.shape(train_data)
  48. input_shape = (np.shape(train_data)[1], 1)
  49.  
  50. x_in = Input(shape=input_shape)
  51. c0 = Conv1D(7, filter1, activation='relu', padding='same', input_shape=input_shape)(x_in)
  52. b0 = MaxPooling1D(pool_size=2)(c0)
  53. d0 = Dropout(dropout)(b0)
  54.  
  55. c1 = Conv1D(3, filter2, activation='relu', padding='same')(d0)
  56. b1 = MaxPooling1D(pool_size=2)(c1)
  57. d1 = Dropout(dropout)(b1)
  58.  
  59. f = Flatten()(d1)
  60. z0 = Dense(dense, activation='relu')(f)
  61. d2 = Dropout(dropout)(z0)
  62. y_out = Dense(1, activation='sigmoid')(d2)
  63.  
  64. cnn = Model(inputs=x_in, outputs=y_out)
  65.  
  66. optimizer = 'adam'
  67. fit_metrics = ['accuracy']
  68. loss = 'binary_crossentropy'
  69. cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)
  70.  
  71. nb_epoch = 20
  72. batch_size = 64
  73. shuffle = True
  74.  
  75. history = cnn.fit(train_data, train_labels,
  76.                   batch_size=batch_size,
  77.                   epochs=nb_epoch,
  78.                   validation_data=(val_data, val_labels),
  79.                   shuffle=shuffle,
  80.                   verbose=True)
  81.  
  82. cnn_file = 'flare_model.h5'
  83. cnn.save(cnn_file)
  84.  
  85. def plot_confusion_matrix(cnn, input_data, input_labels):
  86.    
  87.     # Compute flare predictions for the test dataset
  88.     predictions = cnn.predict(input_data)
  89.  
  90.     # Convert to binary classification
  91.     predictions = (predictions > 0.5).astype('int32')
  92.    
  93.     # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions
  94.     cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
  95.     cm = cm.astype('float')
  96.  
  97.     # Normalize the confusion matrix results.
  98.     cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
  99.    
  100.     # Plotting
  101.     fig = plt.figure()
  102.     ax = fig.add_subplot(111)
  103.  
  104.     ax.matshow(cm_norm, cmap='binary_r')
  105.  
  106.     plt.title('Confusion matrix', y=1.08)
  107.    
  108.     ax.set_xticks([0, 1])
  109.     ax.set_xticklabels(['Flare', 'No Flare'])
  110.    
  111.     ax.set_yticks([0, 1])
  112.     ax.set_yticklabels(['Flare', 'No Flare'])
  113.  
  114.     plt.xlabel('Predicted')
  115.     plt.ylabel('True')
  116.  
  117.     fmt = '.2f'
  118.     thresh = cm_norm.max() / 2.
  119.     for i in range(cm_norm.shape[0]):
  120.         for j in range(cm_norm.shape[1]):
  121.             ax.text(j, i, format(cm_norm[i, j], fmt),
  122.                     ha="center", va="center", color="white" if cm_norm[i, j] < thresh else "black")
  123.     plt.show()
  124.    
  125. def plot_confusion_matrix(cnn, input_data, input_labels):
  126.    
  127.     # Compute flare predictions for the test dataset
  128.     predictions = cnn.predict(input_data)
  129.  
  130.     # Convert to binary classification
  131.     predictions = (predictions > 0.5).astype('int32')
  132.    
  133.     # Compute the confusion matrix by comparing the test labels with the test predictions
  134.     cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])
  135.     cm = cm.astype('float')
  136.  
  137.     # Normalize the confusion matrix results.
  138.     cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]
  139.    
  140.     # Plotting
  141.     fig = plt.figure()
  142.     ax = fig.add_subplot(111)
  143.  
  144.     ax.matshow(cm_norm, cmap='binary_r')
  145.  
  146.     plt.title('Confusion matrix', y=1.08)
  147.    
  148.     ax.set_xticks([0, 1])
  149.     ax.set_xticklabels(['Flare', 'No Flare'])
  150.    
  151.     ax.set_yticks([0, 1])
  152.     ax.set_yticklabels(['Flare', 'No Flare'])
  153.  
  154.     plt.xlabel('Predicted')
  155.     plt.ylabel('True')
  156.  
  157.     fmt = '.2f'
  158.     thresh = cm_norm.max() / 2.
  159.     for i in range(cm_norm.shape[0]):
  160.         for j in range(cm_norm.shape[1]):
  161.             ax.text(j, i, format(cm_norm[i, j], fmt),
  162.                     ha="center", va="center", color="white" if cm_norm[i, j] < thresh else "black")
  163.     plt.show()
  164.  
  165. # Assuming `cnn` is your trained model and `test_data`, `test_labels` are your test dataset
  166. plot_confusion_matrix(cnn, test_data, test_labels)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement