Advertisement
alkkofficial

Untitled

Jun 25th, 2023
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.91 KB | None | 0 0
  1. def cycle(self, X_train, y_train, X_val, y_val, best_score, l1_lambda=0.001, l2_lambda=0.001):
  2. model = Agent().to(d)
  3. X_train, y_train, X_val, y_val = X_train.to(
  4. d), y_train.to(d), X_val.to(d), y_val.to(d)
  5. # Weight initialization
  6. try:
  7. weights_path = "./zlv7_full.pt"
  8. state_dict = torch.load(weights_path, map_location=d)
  9. model.load_state_dict(state_dict)
  10. except FileNotFoundError:
  11. for m in model.modules():
  12. if isinstance(m, nn.Linear):
  13. nn.init.xavier_uniform_(m.weight)
  14. if m.bias is not None:
  15. nn.init.constant_(m.bias, 0)
  16.  
  17. # loss function and optimizer
  18. loss_fn = nn.MSELoss() # mean square error
  19. # Set weight_decay to 0 for L2 regularization
  20. optimizer = optim.AdamW(model.parameters(), lr=5e-6, weight_decay=0)
  21. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  22. optimizer, factor=0.98, patience=3, verbose=True
  23. )
  24. n_epochs = 300
  25. batch_size = 8192 # size of each batch
  26. batch_start = torch.arange(0, len(X_train), batch_size)
  27.  
  28. # Hold the best model
  29. best_mse = np.inf # initialise value as infinite
  30. best_weights = None
  31. history = []
  32. accumulation_steps = 2 # accumulate gradients over 2 batches
  33. for _ in tqdm.tqdm(range(n_epochs), desc="Epochs"):
  34. model.train()
  35. epoch_loss = 0.0
  36. for i, batch_idx in enumerate(batch_start):
  37. batch_X, batch_y = (
  38. X_train[batch_idx: batch_idx + batch_size],
  39. y_train[batch_idx: batch_idx + batch_size],
  40. )
  41. optimizer.zero_grad()
  42. y_pred = model.forward(batch_X).to(d)
  43. loss = loss_fn(y_pred, batch_y.view(-1, 1)).to(d)
  44. # L1 regularization
  45. l1_reg = torch.tensor(0.).to(d)
  46. for name, param in model.named_parameters():
  47. if 'weight' in name:
  48. l1_reg += torch.norm(param, 1)
  49. loss += l1_lambda * l1_reg
  50.  
  51. # L2 regularization
  52. l2_reg = torch.tensor(0.).to(d)
  53. for name, param in model.named_parameters():
  54. if 'weight' in name:
  55. l2_reg += torch.norm(param, 2)
  56. loss += l2_lambda * l2_reg
  57.  
  58. if d == torch.device("cuda"):
  59. scaler.scale(loss).backward() # NEED GPU
  60.  
  61. # accumulate gradients over several batches
  62. if (i + 1) % accumulation_steps == 0 or (i + 1) == len(batch_start):
  63. scaler.step(optimizer) # NEED GPU
  64. scaler.update() # NEED GPU
  65. model.zero_grad()
  66. y_pred = model(batch_X).to(d)
  67. loss = loss_fn(y_pred, batch_y.view(-1, 1)).to(d)
  68. # L1 regularization
  69. l1_reg = torch.tensor(0.).to(d)
  70. for name, param in model.named_parameters():
  71. if 'weight' in name:
  72. l1_reg += torch.norm(param, 1)
  73. loss += l1_lambda * l1_reg
  74.  
  75. # L2 regularization
  76. l2_reg = torch.tensor(0.).to(d)
  77. for name, param in model.named_parameters():
  78. if 'weight' in name:
  79. l2_reg += torch.norm(param, 2)
  80. loss += l2_lambda * l2_reg
  81.  
  82. loss.backward()
  83. optimizer.step()
  84. epoch_loss += loss.item() * batch_X.shape[0]
  85. epoch_loss /= len(X_train)
  86. scheduler.step(epoch_loss)
  87. history.append(epoch_loss)
  88. if epoch_loss < best_mse:
  89. best_mse = epoch_loss
  90.  
  91. print("MSE: %.2f" % best_mse)
  92. print("RMSE: %.2f" % np.sqrt(best_mse))
  93. plt.plot(history)
  94. plt.title("Epoch loss for ZL")
  95. plt.xlabel("Number of Epochs")
  96. plt.ylabel("Epoch Loss")
  97. plt.draw()
  98. plt.savefig("ai-eval-losses.jpg")
  99. model.eval()
  100. with torch.no_grad():
  101. # Test out inference with 5 samples
  102. for i in range(5):
  103. X_sample = X_val[i: i + 1]
  104. X_sample = X_sample.clone().detach().to(d)
  105. y_pred = model.forward(X_sample).to(d)
  106. best_weights = copy.deepcopy(model.state_dict())
  107. torch.save(best_weights, "zlv7_full.pt")
  108. if best_score > epoch_loss:
  109. best_weights = copy.deepcopy(model.state_dict())
  110. torch.save(best_weights, "zlv7_full.pt")
  111. torch.cuda.empty_cache()
  112. del X_train
  113. del X_val
  114. del y_train
  115. del y_val
  116. gc.enable()
  117. gc.collect()
  118. gc.disable()
  119. return epoch_loss
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement