Advertisement
STANAANDREY

nn crossover

Jan 2nd, 2025
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.21 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import copy
  4. import random
  5.  
  6. def crossover_and_mutate(net1, net2, mutation_rate=0.01, mutation_scale=0.1):
  7.     """
  8.    Perform crossover and mutation on two neural networks.
  9.  
  10.    Args:
  11.        net1 (nn.Module): The first parent neural network.
  12.        net2 (nn.Module): The second parent neural network.
  13.        mutation_rate (float): Probability of mutating a weight.
  14.        mutation_scale (float): Scale of the mutation (standard deviation of Gaussian noise).
  15.  
  16.    Returns:
  17.        nn.Module: A new neural network with combined and mutated weights.
  18.    """
  19.     # Ensure the two networks have the same architecture
  20.     assert type(net1) == type(net2), "Networks must have the same architecture"
  21.  
  22.     # Create a new network with the same architecture as the parents
  23.     child_net = copy.deepcopy(net1)
  24.  
  25.     # Perform crossover and mutation
  26.     with torch.no_grad():
  27.         for (child_param, param1, param2) in zip(child_net.parameters(), net1.parameters(), net2.parameters()):
  28.             # Crossover: randomly choose weights from either parent
  29.             mask = torch.rand_like(param1) > 0.5
  30.             child_param.data.copy_(torch.where(mask, param1.data, param2.data))
  31.  
  32.             # Mutation: add random noise to some weights
  33.             mutation_mask = torch.rand_like(child_param) < mutation_rate
  34.             noise = torch.randn_like(child_param) * mutation_scale
  35.             child_param.data.add_(mutation_mask * noise)
  36.  
  37.     return child_net
  38.  
  39. # Example usage
  40. if __name__ == "__main__":
  41.     # Define a simple neural network architecture
  42.     class SimpleNet(nn.Module):
  43.         def __init__(self):
  44.             super(SimpleNet, self).__init__()
  45.             self.fc1 = nn.Linear(10, 20)
  46.             self.fc2 = nn.Linear(20, 10)
  47.  
  48.         def forward(self, x):
  49.             x = torch.relu(self.fc1(x))
  50.             x = self.fc2(x)
  51.             return x
  52.  
  53.     # Create two parent networks
  54.     parent1 = SimpleNet()
  55.     parent2 = SimpleNet()
  56.  
  57.     # Perform crossover and mutation
  58.     child = crossover_and_mutate(parent1, parent2, mutation_rate=0.05, mutation_scale=0.2)
  59.  
  60.     # Print the child network's parameters
  61.     for param in child.parameters():
  62.         print(param)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement