Advertisement
max2201111

spatne

Aug 21st, 2024
154
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.88 KB | Science | 0 0
  1. import numpy as np
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Dense, Flatten, Input, Activation, LeakyReLU
  4. from sklearn.preprocessing import StandardScaler
  5. from sklearn.metrics import confusion_matrix
  6. from sklearn.utils.class_weight import compute_class_weight
  7. import matplotlib.pyplot as plt
  8. import seaborn as sns
  9.  
  10. # Trénovací data (2D)
  11. train_data = np.array([
  12.     [1141, 1050, 1], [1499, 1050, 0], [1451, 1077, 1], [1519, 1077, 1], [1191, 1093, 0],
  13.     [1590, 1093, 1], [1777, 1093, 0], [1141, 1124, 1], [1499, 1124, 0], [1601, 1124, 0],
  14.     [1606, 1141, 0], [1608, 1141, 0], [1870, 1141, 0], [1794, 1191, 0], [1329, 1211, 1],
  15.     [1687, 1211, 0], [1918, 1211, 0], [1608, 1212, 0], [1050, 1228, 0], [2107, 1228, 0],
  16.     [2202, 1266, 0], [1551, 1327, 0], [1608, 1327, 1], [1660, 1327, 0], [1329, 1332, 1],
  17.     [1093, 1346, 0], [1546, 1419, 1], [1327, 1435, 0], [1774, 1435, 0], [1794, 1451, 0],
  18.     [1077, 1458, 1], [1093, 1458, 1], [1731, 1458, 0], [1777, 1491, 0], [1212, 1499, 1],
  19.     [1211, 1519, 1], [1608, 1519, 1], [1918, 1519, 0], [1458, 1538, 1], [1918, 1538, 0],
  20.     [1794, 1545, 0], [1903, 1545, 0], [1435, 1546, 1], [1758, 1546, 0], [2076, 1546, 0],
  21.     [1077, 1551, 1], [1690, 1551, 0], [1050, 1590, 1], [1093, 1601, 1], [1327, 1601, 0],
  22.     [1050, 1606, 1], [1491, 1606, 1], [1519, 1608, 0], [1266, 1660, 1], [1839, 1660, 0],
  23.     [1332, 1687, 0], [1519, 1687, 0], [1538, 1690, 1], [1870, 1690, 0], [1903, 1731, 1],
  24.     [1918, 1731, 0], [1419, 1758, 0], [1839, 1758, 0], [1077, 1774, 1], [1519, 1774, 1],
  25.     [2202, 1774, 0], [1538, 1777, 0], [1903, 1777, 1], [2107, 1777, 0], [1660, 1794, 0],
  26.     [2107, 1794, 0], [1124, 1839, 1], [1519, 1839, 1], [1546, 1839, 1], [1870, 1839, 1],
  27.     [2202, 1839, 0], [1419, 1870, 1], [2107, 1870, 0], [2202, 1870, 0], [1191, 1903, 1],
  28.     [1601, 1903, 1], [1606, 1903, 1], [1660, 1903, 1], [1491, 1918, 1], [1212, 2076, 1],
  29.     [1690, 2076, 1], [1546, 2107, 1], [1903, 2107, 1], [2183, 2107, 0], [1229, 2183, 1],
  30.     [1731, 2183, 1], [1758, 2183, 0], [1918, 2183, 1], [2076, 2183, 0], [1538, 2202, 1],
  31.     [1601, 2202, 1], [2076, 2202, 0], [1660, 2258, 1], [1777, 2258, 0], [2202, 2258, 0],
  32.     [1477, 1141, 0], [1519, 1141, 0], [1519, 1211, 1], [1549, 1211, 0], [1731, 1211, 0],
  33.     [1848, 1211, 0], [1458, 1229, 1], [1528, 1229, 1], [1354, 1327, 0], [1448, 1327, 0],
  34.     [1458, 1327, 1], [1499, 1327, 0], [1758, 1327, 0], [1405, 1332, 0], [1519, 1332, 0],
  35.     [1549, 1332, 0], [1605, 1332, 0], [1606, 1332, 0], [1606, 1332, 1], [1794, 1332, 1],
  36.     [1376, 1346, 0], [1608, 1346, 0], [1551, 1354, 1], [1731, 1354, 0], [1849, 1354, 0],
  37.     [1918, 1354, 0], [1332, 1376, 0], [1690, 1376, 0], [1606, 1405, 1], [1608, 1405, 0],
  38.     [1687, 1405, 0], [1774, 1448, 0], [1870, 1448, 0], [1477, 1450, 0], [1551, 1450, 0],
  39.     [1731, 1450, 1], [2076, 1450, 0], [1687, 1458, 0], [1448, 1477, 0], [1794, 1477, 0],
  40.     [2183, 1477, 0], [1211, 1499, 1], [1774, 1499, 0], [1782, 1499, 0], [1848, 1499, 0],
  41.     [1211, 1519, 0], [1229, 1519, 1], [1332, 1519, 0], [2258, 1519, 0], [1332, 1528, 0],
  42.     [1327, 1546, 0], [1332, 1546, 1], [1606, 1546, 0], [1918, 1546, 1], [1499, 1549, 0],
  43.     [2258, 1549, 0], [1327, 1551, 0], [1870, 1551, 0], [1211, 1590, 1], [1229, 1590, 1],
  44.     [1601, 1590, 1], [1211, 1601, 0], [1332, 1601, 1], [1354, 1601, 1], [1405, 1601, 0],
  45.     [1229, 1605, 0], [1332, 1605, 1], [1448, 1605, 0], [1354, 1606, 0], [1927, 1606, 0],
  46.     [1450, 1608, 0], [1519, 1608, 1], [1849, 1608, 0], [1477, 1687, 1], [1848, 1687, 0],
  47.     [1141, 1690, 1], [1327, 1690, 1], [1549, 1731, 1], [1590, 1731, 0], [2250, 1731, 0],
  48.     [1332, 1747, 1], [1927, 1747, 0], [1450, 1758, 1], [2076, 1758, 0], [1747, 1774, 0],
  49.     [1794, 1774, 0], [1332, 1782, 1], [1687, 1782, 0], [1747, 1782, 0], [1758, 1782, 1],
  50.     [2076, 1782, 1], [1458, 1794, 1], [1590, 1794, 1], [2153, 1794, 0], [1747, 1848, 0],
  51.     [1758, 1848, 0], [1870, 1848, 0], [2183, 1848, 0], [1332, 1849, 1], [1731, 1849, 1],
  52.     [1870, 1849, 1], [2183, 1849, 0], [2250, 1849, 0], [1758, 1870, 0], [1918, 1870, 1],
  53.     [1549, 1918, 0], [1448, 1927, 1], [1758, 1927, 1], [2183, 1927, 1], [2250, 1927, 0],
  54.     [2258, 1927, 0], [1601, 2076, 1], [1605, 2076, 1], [1849, 2076, 1], [2258, 2076, 0],
  55.     [1458, 2153, 1], [1927, 2153, 0], [2076, 2153, 0], [1690, 2183, 1], [2076, 2183, 1],
  56.     [2153, 2183, 0], [1499, 2250, 1], [1690, 2250, 1], [1747, 2250, 0], [1918, 2250, 1],
  57.     [2183, 2250, 0], [1747, 2258, 1]
  58. ])
  59.  
  60. # Testovací data (2D)
  61. test_data = np.array([
  62.     [1451, 1050, 0], [1758, 1050, 0], [1346, 1211, 1], [1546, 1332, 1], [1608, 1451, 1],
  63.     [1839, 1458, 0], [1435, 1538, 1], [1077, 1546, 1], [2183, 1551, 0], [1458, 1590, 0],
  64.     [1538, 1606, 0], [1077, 1608, 1], [2258, 1608, 0], [1419, 1690, 1], [1545, 1731, 1],
  65.     [1774, 1758, 0], [1545, 1774, 1], [2183, 1777, 0], [1228, 1794, 1], [1774, 1794, 0],
  66.     [2258, 1870, 1], [1546, 1903, 1], [1774, 1918, 0], [2076, 1918, 0], [1758, 2076, 1],
  67.     [1839, 2076, 0], [2107, 2076, 1], [2258, 2107, 1], [1731, 2202, 1], [1327, 2258, 1],
  68.     [1354, 1229, 0], [1774, 1229, 0], [1546, 1376, 0], [1918, 1405, 0], [1605, 1450, 0],
  69.     [1605, 1477, 0], [1448, 1519, 0], [1450, 1519, 1], [1141, 1528, 0], [1346, 1551, 1],
  70.     [1608, 1606, 0], [1376, 1608, 1], [2153, 1687, 0], [1458, 1690, 0], [1590, 1690, 1],
  71.     [1774, 1731, 0], [1229, 1747, 1], [2250, 1758, 0], [1346, 1848, 1], [1376, 1870, 1],
  72.     [2258, 1870, 0], [1590, 1918, 1], [1849, 2153, 0], [1782, 2183, 1], [2153, 2258, 0]
  73. ])
  74.  
  75. # Rozdělení na vstupy (X) a výstupy (y)
  76. X_train = train_data[:, 0:2]
  77. y_train = train_data[:, 2]
  78. X_test = test_data[:, 0:2]
  79. y_test = test_data[:, 2]
  80.  
  81. # Normalizace dat pomocí Z-score
  82. scaler = StandardScaler()
  83. X_train = scaler.fit_transform(X_train)
  84. X_test = scaler.transform(X_test)
  85.  
  86. # Výpočet váhy tříd
  87. class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
  88. class_weight_dict = dict(enumerate(class_weights))
  89.  
  90. # DNN model s pevně nastavenými aktivačními funkcemi a threshold
  91. optimal_threshold_dnn = 0.6470964401352306
  92. optimal_threshold_dnn = 0.5
  93.  
  94. dnn_model = Sequential([
  95.     Input(shape=(2,)),  # Vstupní vrstva pro 2D data
  96.     Dense(128),  # První skrytá vrstva
  97.     LeakyReLU(),  # Aktivace první vrstvy
  98.     Dense(64),
  99.     Activation('relu'),  # Aktivace druhé vrstvy
  100.     Dense(64),
  101.     Activation('tanh'),  # Aktivace třetí vrstvy
  102.     Dense(64),
  103.     Activation('tanh'),  # Aktivace čtvrté vrstvy
  104.     Dense(64),
  105.     Activation('tanh'),  # Aktivace páté vrstvy
  106.     Dense(32),
  107.     Activation('tanh'),  # Aktivace šesté vrstvy
  108.     Dense(32),
  109.     Activation('tanh'),  # Aktivace sedmé vrstvy
  110.     Dense(32),
  111.     Activation('sigmoid'),  # Aktivace osmé vrstvy
  112.     Dense(32),
  113.     Activation('tanh'),  # Aktivace deváté vrstvy
  114.     Dense(16),
  115.     Activation('sigmoid'),  # Aktivace desáté vrstvy
  116.     Dense(16),
  117.     Activation('relu'),  # Aktivace jedenácté vrstvy
  118.     Dense(16),
  119.     LeakyReLU(),  # Aktivace dvanácté vrstvy
  120.     Dense(8),
  121.     Activation('relu'),  # Aktivace třinácté vrstvy
  122.     Dense(1, activation='sigmoid')  # Výstupní vrstva
  123. ])
  124.  
  125. dnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  126.  
  127. # Trénování DNN modelu
  128. dnn_model.fit(X_train, y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  129.  
  130. # Predikce s pevně nastavenými aktivačními funkcemi a optimálním threshold pro DNN
  131. y_probs_dnn_best = dnn_model.predict(X_test)
  132. y_pred_dnn_best = (y_probs_dnn_best > optimal_threshold_dnn).astype("int32")
  133.  
  134. # Vyhodnocení výsledného DNN modelu pomocí confusion matrix
  135. conf_matrix_dnn_best = confusion_matrix(y_test, y_pred_dnn_best)
  136. plt.figure(figsize=(10, 7))
  137. sns.heatmap(conf_matrix_dnn_best, annot=True, fmt='d', cmap='Blues')
  138. plt.title(f'Konfuzní Matice (DNN) - Testovací data, Optimalizovaný Threshold: {optimal_threshold_dnn:.2f}')
  139. plt.ylabel('Skutečný Štítek')
  140. plt.xlabel('Predikovaný Štítek')
  141. plt.show()
  142.  
  143. # CNN model s pevně nastavenými aktivačními funkcemi a threshold
  144. optimal_threshold_cnn = 0.3527088104310437
  145. optimal_threshold_cnn = 0.5
  146.  
  147. cnn_model = Sequential([
  148.     Input(shape=(2, 1)),  # Vstupní vrstva pro 2D data
  149.     Flatten(),  # Plochý vstup
  150.     Dense(128),
  151.     Activation('tanh'),  # Aktivace první vrstvy
  152.     Dense(64),
  153.     Activation('tanh'),  # Aktivace druhé vrstvy
  154.     Dense(64),
  155.     Activation('relu'),  # Aktivace třetí vrstvy
  156.     Dense(64),
  157.     Activation('sigmoid'),  # Aktivace čtvrté vrstvy
  158.     Dense(64),
  159.     Activation('sigmoid'),  # Aktivace páté vrstvy
  160.     Dense(32),
  161.     Activation('sigmoid'),  # Aktivace šesté vrstvy
  162.     Dense(32),
  163.     Activation('sigmoid'),  # Aktivace sedmé vrstvy
  164.     Dense(32),
  165.     Activation('tanh'),  # Aktivace osmé vrstvy
  166.     Dense(32),
  167.     LeakyReLU(),  # Aktivace deváté vrstvy
  168.     Dense(16),
  169.     Activation('relu'),  # Aktivace desáté vrstvy
  170.     Dense(16),
  171.     Activation('sigmoid'),  # Aktivace jedenácté vrstvy
  172.     Dense(16),
  173.     Activation('relu'),  # Aktivace dvanácté vrstvy
  174.     Dense(8),
  175.     Activation('relu'),  # Aktivace třinácté vrstvy
  176.     Dense(1, activation='sigmoid')  # Výstupní vrstva
  177. ])
  178.  
  179. cnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  180.  
  181. # Trénování CNN modelu
  182. cnn_model.fit(X_train.reshape(-1, 2, 1), y_train, epochs=200, batch_size=10, verbose=0, class_weight=class_weight_dict)
  183.  
  184. # Predikce s pevně nastavenými aktivačními funkcemi a optimálním threshold pro CNN
  185. y_probs_cnn_best = cnn_model.predict(X_test.reshape(-1, 2, 1))
  186. y_pred_cnn_best = (y_probs_cnn_best > optimal_threshold_cnn).astype("int32")
  187.  
  188. # Vyhodnocení výsledného CNN modelu pomocí confusion matrix
  189. conf_matrix_cnn_best = confusion_matrix(y_test, y_pred_cnn_best)
  190. plt.figure(figsize=(10, 7))
  191. sns.heatmap(conf_matrix_cnn_best, annot=True, fmt='d', cmap='Blues')
  192. plt.title(f'Konfuzní Matice (CNN) - Testovací data, Optimalizovaný Threshold: {optimal_threshold_cnn:.2f}')
  193. plt.ylabel('Skutečný Štítek')
  194. plt.xlabel('Predikovaný Štítek')
  195. plt.show()
  196.  
  197. # Denormalizace rozdílu ELO pro grafy
  198. X_train_denorm = scaler.inverse_transform(X_train)
  199. X_test_denorm = scaler.inverse_transform(X_test)
  200.  
  201. # Funkce pro vykreslení grafu s barvami podle výsledků klasifikace
  202. def plot_classification_results(X, y_true, y_probs, y_pred, threshold, title):
  203.     fig, ax = plt.subplots(figsize=(10, 5))
  204.     ax.scatter([], [], color='yellow', s=100, label='TP')  # Přidání TP do legendy
  205.     ax.scatter([], [], color='green', s=100, label='TN')   # Přidání TN do legendy
  206.     ax.scatter([], [], color='red', s=100, label='FP')     # Přidání FP do legendy
  207.     ax.scatter([], [], color='blue', s=100, label='FN')    # Přidání FN do legendy
  208.  
  209.     for i in range(len(y_true)):
  210.         if y_true[i] == 1 and y_pred[i] == 1:
  211.             ax.scatter(X[i, 0], y_probs[i], color='yellow', s=100, marker='o')
  212.         elif y_true[i] == 0 and y_pred[i] == 0:
  213.             ax.scatter(X[i, 0], y_probs[i], color='green', s=100, marker='o')
  214.         elif y_true[i] == 0 and y_pred[i] == 1:
  215.             ax.scatter(X[i, 0], y_probs[i], color='red', s=100, marker='o')
  216.         elif y_true[i] == 1 and y_pred[i] == 0:
  217.             ax.scatter(X[i, 0], y_probs[i], color='blue', s=100, marker='o')
  218.  
  219.     ax.set_title(title)
  220.     ax.set_xlabel('Rozdíl ELO (Denormalizováno)')
  221.     ax.set_ylabel('Predikovaná Pravděpodobnost')
  222.     ax.legend(loc='upper left')
  223.     ax.grid(True)
  224.     plt.show()
  225.  
  226. # Vykreslení grafu pro DNN - Testovací data
  227. plot_classification_results(X_test_denorm, y_test, y_probs_dnn_best, y_pred_dnn_best, optimal_threshold_dnn,
  228.                             'Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) - Testovací data')
  229.  
  230. # Vykreslení grafu pro CNN - Testovací data
  231. plot_classification_results(X_test_denorm, y_test, y_probs_cnn_best, y_pred_cnn_best, optimal_threshold_cnn,
  232.                             'Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) - Testovací data')
  233.  
  234. # Predikce s pevně nastavenými aktivačními funkcemi a optimálním threshold pro DNN na trénovacích datech
  235. y_probs_dnn_train = dnn_model.predict(X_train)
  236. y_pred_dnn_train = (y_probs_dnn_train > optimal_threshold_dnn).astype("int32")
  237.  
  238. # Predikce s pevně nastavenými aktivačními funkcemi a optimálním threshold pro CNN na trénovacích datech
  239. y_probs_cnn_train = cnn_model.predict(X_train.reshape(-1, 2, 1))
  240. y_pred_cnn_train = (y_probs_cnn_train > optimal_threshold_cnn).astype("int32")
  241.  
  242. # Vykreslení grafu pro DNN - Trénovací data
  243. plot_classification_results(X_train_denorm, y_train, y_probs_dnn_train, y_pred_dnn_train, optimal_threshold_dnn,
  244.                             'Rozdíl ELO vs. Pravděpodobnost Výhry (DNN) - Trénovací data')
  245.  
  246. # Vykreslení grafu pro CNN - Trénovací data
  247. plot_classification_results(X_train_denorm, y_train, y_probs_cnn_train, y_pred_cnn_train, optimal_threshold_cnn,
  248.                             'Rozdíl ELO vs. Pravděpodobnost Výhry (CNN) - Trénovací data')
  249.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement