Advertisement
iSach

Untitled

Nov 22nd, 2023
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.47 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.autograd as ag
  5.  
  6. from tqdm import tqdm
  7.  
  8. import numpy as np
  9. from matplotlib import pyplot as plt
  10.  
  11. # y'+y=x, y(0)=1
  12. # Analytical solution: y(x) = -1 + 2 * exp(-x) + x
  13.  
  14. x_min = 0
  15. x_max = 2
  16. xs = torch.linspace(x_min, x_max, 100_000).reshape(-1, 1)
  17.  
  18. model = nn.Sequential(
  19. nn.Linear(1, 256),
  20. nn.Sigmoid(),
  21. nn.Linear(256, 1),
  22. )
  23.  
  24. x_ds = torch.utils.data.TensorDataset(xs)
  25. x_dl = torch.utils.data.DataLoader(x_ds, batch_size=128, shuffle=True)
  26. zero = torch.tensor([0.0]).reshape(1, 1)
  27.  
  28. losses = []
  29. opt = optim.Adam(model.parameters(), lr=1e-3)
  30. for x in (bar := tqdm(x_dl)):
  31. x = x[0].requires_grad_(True)
  32. y = model(x)
  33. dy_dx = torch.autograd.grad(
  34. outputs=y,
  35. inputs=x,
  36. grad_outputs=torch.ones_like(x),
  37. create_graph=True,
  38. )[0]
  39.  
  40. y_0 = model(zero)
  41.  
  42. l = (dy_dx + y - x)**2 + (y_0 - 1.0)**2
  43. l = l.mean()
  44.  
  45. bar.set_description(f'Loss: {l.item():.2f}')
  46. losses.append(l.item())
  47.  
  48. opt.zero_grad()
  49. l.backward()
  50. opt.step()
  51.  
  52. xs = torch.linspace(x_min - 1.5, x_max + 1.5, 1000).reshape(-1, 1)
  53. ys = model(xs).detach().numpy()
  54. xs = xs.detach().numpy()
  55. ys_true = -1 + 2 * np.exp(-xs) + xs
  56.  
  57. plt.plot(xs, ys, label='NN')
  58. plt.plot(xs, ys_true, label='Analytical')
  59. plt.plot([x_min, x_min], [0, 1], color='black')
  60. plt.plot([x_max, x_max], [0, 1], color='black')
  61. plt.legend()
  62. plt.show()
  63. plt.plot(losses)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement