Advertisement
alkkofficial

Untitled

Jun 25th, 2023
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.84 KB | None | 0 0
  1. def cycle(self, X_train, y_train, X_val, y_val, best_score, l1_lambda=0.001, l2_lambda=0.001):
  2. model = Agent().to(d)
  3. X_train, y_train, X_val, y_val = X_train.to(
  4. d), y_train.to(d), X_val.to(d), y_val.to(d)
  5. # Weight initialization
  6. try:
  7. weights_path = "./zlv7_full.pt"
  8. state_dict = torch.load(weights_path, map_location=d)
  9. model.load_state_dict(state_dict)
  10. except FileNotFoundError:
  11. for m in model.modules():
  12. if isinstance(m, nn.Linear):
  13. nn.init.xavier_uniform_(m.weight)
  14. if m.bias is not None:
  15. nn.init.constant_(m.bias, 0)
  16.  
  17. # loss function and optimizer
  18. loss_fn = nn.MSELoss() # mean square error
  19. # loss_fn2 = nn.HuberLoss()
  20. # loss_fn3 =
  21. # Set weight_decay to 0 for L2 regularization
  22. optimizer = optim.AdamW(
  23. model.parameters(), lr=1e-5, weight_decay=0.003)
  24. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  25. optimizer, factor=0.98, patience=3, verbose=True
  26. )
  27. n_epochs = 300
  28. batch_size = 8192 # size of each batch
  29. batch_start = torch.arange(0, len(X_train), batch_size)
  30.  
  31. # Hold the best model
  32. best_mse = np.inf # initialise value as infinite
  33. best_weights = None
  34. history = []
  35. accumulation_steps = 2 # accumulate gradients over 2 batches
  36. for _ in tqdm.tqdm(range(n_epochs), desc="Epochs"):
  37. model.train()
  38. epoch_loss = 0.0
  39. for i, batch_idx in enumerate(batch_start):
  40. batch_X, batch_y = (
  41. X_train[batch_idx: batch_idx + batch_size],
  42. y_train[batch_idx: batch_idx + batch_size],
  43. )
  44. batch_X, batch_y = batch_X.to(dtype=torch.float32), batch_y.to(dtype=torch.float32)
  45. optimizer.zero_grad()
  46. y_pred = model.forward(batch_X).to(d)
  47. loss = loss_fn(y_pred, batch_y.view(-1, 1)).to(d)
  48. # L1 regularization
  49. l1_reg = torch.tensor(0.).to(d)
  50. for name, param in model.named_parameters():
  51. if 'weight' in name:
  52. l1_reg += torch.norm(param, 1)
  53. loss += l1_lambda * l1_reg
  54.  
  55. # L2 regularization
  56. l2_reg = torch.tensor(0.).to(d)
  57. for name, param in model.named_parameters():
  58. if 'weight' in name:
  59. l2_reg += torch.norm(param, 2)
  60. loss += l2_lambda * l2_reg
  61.  
  62. if d == torch.device("cuda"):
  63. scaler.scale(loss).backward() # NEED GPU
  64.  
  65. # accumulate gradients over several batches
  66. if (i + 1) % accumulation_steps == 0 or (i + 1) == len(batch_start):
  67. scaler.step(optimizer) # NEED GPU
  68. scaler.update() # NEED GPU
  69. model.zero_grad()
  70. y_pred = model(batch_X).to(d)
  71. loss = loss_fn(y_pred, batch_y.view(-1, 1)).to(d)
  72. # L1 regularization
  73. l1_reg = torch.tensor(0.).to(d)
  74. for name, param in model.named_parameters():
  75. if 'weight' in name:
  76. l1_reg += torch.norm(param, 1)
  77. loss += l1_lambda * l1_reg
  78.  
  79. # L2 regularization
  80. l2_reg = torch.tensor(0.).to(d)
  81. for name, param in model.named_parameters():
  82. if 'weight' in name:
  83. l2_reg += torch.norm(param, 2)
  84. loss += l2_lambda * l2_reg
  85.  
  86. loss.backward()
  87. optimizer.step()
  88. epoch_loss += loss.item() * batch_X.shape[0]
  89. epoch_loss /= len(X_train)
  90. scheduler.step(epoch_loss)
  91. history.append(epoch_loss)
  92. if epoch_loss < best_mse:
  93. best_mse = epoch_loss
  94.  
  95. print("MSE: %.2f" % best_mse)
  96. print("RMSE: %.2f" % np.sqrt(best_mse))
  97. plt.plot(history)
  98. plt.title("Epoch loss for ZL")
  99. plt.xlabel("Number of Epochs")
  100. plt.ylabel("Epoch Loss")
  101. plt.draw()
  102. plt.savefig("ai-eval-losses.jpg")
  103. best_weights = copy.deepcopy(model.state_dict())
  104. torch.save(best_weights, "zlv7_full.pt")
  105. if best_score > epoch_loss:
  106. best_weights = copy.deepcopy(model.state_dict())
  107. torch.save(best_weights, "zlv7_full.pt")
  108. if d == torch.device("cuda"):
  109. torch.cuda.empty_cache()
  110. del X_train
  111. del X_val
  112. del y_train
  113. del y_val
  114. gc.enable()
  115. gc.collect()
  116. gc.disable()
  117. return epoch_loss
  118.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement