Advertisement
max2201111

zelene zlute OK krivka

Aug 21st, 2024
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.39 KB | Science | 0 0
  1. import numpy as np
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Dense, Flatten, Input
  4. from tensorflow.keras.layers import LeakyReLU
  5. from sklearn.preprocessing import StandardScaler
  6. from sklearn.metrics import confusion_matrix
  7. from sklearn.utils.class_weight import compute_class_weight
  8. import matplotlib.pyplot as plt
  9. import seaborn as sns
  10. import optuna
  11. from sklearn.preprocessing import PolynomialFeatures
  12. from sklearn.linear_model import LinearRegression
  13.  
  14. # Trénovací data (2D)
  15. train_data = np.array([
  16.     [1141, 1050, 1], [1499, 1050, 0], [1451, 1077, 1], [1519, 1077, 1], [1191, 1093, 0],
  17.     [1590, 1093, 1], [1777, 1093, 0], [1141, 1124, 1], [1499, 1124, 0], [1601, 1124, 0],
  18.     [1606, 1141, 0], [1608, 1141, 0], [1870, 1141, 0], [1794, 1191, 0], [1329, 1211, 1],
  19.     [1687, 1211, 0], [1918, 1211, 0], [1608, 1212, 0], [1050, 1228, 0], [2107, 1228, 0],
  20.     [2202, 1266, 0], [1551, 1327, 0], [1608, 1327, 1], [1660, 1327, 0], [1329, 1332, 1],
  21.     [1093, 1346, 0], [1546, 1419, 1], [1327, 1435, 0], [1774, 1435, 0], [1794, 1451, 0],
  22.     [1077, 1458, 1], [1093, 1458, 1], [1731, 1458, 0], [1777, 1491, 0], [1212, 1499, 1],
  23.     [1211, 1519, 1], [1608, 1519, 1], [1918, 1519, 0], [1458, 1538, 1], [1918, 1538, 0],
  24.     [1794, 1545, 0], [1903, 1545, 0], [1435, 1546, 1], [1758, 1546, 0], [2076, 1546, 0],
  25.     [1077, 1551, 1], [1690, 1551, 0], [1050, 1590, 1], [1093, 1601, 1], [1327, 1601, 0],
  26.     [1050, 1606, 1], [1491, 1606, 1], [1519, 1608, 0], [1266, 1660, 1], [1839, 1660, 0],
  27.     [1332, 1687, 0], [1519, 1687, 0], [1538, 1690, 1], [1870, 1690, 0], [1903, 1731, 1],
  28.     [1918, 1731, 0], [1419, 1758, 0], [1839, 1758, 0], [1077, 1774, 1], [1519, 1774, 1],
  29.     [2202, 1774, 0], [1538, 1777, 0], [1903, 1777, 1], [2107, 1777, 0], [1660, 1794, 0],
  30.     [2107, 1794, 0], [1124, 1839, 1], [1519, 1839, 1], [1546, 1839, 1], [1870, 1839, 1],
  31.     [2202, 1839, 0], [1419, 1870, 1], [2107, 1870, 0], [2202, 1870, 0], [1191, 1903, 1],
  32.     [1601, 1903, 1], [1606, 1903, 1], [1660, 1903, 1], [1491, 1918, 1], [1212, 2076, 1],
  33.     [1690, 2076, 1], [1546, 2107, 1], [1903, 2107, 1], [2183, 2107, 0], [1229, 2183, 1],
  34.     [1731, 2183, 1], [1758, 2183, 0], [1918, 2183, 1], [2076, 2183, 0], [1538, 2202, 1],
  35.     [1601, 2202, 1], [2076, 2202, 0], [1660, 2258, 1], [1777, 2258, 0], [2202, 2258, 0],
  36.     [1477, 1141, 0], [1519, 1141, 0], [1519, 1211, 1], [1549, 1211, 0], [1731, 1211, 0],
  37.     [1848, 1211, 0], [1458, 1229, 1], [1528, 1229, 1], [1354, 1327, 0], [1448, 1327, 0],
  38.     [1458, 1327, 1], [1499, 1327, 0], [1758, 1327, 0], [1405, 1332, 0], [1519, 1332, 0],
  39.     [1549, 1332, 0], [1605, 1332, 0], [1606, 1332, 0], [1606, 1332, 1], [1794, 1332, 1],
  40.     [1376, 1346, 0], [1608, 1346, 0], [1551, 1354, 1], [1731, 1354, 0], [1849, 1354, 0],
  41.     [1918, 1354, 0], [1332, 1376, 0], [1690, 1376, 0], [1606, 1405, 1], [1608, 1405, 0],
  42.     [1687, 1405, 0], [1774, 1448, 0], [1870, 1448, 0], [1477, 1450, 0], [1551, 1450, 0],
  43.     [1731, 1450, 1], [2076, 1450, 0], [1687, 1458, 0], [1448, 1477, 0], [1794, 1477, 0],
  44.     [2183, 1477, 0], [1211, 1499, 1], [1774, 1499, 0], [1782, 1499, 0], [1848, 1499, 0],
  45.     [1211, 1519, 0], [1229, 1519, 1], [1332, 1519, 0], [2258, 1519, 0], [1332, 1528, 0],
  46.     [1327, 1546, 0], [1332, 1546, 1], [1606, 1546, 0], [1918, 1546, 1], [1499, 1549, 0],
  47.     [2258, 1549, 0], [1327, 1551, 0], [1870, 1551, 0], [1211, 1590, 1], [1229, 1590, 1],
  48.     [1601, 1590, 1], [1211, 1601, 0], [1332, 1601, 1], [1354, 1601, 1], [1405, 1601, 0],
  49.     [1229, 1605, 0], [1332, 1605, 1], [1448, 1605, 0], [1354, 1606, 0], [1927, 1606, 0],
  50.     [1450, 1608, 0], [1519, 1608, 1], [1849, 1608, 0], [1477, 1687, 1], [1848, 1687, 0],
  51.     [1141, 1690, 1], [1327, 1690, 1], [1549, 1731, 1], [1590, 1731, 0], [2250, 1731, 0],
  52.     [1332, 1747, 1], [1927, 1747, 0], [1450, 1758, 1], [2076, 1758, 0], [1747, 1774, 0],
  53.     [1794, 1774, 0], [1332, 1782, 1], [1687, 1782, 0], [1747, 1782, 0], [1758, 1782, 1],
  54.     [2076, 1782, 1], [1458, 1794, 1], [1590, 1794, 1], [2153, 1794, 0], [1747, 1848, 0],
  55.     [1758, 1848, 0], [1870, 1848, 0], [2183, 1848, 0], [1332, 1849, 1], [1731, 1849, 1],
  56.     [1870, 1849, 1], [2183, 1849, 0], [2250, 1849, 0], [1758, 1870, 0], [1918, 1870, 1],
  57.     [1549, 1918, 0], [1448, 1927, 1], [1758, 1927, 1], [2183, 1927, 1], [2250, 1927, 0],
  58.     [2258, 1927, 0], [1601, 2076, 1], [1605, 2076, 1], [1849, 2076, 1], [2258, 2076, 0],
  59.     [1458, 2153, 1], [1927, 2153, 0], [2076, 2153, 0], [1690, 2183, 1], [2076, 2183, 1],
  60.     [2153, 2183, 0], [1499, 2250, 1], [1690, 2250, 1], [1747, 2250, 0], [1918, 2250, 1],
  61.     [2183, 2250, 0], [1747, 2258, 1]
  62. ])
  63.  
  64. # Testovací data (2D)
  65. test_data = np.array([
  66.     [1451, 1050, 0], [1758, 1050, 0], [1346, 1211, 1], [1546, 1332, 1], [1608, 1451, 1],
  67.     [1839, 1458, 0], [1435, 1538, 1], [1077, 1546, 1], [2183, 1551, 0], [1458, 1590, 0],
  68.     [1538, 1606, 0], [1077, 1608, 1], [2258, 1608, 0], [1419, 1690, 1], [1545, 1731, 1],
  69.     [1774, 1758, 0], [1545, 1774, 1], [2183, 1777, 0], [1228, 1794, 1], [1774, 1794, 0],
  70.     [2258, 1870, 1], [1546, 1903, 1], [1774, 1918, 0], [2076, 1918, 0], [1758, 2076, 1],
  71.     [1839, 2076, 0], [2107, 2076, 1], [2258, 2107, 1], [1731, 2202, 1], [1327, 2258, 1],
  72.     [1354, 1229, 0], [1774, 1229, 0], [1546, 1376, 0], [1918, 1405, 0], [1605, 1450, 0],
  73.     [1605, 1477, 0], [1448, 1519, 0], [1450, 1519, 1], [1141, 1528, 0], [1346, 1551, 1],
  74.     [1608, 1606, 0], [1376, 1608, 1], [2153, 1687, 0], [1458, 1690, 0], [1590, 1690, 1],
  75.     [1774, 1731, 0], [1229, 1747, 1], [2250, 1758, 0], [1346, 1848, 1], [1376, 1870, 1],
  76.     [2258, 1870, 0], [1590, 1918, 1], [1849, 2153, 0], [1782, 2183, 1], [2153, 2258, 0]
  77. ])
  78.  
  79. # Rozdělení na vstupy (X) a výstupy (y)
  80. X_train = train_data[:, 0:2]
  81. y_train = train_data[:, 2]
  82. X_test = test_data[:, 0:2]
  83. y_test = test_data[:, 2]
  84.  
  85. # Normalizace dat pomocí Z-score
  86. scaler = StandardScaler()
  87. X_train = scaler.fit_transform(X_train)
  88. X_test = scaler.transform(X_test)
  89.  
  90. # Výpočet váhy tříd
  91. class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
  92. class_weight_dict = dict(enumerate(class_weights))
  93.  
  94. # Definování DNN modelu
  95. dnn_model = Sequential([
  96.     Input(shape=(2,)),  # Vstupní vrstva pro 2D data
  97.     Dense(128, activation='relu'),  # První skrytá vrstva
  98.     Dense(64, activation='relu'),   # Druhá skrytá vrstva
  99.     Dense(64, activation='relu'),   # Třetí skrytá vrstva
  100.     Dense(64, activation='relu'),   # Čtvrtá skrytá vrstva
  101.     Dense(64, activation='relu'),   # Pátá skrytá vrstva
  102.     Dense(32, activation='relu'),   # Šestá skrytá vrstva
  103.     Dense(32, activation='relu'),   # Sedmá skrytá vrstva
  104.     Dense(32, activation='relu'),   # Osmá skrytá vrstva
  105.     Dense(32, activation='relu'),   # Devátá skrytá vrstva
  106.     Dense(16, activation='relu'),   # Desátá skrytá vrstva
  107.     Dense(16, activation='relu'),   # Jedenáctá skrytá vrstva
  108.     Dense(16, activation='relu'),   # Dvanáctá skrytá vrstva
  109.     Dense(8, activation='relu'),    # Třináctá skrytá vrstva
  110.     Dense(1, activation='sigmoid')  # Výstupní vrstva
  111. ])
  112.  
  113. # Kompilace DNN modelu s použitím class_weight
  114. dnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  115.  
  116. # Trénování DNN modelu s použitím class_weight
  117. dnn_model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  118.  
  119. # Predikce pravděpodobností pro testovací sadu pomocí DNN
  120. y_probs_dnn = dnn_model.predict(X_test)
  121.  
  122. # Definování funkce pro optimalizaci pomocí Optuna
  123. def objective(trial):
  124.     # Návrh rozsahu pro threshold
  125.     threshold = trial.suggest_float('threshold', 0.0, 1.0)
  126.  
  127.     # Aplikace prahu na predikované pravděpodobnosti z DNN modelu
  128.     y_pred_dnn_optuna = (y_probs_dnn > threshold).astype("int32")
  129.  
  130.     # Výpočet confusion matrix
  131.     conf_matrix = confusion_matrix(y_test, y_pred_dnn_optuna)
  132.  
  133.     # Extrakce FP a FN
  134.     FP = conf_matrix[0, 1]
  135.     FN = conf_matrix[1, 0]
  136.  
  137.     # Návratová hodnota: minimalizujeme FP + FN
  138.     return FP + FN
  139.  
  140. # Inicializace studie Optuna a provedení optimalizace
  141. study = optuna.create_study(direction='minimize')
  142. study.optimize(objective, n_trials=50)
  143.  
  144. # Výsledky optimalizace
  145. optimal_threshold = study.best_params["threshold"]
  146. print(f'Optimální threshold: {optimal_threshold}')
  147. print(f'Nejlepší dosažená hodnota FP + FN: {study.best_value}')
  148.  
  149. # Použití optimálního threshold na DNN model
  150. y_pred_dnn_optimal = (y_probs_dnn > optimal_threshold).astype("int32")
  151.  
  152. # Vyhodnocení modelu s optimálním threshold
  153. conf_matrix_dnn_optimal = confusion_matrix(y_test, y_pred_dnn_optimal)
  154. plt.figure(figsize=(10, 7))
  155. sns.heatmap(conf_matrix_dnn_optimal, annot=True, fmt='d', cmap='Blues')
  156. plt.title(f'Konfuzní Matice (DNN) - Testovací data, Optimalizovaný Threshold: {optimal_threshold:.2f}')
  157. plt.ylabel('Skutečný Štítek')
  158. plt.xlabel('Predikovaný Štítek')
  159. plt.show()
  160.  
  161. # Denormalizace rozdílu ELO pro grafy
  162. X_train_denorm = scaler.inverse_transform(X_train)
  163. X_test_denorm = scaler.inverse_transform(X_test)
  164.  
  165. # Přidání rozměru pro CNN
  166. X_train_cnn = X_train.reshape(-1, 2, 1)
  167. X_test_cnn = X_test.reshape(-1, 2, 1)
  168.  
  169. # Definování CNN modelu s 13 skrytými vrstvami
  170. cnn_model = Sequential([
  171.     Input(shape=(2, 1)),  # Vstupní vrstva pro 2D data
  172.     Flatten(),  # Plochý vstup
  173.     Dense(128, activation='relu'),  # První hustá vrstva
  174.     Dense(64, activation='relu'),   # Druhá hustá vrstva
  175.     Dense(64, activation='relu'),   # Třetí hustá vrstva
  176.     Dense(64, activation='relu'),   # Čtvrtá hustá vrstva
  177.     Dense(64, activation='relu'),   # Pátá hustá vrstva
  178.     Dense(32, activation='relu'),   # Šestá hustá vrstva
  179.     Dense(32, activation='relu'),   # Sedmá hustá vrstva
  180.     Dense(32, activation='relu'),   # Osmá hustá vrstva
  181.     Dense(32, activation='relu'),   # Devátá hustá vrstva
  182.     Dense(16, activation='relu'),   # Desátá hustá vrstva
  183.     Dense(16, activation='relu'),   # Jedenáctá skrytá vrstva
  184.     Dense(16, activation='relu'),   # Dvanáctá hustá vrstva
  185.     Dense(8, activation='relu'),    # Třináctá hustá vrstva
  186.     Dense(1, activation='sigmoid')  # Výstupní vrstva
  187. ])
  188.  
  189. # Kompilace CNN modelu
  190. cnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  191.  
  192. # Trénování CNN modelu
  193. cnn_model.fit(X_train_cnn, y_train, epochs=200, verbose=0, class_weight=class_weight_dict)
  194.  
  195. # Predikce pravděpodobností pro testovací sadu pomocí CNN
  196. y_pred_probs_cnn = cnn_model.predict(X_test_cnn).flatten()
  197. y_pred_cnn = (y_pred_probs_cnn > optimal_threshold).astype("int32")
  198.  
  199. # Predikce pravděpodobností pro trénovací sadu pomocí CNN
  200. y_pred_probs_train_cnn = cnn_model.predict(X_train_cnn).flatten()
  201. y_pred_train_cnn = (y_pred_probs_train_cnn > optimal_threshold).astype("int32")
  202.  
  203. # Vyhodnocení modelu pomocí confusion matrix (trénovací data)
  204. conf_matrix_train_cnn = confusion_matrix(y_train, y_pred_train_cnn)
  205. plt.figure(figsize=(10, 7))
  206. sns.heatmap(conf_matrix_train_cnn, annot=True, fmt='d', cmap='Greens')
  207. plt.title('Konfuzní Matice (CNN) - Trénovací data, Optimalizovaný Threshold')
  208. plt.ylabel('Skutečný Štítek')
  209. plt.xlabel('Predikovaný Štítek')
  210. plt.show()
  211.  
  212. # Funkce pro proložení hladké polynomiální křivky
  213. def plot_polynomial_curve(x, y, ax, degree=3):
  214.     # Seřadit podle hodnoty x
  215.     sorted_indices = np.argsort(x.flatten())
  216.     x_sorted = x.flatten()[sorted_indices]
  217.     y_sorted = y.flatten()[sorted_indices]
  218.    
  219.     # Polynomiální fit
  220.     poly = PolynomialFeatures(degree=degree)
  221.     x_poly = poly.fit_transform(x_sorted.reshape(-1, 1))
  222.    
  223.     model = LinearRegression()
  224.     model.fit(x_poly, y_sorted)
  225.    
  226.     # Predikce pomocí polynomiálního modelu
  227.     y_poly_pred = model.predict(x_poly)
  228.    
  229.     # Vykreslení křivky
  230.     ax.plot(x_sorted, y_poly_pred, '-.', color='green', linewidth=1)
  231.  
  232. # Vykreslení grafu Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) s obarvením a legendou pro testovací data
  233. fig, ax = plt.subplots(figsize=(10, 5))
  234. ax.scatter([], [], color='yellow', s=100, label='TP')  # Přidání TP do legendy
  235. ax.scatter([], [], color='green', s=100, label='TN')   # Přidání TN do legendy
  236. ax.scatter([], [], color='red', s=100, label='FP')     # Přidání FP do legendy
  237. ax.scatter([], [], color='blue', s=100, label='FN')    # Přidání FN do legendy
  238.  
  239. for i in range(len(y_test)):
  240.     if y_test[i] == 1 and y_pred_dnn_optimal[i] == 1:
  241.         ax.scatter(X_test_denorm[i, 0], y_probs_dnn[i], color='yellow', s=100, marker='o')
  242.     elif y_test[i] == 0 and y_pred_dnn_optimal[i] == 0:
  243.         ax.scatter(X_test_denorm[i, 0], y_probs_dnn[i], color='green', s=100, marker='o')
  244.     elif y_test[i] == 0 and y_pred_dnn_optimal[i] == 1:
  245.         ax.scatter(X_test_denorm[i, 0], y_probs_dnn[i], color='red', s=100, marker='o')
  246.     elif y_test[i] == 1 and y_pred_dnn_optimal[i] == 0:
  247.         ax.scatter(X_test_denorm[i, 0], y_probs_dnn[i], color='blue', s=100, marker='o')
  248.  
  249. plot_polynomial_curve(X_test_denorm[:, 0], y_probs_dnn, ax)
  250.  
  251. ax.set_title(f'Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) - Testovací data, Threshold: {optimal_threshold:.2f}')
  252. ax.set_xlabel('Rozdíl ELO (Denormalizováno)')
  253. ax.set_ylabel('Predikovaná Pravděpodobnost')
  254. ax.legend(loc='upper left')
  255. ax.grid(True)
  256. plt.show()
  257.  
  258. # Vykreslení grafu Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) s obarvením a legendou pro testovací data
  259. fig, ax = plt.subplots(figsize=(10, 5))
  260. ax.scatter([], [], color='yellow', s=100, label='TP')  # Přidání TP do legendy
  261. ax.scatter([], [], color='green', s=100, label='TN')   # Přidání TN do legendy
  262. ax.scatter([], [], color='red', s=100, label='FP')     # Přidání FP do legendy
  263. ax.scatter([], [], color='blue', s=100, label='FN')    # Přidání FN do legendy
  264.  
  265. for i in range(len(y_test)):
  266.     if y_test[i] == 1 and y_pred_cnn[i] == 1:
  267.         ax.scatter(X_test_denorm[i, 0], y_pred_probs_cnn[i], color='yellow', s=100, marker='o')
  268.     elif y_test[i] == 0 and y_pred_cnn[i] == 0:
  269.         ax.scatter(X_test_denorm[i, 0], y_pred_probs_cnn[i], color='green', s=100, marker='o')
  270.     elif y_test[i] == 0 and y_pred_cnn[i] == 1:
  271.         ax.scatter(X_test_denorm[i, 0], y_pred_probs_cnn[i], color='red', s=100, marker='o')
  272.     elif y_test[i] == 1 and y_pred_cnn[i] == 0:
  273.         ax.scatter(X_test_denorm[i, 0], y_pred_probs_cnn[i], color='blue', s=100, marker='o')
  274.  
  275. plot_polynomial_curve(X_test_denorm[:, 0], y_pred_probs_cnn, ax)
  276.  
  277. ax.set_title('Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) - Testovací data')
  278. ax.set_xlabel('Rozdíl ELO (Denormalizováno)')
  279. ax.set_ylabel('Predikovaná Pravděpodobnost')
  280. ax.legend(loc='upper left')
  281. ax.grid(True)
  282. plt.show()
  283.  
  284. # Vykreslení grafu Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) s obarvením a legendou pro trénovací data (podle labelů)
  285. fig, ax = plt.subplots(figsize=(10, 5))
  286. ax.scatter([], [], color='yellow', s=100, label='Label 1')  # Přidání labelu 1 do legendy
  287. ax.scatter([], [], color='green', s=100, label='Label 0')   # Přidání labelu 0 do legendy
  288.  
  289. for i in range(len(y_train)):
  290.     if y_train[i] == 1:
  291.         ax.scatter(X_train_denorm[i, 0], y_probs_train_dnn[i], color='yellow', s=100, marker='o')
  292.     elif y_train[i] == 0:
  293.         ax.scatter(X_train_denorm[i, 0], y_probs_train_dnn[i], color='green', s=100, marker='o')
  294.  
  295. plot_polynomial_curve(X_train_denorm[:, 0], y_probs_train_dnn, ax)
  296.  
  297. ax.set_title('Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) - Trénovací data')
  298. ax.set_xlabel('Rozdíl ELO (Denormalizováno)')
  299. ax.set_ylabel('Predikovaná Pravděpodobnost')
  300. ax.legend(loc='upper left')
  301. ax.grid(True)
  302. plt.show()
  303.  
  304. # Vykreslení grafu Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) s obarvením a legendou pro trénovací data (podle labelů)
  305. fig, ax = plt.subplots(figsize=(10, 5))
  306. ax.scatter([], [], color='yellow', s=100, label='Label 1')  # Přidání labelu 1 do legendy
  307. ax.scatter([], [], color='green', s=100, label='Label 0')   # Přidání labelu 0 do legendy
  308.  
  309. for i in range(len(y_train)):
  310.     if y_train[i] == 1:
  311.         ax.scatter(X_train_denorm[i, 0], y_pred_probs_train_cnn[i], color='yellow', s=100, marker='o')
  312.     elif y_train[i] == 0:
  313.         ax.scatter(X_train_denorm[i, 0], y_pred_probs_train_cnn[i], color='green', s=100, marker='o')
  314.  
  315. plot_polynomial_curve(X_train_denorm[:, 0], y_pred_probs_train_cnn, ax)
  316.  
  317. ax.set_title('Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) - Trénovací data')
  318. ax.set_xlabel('Rozdíl ELO (Denormalizováno)')
  319. ax.set_ylabel('Predikovaná Pravděpodobnost')
  320. ax.legend(loc='upper left')
  321. ax.grid(True)
  322. plt.show()
  323.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement