Advertisement
exotic666

Proj

Nov 29th, 2024
17
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.05 KB | Source Code | 0 0
  1. # Mount Google Drive (Optional, for saving model)
  2. from google.colab import drive
  3. drive.mount('/content/drive')
  4.  
  5. # Import Libraries
  6. import os
  7. import zipfile
  8. import cv2
  9. import numpy as np
  10. from tensorflow.keras.models import Sequential
  11. from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
  12. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  13. from tensorflow.keras.utils import to_categorical
  14. from sklearn.model_selection import train_test_split
  15. from sklearn.preprocessing import LabelEncoder
  16. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
  17. import matplotlib.pyplot as plt
  18.  
  19. # Dataset Parameters
  20. IMG_HEIGHT, IMG_WIDTH = 100, 100  # Resize images to 100x100
  21. DATASET_ZIP = '/content/lfw-deepfunneled.zip'  # Path to dataset zip file
  22. EXTRACTED_FOLDER = '/content/lfw-deepfunneled'
  23.  
  24. # Step 1: Load and Preprocess Dataset
  25. def load_lfw_data(data_dir):
  26.     X, y = [], []
  27.     for person in os.listdir(data_dir):
  28.         person_dir = os.path.join(data_dir, person)
  29.         if os.path.isdir(person_dir):
  30.             for img_name in os.listdir(person_dir):
  31.                 img_path = os.path.join(person_dir, img_name)
  32.                 img = cv2.imread(img_path)
  33.                 if img is not None:
  34.                     img = cv2.resize(img, (IMG_HEIGHT, IMG_WIDTH))  # Resize image
  35.                     img = img / 255.0  # Normalize pixel values to [0, 1]
  36.                     X.append(img)
  37.                     y.append(person)
  38.     return np.array(X), np.array(y)
  39.  
  40. # Unzip Dataset if Necessary
  41. if not os.path.exists(EXTRACTED_FOLDER):
  42.     print("Extracting dataset...")
  43.     with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
  44.         zip_ref.extractall('/content/')
  45. print("Dataset extracted!")
  46.  
  47. # Load Data
  48. print("Loading data...")
  49. X, y = load_lfw_data(EXTRACTED_FOLDER)
  50.  
  51. # Encode Labels
  52. print("Encoding labels...")
  53. le = LabelEncoder()
  54. y_encoded = le.fit_transform(y)  # Convert string labels to integers
  55. y_categorical = to_categorical(y_encoded)  # Convert to one-hot encoding
  56.  
  57. # Split Data into Train/Test Sets
  58. print("Splitting data...")
  59. X_train, X_test, y_train, y_test = train_test_split(X, y_categorical, test_size=0.2, random_state=42)
  60.  
  61. # Step 2: Data Augmentation
  62. print("Applying data augmentation...")
  63. datagen = ImageDataGenerator(
  64.     rotation_range=10,       # Random rotation
  65.     width_shift_range=0.1,   # Horizontal shift
  66.     height_shift_range=0.1,  # Vertical shift
  67.     horizontal_flip=True     # Random horizontal flips
  68. )
  69. datagen.fit(X_train)
  70.  
  71. # Step 3: Build the Model
  72. print("Building model...")
  73. model = Sequential([
  74.     Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
  75.     BatchNormalization(),
  76.     MaxPooling2D((2, 2)),
  77.     Dropout(0.25),
  78.  
  79.     Conv2D(64, (3, 3), activation='relu'),
  80.     BatchNormalization(),
  81.     MaxPooling2D((2, 2)),
  82.     Dropout(0.25),
  83.  
  84.     Conv2D(128, (3, 3), activation='relu'),
  85.     BatchNormalization(),
  86.     MaxPooling2D((2, 2)),
  87.  
  88.     Conv2D(256, (3, 3), activation='relu'),
  89.     BatchNormalization(),
  90.     MaxPooling2D((2, 2)),
  91.  
  92.     Flatten(),
  93.     Dense(256, activation='relu'),
  94.     Dropout(0.5),
  95.     Dense(y_categorical.shape[1], activation='softmax')  # Output layer
  96. ])
  97.  
  98. # Compile the Model
  99. model.compile(optimizer='adam',
  100.               loss='categorical_crossentropy',
  101.               metrics=['accuracy'])
  102.  
  103. # Model Summary
  104. model.summary()
  105.  
  106. # Step 4: Train the Model
  107. print("Training model...")
  108. history = model.fit(
  109.     datagen.flow(X_train, y_train, batch_size=32),
  110.     epochs=20,
  111.     validation_data=(X_test, y_test),
  112.     verbose=2
  113. )
  114.  
  115. # Step 5: Evaluate the Model
  116. print("Evaluating model...")
  117. test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=2)
  118. print(f"Test Loss: {test_loss}")
  119. print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
  120.  
  121. # Step 6: Analyze Predictions
  122. y_pred = model.predict(X_test)
  123. y_pred_classes = np.argmax(y_pred, axis=1)
  124. y_test_classes = np.argmax(y_test, axis=1)
  125.  
  126. # Classification Report
  127. print("Classification Report:")
  128. print(classification_report(y_test_classes, y_pred_classes, target_names=le.classes_))
  129.  
  130. # Confusion Matrix
  131. print("Confusion Matrix:")
  132. cm = confusion_matrix(y_test_classes, y_pred_classes)
  133. plt.figure(figsize=(10, 8))
  134. plt.title("Confusion Matrix")
  135. plt.imshow(cm, cmap="viridis")
  136. plt.colorbar()
  137. plt.show()
  138.  
  139. # Step 7: Save the Model
  140. model.save("lfw_funneled_5layer_model.h5")
  141. print("Model saved as 'lfw_funneled_5layer_model.h5'")
  142.  
  143. # Optional: Visualize Accuracy and Loss
  144. plt.figure(figsize=(12, 6))
  145.  
  146. # Plot Accuracy
  147. plt.subplot(1, 2, 1)
  148. plt.plot(history.history['accuracy'], label='Train Accuracy')
  149. plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
  150. plt.title('Accuracy Over Epochs')
  151. plt.xlabel('Epochs')
  152. plt.ylabel('Accuracy')
  153. plt.legend()
  154.  
  155. # Plot Loss
  156. plt.subplot(1, 2, 2)
  157. plt.plot(history.history['loss'], label='Train Loss')
  158. plt.plot(history.history['val_loss'], label='Validation Loss')
  159. plt.title('Loss Over Epochs')
  160. plt.xlabel('Epochs')
  161. plt.ylabel('Loss')
  162. plt.legend()
  163.  
  164. plt.show()
  165.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement