Advertisement
mayankjoin3

gan oversampling

Oct 10th, 2024
46
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.79 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import pandas as pd
  5. from sklearn.preprocessing import LabelEncoder
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.preprocessing import MinMaxScaler
  8.  
  9. # Generator Model
  10. class Generator(nn.Module):
  11.     def __init__(self, input_dim, output_dim):
  12.         super(Generator, self).__init__()
  13.         self.model = nn.Sequential(
  14.             nn.Linear(input_dim, 128),
  15.             nn.ReLU(),
  16.             nn.Linear(128, 256),
  17.             nn.ReLU(),
  18.             nn.Linear(256, output_dim),
  19.             nn.Tanh()  # Use Tanh because output values are in the range of [-1, 1]
  20.         )
  21.  
  22.     def forward(self, x):
  23.         return self.model(x)
  24.  
  25. # Discriminator Model
  26. class Discriminator(nn.Module):
  27.     def __init__(self, input_dim):
  28.         super(Discriminator, self).__init__()
  29.         self.model = nn.Sequential(
  30.             nn.Linear(input_dim, 256),
  31.             nn.LeakyReLU(0.2),
  32.             nn.Linear(256, 128),
  33.             nn.LeakyReLU(0.2),
  34.             nn.Linear(128, 1),
  35.             nn.Sigmoid()  # Outputs a probability
  36.         )
  37.  
  38.     def forward(self, x):
  39.         return self.model(x)
  40.  
  41. # GAN-based resampling method
  42. def gan_resample(X_train, y_train, num_epochs=5000, batch_size=64):
  43.     # Find majority and minority classes
  44.     majority_class = np.argmax(np.bincount(y_train))
  45.     minority_class = np.argmin(np.bincount(y_train))
  46.  
  47.     X_minority = X_train[y_train == minority_class]
  48.     X_majority = X_train[y_train == majority_class]
  49.  
  50.     # GAN dimensions
  51.     input_dim = X_minority.shape[1]
  52.     latent_dim = 32  # Size of the random noise vector
  53.     output_dim = input_dim
  54.  
  55.     # Initialize Generator and Discriminator
  56.     generator = Generator(latent_dim, output_dim)
  57.     discriminator = Discriminator(input_dim)
  58.  
  59.     # Loss and optimizers
  60.     criterion = nn.BCELoss()
  61.     g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
  62.     d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
  63.  
  64.     # Training loop for GAN
  65.     for epoch in range(num_epochs):
  66.         # Sample noise as generator input
  67.         z = torch.randn(batch_size, latent_dim)
  68.         real_samples = torch.Tensor(X_minority[np.random.randint(0, X_minority.shape[0], batch_size)])
  69.  
  70.         # Train Discriminator
  71.         d_optimizer.zero_grad()
  72.         real_labels = torch.ones(batch_size, 1)
  73.         fake_labels = torch.zeros(batch_size, 1)
  74.  
  75.         # Compute loss with real samples
  76.         real_loss = criterion(discriminator(real_samples), real_labels)
  77.         # Generate fake samples
  78.         fake_samples = generator(z).detach()
  79.         fake_loss = criterion(discriminator(fake_samples), fake_labels)
  80.  
  81.         # Total discriminator loss
  82.         d_loss = real_loss + fake_loss
  83.         d_loss.backward()
  84.         d_optimizer.step()
  85.  
  86.         # Train Generator
  87.         g_optimizer.zero_grad()
  88.         # Generate fake samples
  89.         z = torch.randn(batch_size, latent_dim)
  90.         generated_samples = generator(z)
  91.         g_loss = criterion(discriminator(generated_samples), real_labels)
  92.         g_loss.backward()
  93.         g_optimizer.step()
  94.  
  95.         if epoch % 1000 == 0:
  96.             print(f"Epoch {epoch}/{num_epochs}, d_loss: {d_loss.item()}, g_loss: {g_loss.item()}")
  97.  
  98.     # Generate new synthetic samples from the trained Generator
  99.     num_samples_needed = X_majority.shape[0] - X_minority.shape[0]
  100.     z = torch.randn(num_samples_needed, latent_dim)
  101.     synthetic_samples = generator(z).detach().numpy()
  102.  
  103.     # Concatenate original and synthetic samples
  104.     X_balanced = np.vstack([X_train, synthetic_samples])
  105.     y_balanced = np.hstack([y_train, np.full(num_samples_needed, minority_class)])
  106.  
  107.     return X_balanced, y_balanced
  108.  
  109. # Preprocessing function modified for GAN resampling
  110. def preprocess_data_with_gan(input_file, dataset_percent):
  111.     data = pd.read_csv(input_file)
  112.  
  113.     # Handle missing values (e.g., replacing with median)
  114.     data.fillna(data.median(), inplace=True)
  115.  
  116.     # Convert categorical data
  117.     for col in data.select_dtypes(include=['object']).columns:
  118.         data[col] = LabelEncoder().fit_transform(data[col])
  119.  
  120.     # Separate features and target
  121.     X = data.iloc[:, :-1].values
  122.     y = data.iloc[:, -1].values
  123.  
  124.     # MinMax scaling
  125.     scaler = MinMaxScaler()
  126.     X = scaler.fit_transform(X)
  127.  
  128.     # Split into train and test
  129.     X_train, _, y_train, _ = train_test_split(X, y, train_size=dataset_percent / 100, stratify=y)
  130.  
  131.     # Balance the dataset using the GAN-based resampling method
  132.     X_balanced, y_balanced = gan_resample(X_train, y_train)
  133.  
  134.     return X_balanced, y_balanced
  135.  
  136. # Usage example
  137. input_file = 'input.csv'  # Input dataset
  138. dataset_percent = 10  # Use 10% of the dataset
  139. X_balanced, y_balanced = preprocess_data_with_gan(input_file, dataset_percent)
  140.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement