Advertisement
glerium

train.py 0714

Jul 13th, 2024
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.71 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from data import get_data
  4. from model import Model
  5.  
  6. model = Model().cuda()
  7. # model = torch.load('epoch0_1994_384224.07.model')
  8. loss_fn = nn.MSELoss()
  9. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4)
  10.  
  11. data_iter = get_data(range(1995, 2017))
  12. print('Start training...')
  13. for epoch in range(10):
  14.     print(f'\nEpoch {epoch}:')
  15.     overall_loss = 0
  16.     overall_cnt = 0
  17.     for year, catdata in enumerate(data_iter, start=1979):
  18.         print(f'Year {year}')
  19.         year_loss = 0
  20.         year_cnt = 0
  21.         for i in range(catdata.shape[0] - 168 * 2):
  22.             if i != 0 and i % 100 == 0:
  23.                 print(f"step={i}, loss={year_loss/year_cnt:.4f}")
  24.             optimizer.zero_grad()
  25.             input = catdata[None, i:i+168, :, :, :].cuda()
  26.             input[:, :, :, :, 1] = (input[:, :, :, :, 1] + 30) / 5300
  27.             input[:, :, :, :, 2] = input[:, :, :, :, 2] / 180 + 1
  28.             input[:, :, :, :, 3] = (input[:, :, :, :, 3] + 8000) / 14000
  29.             input[:, :, :, :, 4] = (input[:, :, :, :, 4] + 10) / 190
  30.             input[:, :, :, :, -3] = (input[:, :, :, :, -3] - 190) / 140
  31.             input[:, :, :, :, -2:] = (input[:, :, :, :, -2:] + 50) / 100
  32.             # print(input.min(), input.max())
  33.            
  34.             target = catdata[None, i+168:i+168*2, :, :, :].cuda()
  35.             target[:, :, :, :, 1] = (target[:, :, :, :, 1] + 30) / 5300
  36.             target[:, :, :, :, 2] = target[:, :, :, :, 2] / 180 + 1
  37.             target[:, :, :, :, 3] = (target[:, :, :, :, 3] + 8000) / 14000
  38.             target[:, :, :, :, 4] = (target[:, :, :, :, 4] + 10) / 190
  39.             target[:, :, :, :, -3] = (target[:, :, :, :, -3] - 190) / 140
  40.             target[:, :, :, :, -2:] = (target[:, :, :, :, -2:] + 50) / 100
  41.            
  42.             output = model(input, target)
  43.            
  44.             # output[:, :, :, :, -3] = output[:, :, :, :, -3] * 140 + 190
  45.             # output[:, :, :, :, -2:] = (output[:, :, :, :, -2:] * 100) - 50
  46.             # target[:, :, :, :, -3] = target[:, :, :, :, -3] * 140 + 190
  47.             # target[:, :, :, :, -2:] = (target[:, :, :, :, -2:] * 100) - 50
  48.            
  49.             output = output[:, :, :, :, -5:]
  50.             target = target[:, :, :, :, -5:]
  51.             loss = loss_fn(output, target)
  52.            
  53.             loss.backward()
  54.             optimizer.step()
  55.             year_loss += loss.item()
  56.             year_cnt += 1
  57.             overall_loss += loss.item()
  58.             overall_cnt += 1
  59.         print(f'Year {year}, Loss={year_loss / year_cnt}')
  60.         torch.save(model, f'epoch{epoch}_{year}_{year_loss / year_cnt:.2f}.model')
  61.     print(f'Epoch {epoch}, Loss={overall_loss / overall_cnt}')
  62.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement