Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import copy
- import random
- def crossover_and_mutate(net1, net2, mutation_rate=0.01, mutation_scale=0.1):
- """
- Perform crossover and mutation on two neural networks.
- Args:
- net1 (nn.Module): The first parent neural network.
- net2 (nn.Module): The second parent neural network.
- mutation_rate (float): Probability of mutating a weight.
- mutation_scale (float): Scale of the mutation (standard deviation of Gaussian noise).
- Returns:
- nn.Module: A new neural network with combined and mutated weights.
- """
- # Ensure the two networks have the same architecture
- assert type(net1) == type(net2), "Networks must have the same architecture"
- # Create a new network with the same architecture as the parents
- child_net = copy.deepcopy(net1)
- # Perform crossover and mutation
- with torch.no_grad():
- for (child_param, param1, param2) in zip(child_net.parameters(), net1.parameters(), net2.parameters()):
- # Crossover: randomly choose weights from either parent
- mask = torch.rand_like(param1) > 0.5
- child_param.data.copy_(torch.where(mask, param1.data, param2.data))
- # Mutation: add random noise to some weights
- mutation_mask = torch.rand_like(child_param) < mutation_rate
- noise = torch.randn_like(child_param) * mutation_scale
- child_param.data.add_(mutation_mask * noise)
- return child_net
- # Example usage
- if __name__ == "__main__":
- # Define a simple neural network architecture
- class SimpleNet(nn.Module):
- def __init__(self):
- super(SimpleNet, self).__init__()
- self.fc1 = nn.Linear(10, 20)
- self.fc2 = nn.Linear(20, 10)
- def forward(self, x):
- x = torch.relu(self.fc1(x))
- x = self.fc2(x)
- return x
- # Create two parent networks
- parent1 = SimpleNet()
- parent2 = SimpleNet()
- # Perform crossover and mutation
- child = crossover_and_mutate(parent1, parent2, mutation_rate=0.05, mutation_scale=0.2)
- # Print the child network's parameters
- for param in child.parameters():
- print(param)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement