Advertisement
mayankjoin3

gan_2_no_gpu

Nov 14th, 2024
47
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.08 KB | None | 0 0
  1. from datetime import datetime
  2.  
  3. start_time = datetime.now()
  4.  
  5.  
  6. """# GAN"""
  7.  
  8. import os
  9. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # Disable GPU
  10.  
  11. import numpy as np
  12. import pandas as pd
  13. import os
  14. from tensorflow.keras.models import Sequential, Model
  15. from tensorflow.keras.layers import Dense, LeakyReLU, Input, Embedding, Concatenate, Flatten
  16. from tensorflow.keras.optimizers import RMSprop
  17.  
  18.  
  19. # Function to create a basic GAN generator model
  20. def create_standard_gan_generator(input_dim, output_dim):
  21.     model = Sequential()
  22.     model.add(Dense(256, input_dim=input_dim))
  23.     model.add(LeakyReLU(alpha=0.2))
  24.     model.add(Dense(512))
  25.     model.add(LeakyReLU(alpha=0.2))
  26.     model.add(Dense(output_dim, activation='tanh'))
  27.     return model
  28.  
  29. def create_cgan_generator(latent_dim, output_dim, num_classes):
  30.     # Define label input and embedding layer for labels
  31.     label = Input(shape=(1,), name='label_input')
  32.     label_embedding = Embedding(num_classes, latent_dim, input_length=1)(label)  # Embed to match `latent_dim`
  33.     label_embedding = Flatten()(label_embedding)  # Flatten embedding to concatenate
  34.  
  35.     # Define noise input
  36.     noise = Input(shape=(latent_dim,), name='noise_input')
  37.  
  38.     # Concatenate noise and label embedding
  39.     combined_input = Concatenate()([noise, label_embedding])  # This shape is (latent_dim + latent_dim)
  40.  
  41.     # Build generator model with combined input
  42.     x = Dense(256)(combined_input)
  43.     x = LeakyReLU(alpha=0.2)(x)
  44.     x = Dense(512)(x)
  45.     x = LeakyReLU(alpha=0.2)(x)
  46.     generator_output = Dense(output_dim, activation='tanh')(x)
  47.  
  48.     # Create the model
  49.     model = Model([noise, label], generator_output)
  50.     return model
  51.  
  52. # Function to create a Wasserstein GAN (WGAN) generator model
  53. def create_wgan_generator(input_dim, output_dim):
  54.     model = Sequential()
  55.     model.add(Dense(256, input_dim=input_dim))
  56.     model.add(LeakyReLU(alpha=0.2))
  57.     model.add(Dense(512))
  58.     model.add(LeakyReLU(alpha=0.2))
  59.     model.add(Dense(output_dim, activation='tanh'))
  60.     return model
  61.  
  62. def generate_samples(generator, n_samples, latent_dim, gan_type, num_classes=None, cls=None):
  63.     noise = np.random.normal(0, 1, (n_samples, latent_dim))
  64.  
  65.     if gan_type == "cGAN" and cls is not None:
  66.         labels = np.full((n_samples, 1), cls)
  67.         generated_samples = generator.predict([noise, labels])
  68.     else:
  69.         generated_samples = generator.predict(noise)
  70.  
  71.     return generated_samples
  72.  
  73. # def generate_data_with_gans(data, output_dir, base_name, latent_dim=100, samples_per_class=1000):
  74. def generate_data_with_gans(data, output_dir, base_name, latent_dim=100, samples_per_class=1000):
  75.     os.makedirs(output_dir, exist_ok=True)
  76.     classes = np.unique(data['label'])
  77.     num_features = data.shape[1] - 1
  78.     num_classes = len(classes)
  79.  
  80.     for gan_type in ["StandardGAN", "cGAN", "WGAN"]:
  81.         all_generated_data = []
  82.  
  83.         for cls in classes:
  84.             if gan_type == "StandardGAN":
  85.                 generator = create_standard_gan_generator(latent_dim, num_features)
  86.                 generated_samples = generate_samples(generator, samples_per_class, latent_dim, gan_type)
  87.  
  88.             elif gan_type == "cGAN":
  89.                 generator = create_cgan_generator(latent_dim, num_features, num_classes)
  90.                 generated_samples = generate_samples(generator, samples_per_class, latent_dim, gan_type, num_classes, cls)
  91.  
  92.             elif gan_type == "WGAN":
  93.                 generator = create_wgan_generator(latent_dim, num_features)
  94.                 generated_samples = generate_samples(generator, samples_per_class, latent_dim, gan_type)
  95.  
  96.             generated_label = np.full((samples_per_class, 1), cls)
  97.             generated_data = np.hstack((generated_samples, generated_label))
  98.             all_generated_data.append(generated_data)
  99.  
  100.         all_generated_data = np.vstack(all_generated_data)
  101.         df_generated = pd.DataFrame(all_generated_data, columns=[*data.columns[:-1], 'label'])
  102.  
  103.         filename = os.path.join(output_dir, f"{base_name}_{gan_type}.csv")
  104.         df_generated.to_csv(filename, index=False)
  105.         print(f"Data for {gan_type} generated and saved successfully as:", filename)
  106.  
  107. import os
  108. import pandas as pd
  109.  
  110. # Define the input and output directories
  111. input_dir = './Data'
  112. output_dir = './GAN'
  113. latent_dim = 134
  114.  
  115. # Loop through each file in the input directory
  116. for filename in os.listdir(input_dir):
  117.     if filename.endswith('_selected_data.csv'):  # Process only files with the specific suffix
  118.         file_path = os.path.join(input_dir, filename)
  119.  
  120.         # Extract the base file name (remove "_selected_data.csv")
  121.         base_name = filename.replace('_selected_data.csv', '')
  122.  
  123.         # Read the data from the CSV file
  124.         data = pd.read_csv(file_path)
  125.  
  126.         # Generate data with GANs for each file
  127.         print(f"Processing file: {base_name}")
  128.         generate_data_with_gans(data, output_dir, base_name, latent_dim=latent_dim)
  129.         print(f"Finished processing file: {base_name}\n")
  130.        
  131.        
  132. end_time = datetime.now()
  133. print('Duration: {}'.format(end_time - start_time))
  134.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement