Advertisement
max2201111

Petr4 DNN a CNN CM train lepsi

Aug 22nd, 2024
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.07 KB | Science | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import seaborn as sns
  4. from sklearn.preprocessing import StandardScaler
  5. from sklearn.metrics import confusion_matrix, classification_report
  6. from sklearn.utils.class_weight import compute_class_weight
  7. from tensorflow.keras.models import Sequential
  8. from tensorflow.keras.layers import Dense, Dropout, Input, BatchNormalization, Conv1D, GlobalAveragePooling1D
  9. from tensorflow.keras.optimizers import Adam
  10. from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
  11. from tensorflow.keras.regularizers import l2
  12. from scipy.ndimage import gaussian_filter1d
  13.  
  14. # Nastavitelný práh
  15. THRESHOLD = 0.5
  16.  
  17. THRESHOLD = 0.5
  18.  
  19. train_data = np.array([
  20.     [1141, 1050, 1], [1499, 1050, 0], [1451, 1077, 1], [1519, 1077, 1], [1191, 1093, 0],
  21.     [1590, 1093, 1], [1777, 1093, 0], [1141, 1124, 1], [1499, 1124, 0], [1601, 1124, 0],
  22.     [1606, 1141, 0], [1608, 1141, 0], [1870, 1141, 0], [1794, 1191, 0], [1329, 1211, 1],
  23.     [1687, 1211, 0], [1918, 1211, 0], [1608, 1212, 0], [1050, 1228, 0], [2107, 1228, 0],
  24.     [2202, 1266, 0], [1551, 1327, 0], [1608, 1327, 1], [1660, 1327, 0], [1329, 1332, 1],
  25.     [1093, 1346, 0], [1546, 1419, 1], [1327, 1435, 0], [1774, 1435, 0], [1794, 1451, 0],
  26.     [1077, 1458, 1], [1093, 1458, 1], [1731, 1458, 0], [1777, 1491, 0], [1212, 1499, 1],
  27.     [1211, 1519, 1], [1608, 1519, 1], [1918, 1519, 0], [1458, 1538, 1], [1918, 1538, 0],
  28.     [1794, 1545, 0], [1903, 1545, 0], [1435, 1546, 1], [1758, 1546, 0], [2076, 1546, 0],
  29.     [1077, 1551, 1], [1690, 1551, 0], [1050, 1590, 1], [1093, 1601, 1], [1327, 1601, 0],
  30.     [1050, 1606, 1], [1491, 1606, 1], [1519, 1608, 0], [1266, 1660, 1], [1839, 1660, 0],
  31.     [1332, 1687, 0], [1519, 1687, 0], [1538, 1690, 1], [1870, 1690, 0], [1903, 1731, 1],
  32.     [1918, 1731, 0], [1419, 1758, 0], [1839, 1758, 0], [1077, 1774, 1], [1519, 1774, 1],
  33.     [2202, 1774, 0], [1538, 1777, 0], [1903, 1777, 1], [2107, 1777, 0], [1660, 1794, 0],
  34.     [2107, 1794, 0], [1124, 1839, 1], [1519, 1839, 1], [1546, 1839, 1], [1870, 1839, 1],
  35.     [2202, 1839, 0], [1419, 1870, 1], [2107, 1870, 0], [2202, 1870, 0], [1191, 1903, 1],
  36.     [1601, 1903, 1], [1606, 1903, 1], [1660, 1903, 1], [1491, 1918, 1], [1212, 2076, 1],
  37.     [1690, 2076, 1], [1546, 2107, 1], [1903, 2107, 1], [2183, 2107, 0], [1229, 2183, 1],
  38.     [1731, 2183, 1], [1758, 2183, 0], [1918, 2183, 1], [2076, 2183, 0], [1538, 2202, 1],
  39.     [1601, 2202, 1], [2076, 2202, 0], [1660, 2258, 1], [1777, 2258, 0], [2202, 2258, 0]
  40. ])
  41.  
  42. # Testovací data (2D)
  43. test_data = np.array([
  44.     [1451, 1050, 0], [1758, 1050, 0], [1346, 1211, 1], [1546, 1332, 1], [1608, 1451, 1],
  45.     [1839, 1458, 0], [1435, 1538, 1], [1077, 1546, 1], [2183, 1551, 0], [1458, 1590, 0],
  46.     [1538, 1606, 0], [1077, 1608, 1], [2258, 1608, 0], [1419, 1690, 1], [1545, 1731, 1],
  47.     [1774, 1758, 0], [1545, 1774, 1], [2183, 1777, 0], [1228, 1794, 1], [1774, 1794, 0],
  48.     [2258, 1870, 1], [1546, 1903, 1], [1774, 1918, 0], [2076, 1918, 0], [1758, 2076, 1],
  49.     [1839, 2076, 0], [2107, 2076, 1], [2258, 2107, 1], [1731, 2202, 1], [1327, 2258, 1]
  50. ])
  51.  
  52. # Funkce pro vytvoření nových features
  53. def create_features(X):
  54.     elo_diff = X[:, 0] - X[:, 1]
  55.     return np.column_stack((X, elo_diff))
  56.  
  57. # Normalizace dat pomocí Z-score
  58. scaler = StandardScaler()
  59.  
  60. # Funkce pro vytvoření DNN modelu
  61. def create_dnn_model(input_shape):
  62.     model = Sequential([
  63.         Input(shape=input_shape),
  64.         Dense(64, activation='relu', kernel_regularizer=l2(0.01)),
  65.         BatchNormalization(),
  66.         Dropout(0.3),
  67.         Dense(32, activation='relu', kernel_regularizer=l2(0.01)),
  68.         BatchNormalization(),
  69.         Dropout(0.3),
  70.         Dense(16, activation='relu', kernel_regularizer=l2(0.01)),
  71.         BatchNormalization(),
  72.         Dropout(0.3),
  73.         Dense(1, activation='sigmoid')
  74.     ])
  75.     model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
  76.     return model
  77.  
  78. # Funkce pro vytvoření CNN modelu
  79. def create_cnn_model(input_shape):
  80.     model = Sequential([
  81.         Input(shape=input_shape),
  82.         Conv1D(64, 2, activation='relu', padding='same', kernel_regularizer=l2(0.01)),
  83.         BatchNormalization(),
  84.         Conv1D(128, 2, activation='relu', padding='same', kernel_regularizer=l2(0.01)),
  85.         BatchNormalization(),
  86.         GlobalAveragePooling1D(),
  87.         Dense(64, activation='relu', kernel_regularizer=l2(0.01)),
  88.         BatchNormalization(),
  89.         Dropout(0.3),
  90.         Dense(32, activation='relu', kernel_regularizer=l2(0.01)),
  91.         BatchNormalization(),
  92.         Dropout(0.3),
  93.         Dense(1, activation='sigmoid')
  94.     ])
  95.     model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
  96.     return model
  97.  
  98. # Funkce pro vykreslení grafů s barevnými čtverečky a křivkou trendu
  99. def plot_colored_squares_with_trend(X, y_true, y_pred, y_prob, title):
  100.     plt.figure(figsize=(10, 8))
  101.    
  102.     # True Positive: zelená
  103.     tp = ((y_true == 1) & (y_pred == 1))
  104.     plt.scatter(X[tp, 2], y_prob[tp], c='green', marker='s', s=50, label='True Positive', alpha=0.7)
  105.    
  106.     # True Negative: červená
  107.     tn = ((y_true == 0) & (y_pred == 0))
  108.     plt.scatter(X[tn, 2], y_prob[tn], c='red', marker='s', s=50, label='True Negative', alpha=0.7)
  109.    
  110.     # False Positive: modrá
  111.     fp = ((y_true == 0) & (y_pred == 1))
  112.     plt.scatter(X[fp, 2], y_prob[fp], c='blue', marker='s', s=50, label='False Positive', alpha=0.7)
  113.    
  114.     # False Negative: zlatá
  115.     fn = ((y_true == 1) & (y_pred == 0))
  116.     plt.scatter(X[fn, 2], y_prob[fn], c='gold', marker='s', s=50, label='False Negative', alpha=0.7)
  117.    
  118.     # Vytvoření křivky trendu
  119.     X_diff = X[:, 2]
  120.     num_bins = 50
  121.     bin_means, bin_edges, _ = binned_statistic(X_diff, y_prob, statistic='mean', bins=num_bins)
  122.     bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
  123.    
  124.     # Odstranění NaN hodnot
  125.     valid_indices = ~np.isnan(bin_means)
  126.     bin_centers = bin_centers[valid_indices]
  127.     bin_means = bin_means[valid_indices]
  128.    
  129.     # Seřazení bodů podle X pro správné vykreslení křivky
  130.     sort_indices = np.argsort(bin_centers)
  131.     bin_centers = bin_centers[sort_indices]
  132.     bin_means = bin_means[sort_indices]
  133.    
  134.     # Aplikace Gaussova filtru pro vyhlazení
  135.     smoothed_means = gaussian_filter1d(bin_means, sigma=1)
  136.    
  137.     # Vykreslení křivky trendu
  138.     plt.plot(bin_centers, smoothed_means, color='black', label='Trend', linewidth=2)
  139.    
  140.     plt.axhline(y=THRESHOLD, color='r', linestyle='--', label='Práh')
  141.    
  142.     plt.title(title)
  143.     plt.xlabel('Rozdíl ELO')
  144.     plt.ylabel('Pravděpodobnost výhry')
  145.     plt.legend()
  146.     plt.grid(True)
  147.     plt.show()
  148.  
  149. # Funkce pro vykreslení průběhu trénování
  150. def plot_training_history(dnn_history, cnn_history):
  151.     plt.figure(figsize=(15, 5))
  152.    
  153.     plt.subplot(1, 2, 1)
  154.     plt.plot(dnn_history.history['accuracy'], label='DNN Trénovací')
  155.     plt.plot(dnn_history.history['val_accuracy'], label='DNN Validační')
  156.     plt.plot(cnn_history.history['accuracy'], label='CNN Trénovací')
  157.     plt.plot(cnn_history.history['val_accuracy'], label='CNN Validační')
  158.     plt.title('Průběh trénování (Přesnost)')
  159.     plt.xlabel('Epocha')
  160.     plt.ylabel('Přesnost')
  161.     plt.legend()
  162.  
  163.     plt.subplot(1, 2, 2)
  164.     plt.plot(dnn_history.history['loss'], label='DNN Trénovací')
  165.     plt.plot(dnn_history.history['val_loss'], label='DNN Validační')
  166.     plt.plot(cnn_history.history['loss'], label='CNN Trénovací')
  167.     plt.plot(cnn_history.history['val_loss'], label='CNN Validační')
  168.     plt.title('Průběh trénování (Loss)')
  169.     plt.xlabel('Epocha')
  170.     plt.ylabel('Loss')
  171.     plt.legend()
  172.  
  173.     plt.tight_layout()
  174.     plt.show()
  175.  
  176. # Hlavní funkce pro trénování a vyhodnocení modelů
  177. def train_and_evaluate_models(X_train, y_train, X_test, y_test):
  178.     # Příprava dat
  179.     X_train = create_features(X_train)
  180.     X_test = create_features(X_test)
  181.    
  182.     X_train_scaled = scaler.fit_transform(X_train)
  183.     X_test_scaled = scaler.transform(X_test)
  184.  
  185.     # Příprava dat pro CNN
  186.     X_train_cnn = X_train_scaled.reshape((X_train_scaled.shape[0], X_train_scaled.shape[1], 1))
  187.     X_test_cnn = X_test_scaled.reshape((X_test_scaled.shape[0], X_test_scaled.shape[1], 1))
  188.  
  189.     # Výpočet váhy tříd
  190.     class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
  191.     class_weight_dict = dict(enumerate(class_weights))
  192.  
  193.     # Vytvoření a trénování modelů
  194.     dnn_model = create_dnn_model((X_train_scaled.shape[1],))
  195.     cnn_model = create_cnn_model((X_train_cnn.shape[1], 1))
  196.  
  197.     early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
  198.     reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=10, min_lr=0.00001)
  199.  
  200.     dnn_history = dnn_model.fit(
  201.         X_train_scaled, y_train,
  202.         epochs=200,
  203.         batch_size=32,
  204.         validation_split=0.2,
  205.         class_weight=class_weight_dict,
  206.         callbacks=[early_stopping, reduce_lr],
  207.         verbose=0
  208.     )
  209.  
  210.     cnn_history = cnn_model.fit(
  211.         X_train_cnn, y_train,
  212.         epochs=200,
  213.         batch_size=32,
  214.         validation_split=0.2,
  215.         class_weight=class_weight_dict,
  216.         callbacks=[early_stopping, reduce_lr],
  217.         verbose=0
  218.     )
  219.  
  220.     # Predikce a vyhodnocení modelů
  221.     dnn_train_probs = dnn_model.predict(X_train_scaled).flatten()
  222.     dnn_test_probs = dnn_model.predict(X_test_scaled).flatten()
  223.     cnn_train_probs = cnn_model.predict(X_train_cnn).flatten()
  224.     cnn_test_probs = cnn_model.predict(X_test_cnn).flatten()
  225.  
  226.     dnn_train_pred = (dnn_train_probs >= THRESHOLD).astype("int32")
  227.     dnn_test_pred = (dnn_test_probs >= THRESHOLD).astype("int32")
  228.     cnn_train_pred = (cnn_train_probs >= THRESHOLD).astype("int32")
  229.     cnn_test_pred = (cnn_test_probs >= THRESHOLD).astype("int32")
  230.  
  231.     print(f"Použitý práh: {THRESHOLD}")
  232.  
  233.     # Vykreslení 6 požadovaných grafů
  234.     print("Vykreslování 6 požadovaných grafů:")
  235.    
  236.     print("1. DNN - Trénovací data")
  237.     plot_colored_squares_with_trend(X_train, y_train, dnn_train_pred, dnn_train_probs, 'DNN - Trénovací data')
  238.    
  239.     print("2. DNN - Testovací data")
  240.     plot_colored_squares_with_trend(X_test, y_test, dnn_test_pred, dnn_test_probs, 'DNN - Testovací data')
  241.    
  242.     print("3. CNN - Trénovací data")
  243.     plot_colored_squares_with_trend(X_train, y_train, cnn_train_pred, cnn_train_probs, 'CNN - Trénovací data')
  244.    
  245.     print("4. CNN - Testovací data")
  246.     plot_colored_squares_with_trend(X_test, y_test, cnn_test_pred, cnn_test_probs, 'CNN - Testovací data')
  247.    
  248.     print("5-6. Průběh accuracy a loss pro oba modely")
  249.     plot_training_history(dnn_history, cnn_history)
  250.  
  251.     # Výpis shrnutí výsledků
  252.     print("\nShrnutí výsledků:")
  253.     print("DNN model:")
  254.     print(f"Trénovací přesnost: {dnn_history.history['accuracy'][-1]:.4f}")
  255.     print(f"Validační přesnost: {dnn_history.history['val_accuracy'][-1]:.4f}")
  256.     print(f"Testovací přesnost: {np.mean(dnn_test_pred == y_test):.4f}")
  257.     print("\nCNN model:")
  258.     print(f"Trénovací přesnost: {cnn_history.history['accuracy'][-1]:.4f}")
  259.     print(f"Validační přesnost: {cnn_history.history['val_accuracy'][-1]:.4f}")
  260.     print(f"Testovací přesnost: {np.mean(cnn_test_pred == y_test):.4f}")
  261.  
  262.     # Výpis confusion matrix a klasifikační zprávy pro oba modely
  263.     print("\nDNN Confusion Matrix - Testovací data:")
  264.     print(confusion_matrix(y_test, dnn_test_pred))
  265.     print("\nDNN Classification Report - Testovací data:")
  266.     print(classification_report(y_test, dnn_test_pred))
  267.  
  268.     print("\nCNN Confusion Matrix - Testovací data:")
  269.     print(confusion_matrix(y_test, cnn_test_pred))
  270.     print("\nCNN Classification Report - Testovací data:")
  271.     print(classification_report(y_test, cnn_test_pred))
  272.  
  273. # Zde by následovalo volání funkce train_and_evaluate_models s vašimi trénovacími a testovacími daty
  274. # train_and_evaluate_models(X_train, y_train, X_test, y_test)
  275.  
  276.  
  277. np.random.seed(42)
  278. X_train = np.random.randint(1000, 2500, size=(1000, 2))
  279. y_train = (X_train[:, 0] > X_train[:, 1]).astype(int)
  280. X_test = np.random.randint(1000, 2500, size=(200, 2))
  281. y_test = (X_test[:, 0] > X_test[:, 1]).astype(int)
  282.  
  283. # Volání hlavní funkce s ukázkovými daty
  284. train_and_evaluate_models(X_train, y_train, X_test, y_test)
  285.  
  286. print("\nZávěr:")
  287. print("Tento skript demonstruje kompletní proces trénování a vyhodnocení modelů pro predikci šachových výsledků.")
  288. print("Pro spuštění analýzy je třeba poskytnout vlastní trénovací a testovací data.")
  289. print("Skript vykreslí 6 grafů: 4 grafy s barevnými čtverečky pro DNN a CNN na trénovacích a testovacích datech,")
  290. print("plus 2 grafy zobrazující průběh accuracy a loss během trénování.")
  291. print("Všechny čtyři kategorie predikcí (TP, TN, FP, FN) jsou nyní zobrazeny v grafech.")
  292. print("Křivka trendu je monotónní klesající a hladká, a správně zachází s duplicitními hodnotami rozdílu ELO.")
  293. print("Dále poskytne shrnutí výsledků, confusion matrix a klasifikační zprávu pro oba modely.")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement