Advertisement
UF6

Neural Network Differential Training Data

UF6
Oct 20th, 2024
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.77 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4.  
  5. # Define the neural network to approximate the solution
  6. class Net(nn.Module):
  7.     def __init__(self):
  8.         super(Net, self).__init__()
  9.         # A simple neural network with one hidden layer
  10.         self.hidden = nn.Linear(1, 10)
  11.         self.output = nn.Linear(10, 1)
  12.    
  13.     def forward(self, x):
  14.         # Apply a non-linearity (ReLU in this case)
  15.         x = torch.relu(self.hidden(x))
  16.         # Output the approximation for y(x)
  17.         x = self.output(x)
  18.         return x
  19.  
  20. # Define the differential equation dy/dx = -y
  21. def ode_loss(x, model):
  22.     y = model(x)  # Network prediction for y
  23.     dy_dx = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(x), create_graph=True)[0]
  24.     # Loss is based on the equation: dy/dx + y = 0
  25.     loss = torch.mean((dy_dx + y) ** 2)
  26.     return loss
  27.  
  28. # Generate training data
  29. x_train = torch.linspace(0, 2, 100).view(-1, 1)  # x from 0 to 2
  30. x_train.requires_grad = True  # We need gradients for x
  31.  
  32. # Initialize the neural network and optimizer
  33. model = Net()
  34. optimizer = optim.Adam(model.parameters(), lr=0.01)
  35.  
  36. # Training loop
  37. epochs = 2000
  38. for epoch in range(epochs):
  39.     optimizer.zero_grad()  # Zero the gradients
  40.     loss = ode_loss(x_train, model)  # Compute the loss
  41.     loss.backward()  # Backpropagate
  42.     optimizer.step()  # Update the weights
  43.    
  44.     if epoch % 100 == 0:
  45.         print(f'Epoch {epoch}, Loss: {loss.item()}')
  46.  
  47. # Test the model
  48. with torch.no_grad():
  49.     x_test = torch.linspace(0, 2, 100).view(-1, 1)
  50.     y_pred = model(x_test)
  51.    
  52. # Plotting the result
  53. import matplotlib.pyplot as plt
  54. plt.plot(x_test, y_pred, label="NN solution")
  55. plt.plot(x_test, torch.exp(-x_test), label="Exact solution")
  56. plt.legend()
  57. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement