Advertisement
max2201111

optuna relu tanh

Aug 21st, 2024
113
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.37 KB | Science | 0 0
  1. import numpy as np
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Dense, Flatten, Input, Activation
  4. from tensorflow.keras.layers import LeakyReLU
  5. from sklearn.preprocessing import StandardScaler
  6. from sklearn.metrics import confusion_matrix
  7. from sklearn.utils.class_weight import compute_class_weight
  8. import matplotlib.pyplot as plt
  9. import seaborn as sns
  10. import optuna
  11. from sklearn.preprocessing import PolynomialFeatures
  12. from sklearn.linear_model import LinearRegression
  13.  
  14. # Trénovací data (2D)
  15. # Trénovací data (2D)
  16. train_data = np.array([
  17.     [1141, 1050, 1], [1499, 1050, 0], [1451, 1077, 1], [1519, 1077, 1], [1191, 1093, 0],
  18.     [1590, 1093, 1], [1777, 1093, 0], [1141, 1124, 1], [1499, 1124, 0], [1601, 1124, 0],
  19.     [1606, 1141, 0], [1608, 1141, 0], [1870, 1141, 0], [1794, 1191, 0], [1329, 1211, 1],
  20.     [1687, 1211, 0], [1918, 1211, 0], [1608, 1212, 0], [1050, 1228, 0], [2107, 1228, 0],
  21.     [2202, 1266, 0], [1551, 1327, 0], [1608, 1327, 1], [1660, 1327, 0], [1329, 1332, 1],
  22.     [1093, 1346, 0], [1546, 1419, 1], [1327, 1435, 0], [1774, 1435, 0], [1794, 1451, 0],
  23.     [1077, 1458, 1], [1093, 1458, 1], [1731, 1458, 0], [1777, 1491, 0], [1212, 1499, 1],
  24.     [1211, 1519, 1], [1608, 1519, 1], [1918, 1519, 0], [1458, 1538, 1], [1918, 1538, 0],
  25.     [1794, 1545, 0], [1903, 1545, 0], [1435, 1546, 1], [1758, 1546, 0], [2076, 1546, 0],
  26.     [1077, 1551, 1], [1690, 1551, 0], [1050, 1590, 1], [1093, 1601, 1], [1327, 1601, 0],
  27.     [1050, 1606, 1], [1491, 1606, 1], [1519, 1608, 0], [1266, 1660, 1], [1839, 1660, 0],
  28.     [1332, 1687, 0], [1519, 1687, 0], [1538, 1690, 1], [1870, 1690, 0], [1903, 1731, 1],
  29.     [1918, 1731, 0], [1419, 1758, 0], [1839, 1758, 0], [1077, 1774, 1], [1519, 1774, 1],
  30.     [2202, 1774, 0], [1538, 1777, 0], [1903, 1777, 1], [2107, 1777, 0], [1660, 1794, 0],
  31.     [2107, 1794, 0], [1124, 1839, 1], [1519, 1839, 1], [1546, 1839, 1], [1870, 1839, 1],
  32.     [2202, 1839, 0], [1419, 1870, 1], [2107, 1870, 0], [2202, 1870, 0], [1191, 1903, 1],
  33.     [1601, 1903, 1], [1606, 1903, 1], [1660, 1903, 1], [1491, 1918, 1], [1212, 2076, 1],
  34.     [1690, 2076, 1], [1546, 2107, 1], [1903, 2107, 1], [2183, 2107, 0], [1229, 2183, 1],
  35.     [1731, 2183, 1], [1758, 2183, 0], [1918, 2183, 1], [2076, 2183, 0], [1538, 2202, 1],
  36.     [1601, 2202, 1], [2076, 2202, 0], [1660, 2258, 1], [1777, 2258, 0], [2202, 2258, 0],
  37.     [1477, 1141, 0], [1519, 1141, 0], [1519, 1211, 1], [1549, 1211, 0], [1731, 1211, 0],
  38.     [1848, 1211, 0], [1458, 1229, 1], [1528, 1229, 1], [1354, 1327, 0], [1448, 1327, 0],
  39.     [1458, 1327, 1], [1499, 1327, 0], [1758, 1327, 0], [1405, 1332, 0], [1519, 1332, 0],
  40.     [1549, 1332, 0], [1605, 1332, 0], [1606, 1332, 0], [1606, 1332, 1], [1794, 1332, 1],
  41.     [1376, 1346, 0], [1608, 1346, 0], [1551, 1354, 1], [1731, 1354, 0], [1849, 1354, 0],
  42.     [1918, 1354, 0], [1332, 1376, 0], [1690, 1376, 0], [1606, 1405, 1], [1608, 1405, 0],
  43.     [1687, 1405, 0], [1774, 1448, 0], [1870, 1448, 0], [1477, 1450, 0], [1551, 1450, 0],
  44.     [1731, 1450, 1], [2076, 1450, 0], [1687, 1458, 0], [1448, 1477, 0], [1794, 1477, 0],
  45.     [2183, 1477, 0], [1211, 1499, 1], [1774, 1499, 0], [1782, 1499, 0], [1848, 1499, 0],
  46.     [1211, 1519, 0], [1229, 1519, 1], [1332, 1519, 0], [2258, 1519, 0], [1332, 1528, 0],
  47.     [1327, 1546, 0], [1332, 1546, 1], [1606, 1546, 0], [1918, 1546, 1], [1499, 1549, 0],
  48.     [2258, 1549, 0], [1327, 1551, 0], [1870, 1551, 0], [1211, 1590, 1], [1229, 1590, 1],
  49.     [1601, 1590, 1], [1211, 1601, 0], [1332, 1601, 1], [1354, 1601, 1], [1405, 1601, 0],
  50.     [1229, 1605, 0], [1332, 1605, 1], [1448, 1605, 0], [1354, 1606, 0], [1927, 1606, 0],
  51.     [1450, 1608, 0], [1519, 1608, 1], [1849, 1608, 0], [1477, 1687, 1], [1848, 1687, 0],
  52.     [1141, 1690, 1], [1327, 1690, 1], [1549, 1731, 1], [1590, 1731, 0], [2250, 1731, 0],
  53.     [1332, 1747, 1], [1927, 1747, 0], [1450, 1758, 1], [2076, 1758, 0], [1747, 1774, 0],
  54.     [1794, 1774, 0], [1332, 1782, 1], [1687, 1782, 0], [1747, 1782, 0], [1758, 1782, 1],
  55.     [2076, 1782, 1], [1458, 1794, 1], [1590, 1794, 1], [2153, 1794, 0], [1747, 1848, 0],
  56.     [1758, 1848, 0], [1870, 1848, 0], [2183, 1848, 0], [1332, 1849, 1], [1731, 1849, 1],
  57.     [1870, 1849, 1], [2183, 1849, 0], [2250, 1849, 0], [1758, 1870, 0], [1918, 1870, 1],
  58.     [1549, 1918, 0], [1448, 1927, 1], [1758, 1927, 1], [2183, 1927, 1], [2250, 1927, 0],
  59.     [2258, 1927, 0], [1601, 2076, 1], [1605, 2076, 1], [1849, 2076, 1], [2258, 2076, 0],
  60.     [1458, 2153, 1], [1927, 2153, 0], [2076, 2153, 0], [1690, 2183, 1], [2076, 2183, 1],
  61.     [2153, 2183, 0], [1499, 2250, 1], [1690, 2250, 1], [1747, 2250, 0], [1918, 2250, 1],
  62.     [2183, 2250, 0], [1747, 2258, 1]
  63. ])
  64.  
  65. # Testovací data (2D)
  66. test_data = np.array([
  67.     [1451, 1050, 0], [1758, 1050, 0], [1346, 1211, 1], [1546, 1332, 1], [1608, 1451, 1],
  68.     [1839, 1458, 0], [1435, 1538, 1], [1077, 1546, 1], [2183, 1551, 0], [1458, 1590, 0],
  69.     [1538, 1606, 0], [1077, 1608, 1], [2258, 1608, 0], [1419, 1690, 1], [1545, 1731, 1],
  70.     [1774, 1758, 0], [1545, 1774, 1], [2183, 1777, 0], [1228, 1794, 1], [1774, 1794, 0],
  71.     [2258, 1870, 1], [1546, 1903, 1], [1774, 1918, 0], [2076, 1918, 0], [1758, 2076, 1],
  72.     [1839, 2076, 0], [2107, 2076, 1], [2258, 2107, 1], [1731, 2202, 1], [1327, 2258, 1],
  73.     [1354, 1229, 0], [1774, 1229, 0], [1546, 1376, 0], [1918, 1405, 0], [1605, 1450, 0],
  74.     [1605, 1477, 0], [1448, 1519, 0], [1450, 1519, 1], [1141, 1528, 0], [1346, 1551, 1],
  75.     [1608, 1606, 0], [1376, 1608, 1], [2153, 1687, 0], [1458, 1690, 0], [1590, 1690, 1],
  76.     [1774, 1731, 0], [1229, 1747, 1], [2250, 1758, 0], [1346, 1848, 1], [1376, 1870, 1],
  77.     [2258, 1870, 0], [1590, 1918, 1], [1849, 2153, 0], [1782, 2183, 1], [2153, 2258, 0]
  78. ])
  79.  
  80. # Rozdělení na vstupy (X) a výstupy (y)
  81. X_train = train_data[:, 0:2]
  82. y_train = train_data[:, 2]
  83. X_test = test_data[:, 0:2]
  84. y_test = test_data[:, 2]
  85.  
  86. # Normalizace dat pomocí Z-score
  87. scaler = StandardScaler()
  88. X_train = scaler.fit_transform(X_train)
  89. X_test = scaler.transform(X_test)
  90.  
  91. # Výpočet váhy tříd
  92. class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
  93. class_weight_dict = dict(enumerate(class_weights))
  94.  
  95. # Funkce pro vytvoření modelu s různými aktivačními funkcemi
  96. def create_model(activation_fns, model_type='dnn'):
  97.     if model_type == 'dnn':
  98.         model = Sequential([
  99.             Input(shape=(2,)),  # Vstupní vrstva pro 2D data
  100.             Dense(128),  # První skrytá vrstva
  101.             Activation(activation_fns[0]),  # Aktivace první vrstvy
  102.             Dense(64),
  103.             Activation(activation_fns[1]),  # Aktivace druhé vrstvy
  104.             Dense(64),
  105.             Activation(activation_fns[2]),  # Aktivace třetí vrstvy
  106.             Dense(64),
  107.             Activation(activation_fns[3]),  # Aktivace čtvrté vrstvy
  108.             Dense(64),
  109.             Activation(activation_fns[4]),  # Aktivace páté vrstvy
  110.             Dense(32),
  111.             Activation(activation_fns[5]),  # Aktivace šesté vrstvy
  112.             Dense(32),
  113.             Activation(activation_fns[6]),  # Aktivace sedmé vrstvy
  114.             Dense(32),
  115.             Activation(activation_fns[7]),  # Aktivace osmé vrstvy
  116.             Dense(32),
  117.             Activation(activation_fns[8]),  # Aktivace deváté vrstvy
  118.             Dense(16),
  119.             Activation(activation_fns[9]),  # Aktivace desáté vrstvy
  120.             Dense(16),
  121.             Activation(activation_fns[10]),  # Aktivace jedenácté vrstvy
  122.             Dense(16),
  123.             Activation(activation_fns[11]),  # Aktivace dvanácté vrstvy
  124.             Dense(8),
  125.             Activation(activation_fns[12]),  # Aktivace třinácté vrstvy
  126.             Dense(1, activation='sigmoid')  # Výstupní vrstva
  127.         ])
  128.     elif model_type == 'cnn':
  129.         model = Sequential([
  130.             Input(shape=(2, 1)),  # Vstupní vrstva pro 2D data
  131.             Flatten(),  # Plochý vstup
  132.             Dense(128),
  133.             Activation(activation_fns[0]),  # Aktivace první vrstvy
  134.             Dense(64),
  135.             Activation(activation_fns[1]),  # Aktivace druhé vrstvy
  136.             Dense(64),
  137.             Activation(activation_fns[2]),  # Aktivace třetí vrstvy
  138.             Dense(64),
  139.             Activation(activation_fns[3]),  # Aktivace čtvrté vrstvy
  140.             Dense(64),
  141.             Activation(activation_fns[4]),  # Aktivace páté vrstvy
  142.             Dense(32),
  143.             Activation(activation_fns[5]),  # Aktivace šesté vrstvy
  144.             Dense(32),
  145.             Activation(activation_fns[6]),  # Aktivace sedmé vrstvy
  146.             Dense(32),
  147.             Activation(activation_fns[7]),  # Aktivace osmé vrstvy
  148.             Dense(32),
  149.             Activation(activation_fns[8]),  # Aktivace deváté vrstvy
  150.             Dense(16),
  151.             Activation(activation_fns[9]),  # Aktivace desáté vrstvy
  152.             Dense(16),
  153.             Activation(activation_fns[10]),  # Aktivace jedenácté vrstvy
  154.             Dense(16),
  155.             Activation(activation_fns[11]),  # Aktivace dvanácté vrstvy
  156.             Dense(8),
  157.             Activation(activation_fns[12]),  # Aktivace třinácté vrstvy
  158.             Dense(1, activation='sigmoid')  # Výstupní vrstva
  159.         ])
  160.     return model
  161.  
  162. # Funkce pro optimalizaci pomocí Optuna
  163. def objective(trial, model_type='dnn'):
  164.     # Náhodný výběr aktivačních funkcí pro každou vrstvu
  165.     activation_fns = []
  166.     for i in range(13):
  167.         activation_fn = trial.suggest_categorical(f'activation_fn_{i}', ['sigmoid', 'tanh', 'relu', 'leaky_relu'])
  168.         activation_fns.append(activation_fn)
  169.  
  170.     # Vytvoření modelu s náhodně vybranými aktivačními funkcemi
  171.     model = create_model(activation_fns, model_type)
  172.     model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  173.     model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  174.    
  175.     # Predikce pravděpodobností pro testovací sadu
  176.     if model_type == 'cnn':
  177.         X_test_input = X_test.reshape(-1, 2, 1)
  178.     else:
  179.         X_test_input = X_test
  180.    
  181.     y_probs = model.predict(X_test_input)
  182.    
  183.     # Optimalizace threshold pro minimalizaci FP + FN
  184.     threshold = trial.suggest_float('threshold', 0.0, 1.0)
  185.     y_pred_optimal = (y_probs > threshold).astype("int32")
  186.    
  187.     # Výpočet confusion matrix
  188.     conf_matrix = confusion_matrix(y_test, y_pred_optimal)
  189.     FP = conf_matrix[0, 1]
  190.     FN = conf_matrix[1, 0]
  191.  
  192.     # Cílem je minimalizovat FP + FN
  193.     return FP + FN
  194.  
  195. # Optimalizace pro DNN model
  196. study_dnn = optuna.create_study(direction='minimize')
  197. study_dnn.optimize(lambda trial: objective(trial, model_type='dnn'), n_trials=50)
  198.  
  199. # Optimalizace pro CNN model
  200. study_cnn = optuna.create_study(direction='minimize')
  201. study_cnn.optimize(lambda trial: objective(trial, model_type='cnn'), n_trials=50)
  202.  
  203. # Výsledky optimalizace pro DNN model
  204. optimal_threshold_dnn = study_dnn.best_params["threshold"]
  205. best_activation_fns_dnn = [study_dnn.best_params[f'activation_fn_{i}'] for i in range(13)]
  206. print(f'DNN - Optimální threshold: {optimal_threshold_dnn}')
  207. print(f'DNN - Nejlepší kombinace aktivačních funkcí: {best_activation_fns_dnn}')
  208.  
  209. # Výsledky optimalizace pro CNN model
  210. optimal_threshold_cnn = study_cnn.best_params["threshold"]
  211. best_activation_fns_cnn = [study_cnn.best_params[f'activation_fn_{i}'] for i in range(13)]
  212. print(f'CNN - Optimální threshold: {optimal_threshold_cnn}')
  213. print(f'CNN - Nejlepší kombinace aktivačních funkcí: {best_activation_fns_cnn}')
  214.  
  215. # Vytvoření nejlepšího modelu s optimalizovanými aktivačními funkcemi pro DNN
  216. best_dnn_model = create_model(best_activation_fns_dnn, model_type='dnn')
  217. best_dnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  218. best_dnn_model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  219.  
  220. # Predikce s nejlepšími aktivačními funkcemi a optimálním threshold pro DNN
  221. y_probs_dnn_best = best_dnn_model.predict(X_test)
  222. y_pred_dnn_best = (y_probs_dnn_best > optimal_threshold_dnn).astype("int32")
  223.  
  224. # Vyhodnocení výsledného DNN modelu pomocí confusion matrix
  225. conf_matrix_dnn_best = confusion_matrix(y_test, y_pred_dnn_best)
  226. plt.figure(figsize=(10, 7))
  227. sns.heatmap(conf_matrix_dnn_best, annot=True, fmt='d', cmap='Blues')
  228. plt.title(f'Konfuzní Matice (DNN) - Testovací data, Optimalizovaný Threshold: {optimal_threshold_dnn:.2f}')
  229. plt.ylabel('Skutečný Štítek')
  230. plt.xlabel('Predikovaný Štítek')
  231. plt.show()
  232.  
  233. # Vytvoření nejlepšího modelu s optimalizovanými aktivačními funkcemi pro CNN
  234. best_cnn_model = create_model(best_activation_fns_cnn, model_type='cnn')
  235. best_cnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  236. best_cnn_model.fit(X_train.reshape(-1, 2, 1), y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  237.  
  238. # Predikce s nejlepšími aktivačními funkcemi a optimálním threshold pro CNN
  239. y_probs_cnn_best = best_cnn_model.predict(X_test.reshape(-1, 2, 1))
  240. y_pred_cnn_best = (y_probs_cnn_best > optimal_threshold_cnn).astype("int32")
  241.  
  242. # Vyhodnocení výsledného CNN modelu pomocí confusion matrix
  243. conf_matrix_cnn_best = confusion_matrix(y_test, y_pred_cnn_best)
  244. plt.figure(figsize=(10, 7))
  245. sns.heatmap(conf_matrix_cnn_best, annot=True, fmt='d', cmap='Blues')
  246. plt.title(f'Konfuzní Matice (CNN) - Testovací data, Optimalizovaný Threshold: {optimal_threshold_cnn:.2f}')
  247. plt.ylabel('Skutečný Štítek')
  248. plt.xlabel('Predikovaný Štítek')
  249. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement