Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train(model, train_loader, test_loader, optimizer, loss_fn):
- losses = []
- start_step = 0
- max_ssim = 0
- max_psnr = 0
- ssims = []
- psnrs = []
- if resume:
- losses = saved_losses
- start_step = saved_step
- max_ssim = saved_max_ssim
- max_psnr = saved_max_psnr
- ssims = saved_ssims
- psnrs = saved_psnrs
- optimizer.load_state_dict(saved_optimizer)
- model.load_state_dict(saved_model)
- print(f'model loaded from {last_data_path}')
- for step in range(start_step+1, steps+1):
- model.train()
- lr = init_lr
- lr = lr_schedule_cosdecay(step, steps,init_lr)
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
- x, y = next(iter(train_loader))
- x = x.to(device)
- y = y.to(device)
- out = model(x)
- # print(f"Output : {out.shape}" )
- # print(f"Target: {y.shape}")
- loss = 0.0
- l1_loss = loss_fn[0](out, y)
- loss = l1_loss
- loss.backward()
- optimizer.step()
- optimizer.zero_grad()
- losses.append(loss.item())
- print(f'loss: {loss.item():.5f}, L1_loss: {l1_loss:.5f} | step :{step}/{steps}|lr :{lr :.7f} |time_used :{(time.time() - start_time) / 60 :.1f}',end='', flush=True)
- with SummaryWriter(logdir=f'./data/{ablation}-logs', comment=f'./data/{ablation}-logs') as writer:
- writer.add_scalar('runs-loss' + ablation, loss, step)
- writer.add_scalar('runs-loss_l1' + ablation, l1_loss, step)
- if step % config['eval_step'] == 0:
- epoch = step // config['eval_step']
- save_model_dir = f'./data/trained_models_{ablation}/{epoch}.ok'
- best_model_dir = f'./data/trained_models_{ablation}/trained_model.best'
- with torch.no_grad():
- ssim_eval, psnr_eval = test(model, test_loader)
- log = f'\nstep :{step} | epoch: {epoch} | ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}'
- model_name = config['model_name']
- print(log)
- with open(f'./data/{ablation}-logs/{ablation +"_"+ model_name}.txt', 'a') as f:
- f.write(log + '\n')
- ssims.append(ssim_eval)
- psnrs.append(psnr_eval)
- if psnr_eval > max_psnr:
- max_ssim = max(max_ssim, ssim_eval)
- max_psnr = max(max_psnr, psnr_eval)
- print(f'\n model saved at step :{step}| epoch: {epoch} | max_psnr:{max_psnr:.4f}| max_ssim:{max_ssim:.4f}')
- torch.save({
- 'epoch': epoch,
- 'step': step,
- 'max_psnr': max_psnr,
- 'max_ssim': max_ssim,
- 'ssims': ssims,
- 'psnrs': psnrs,
- 'losses': losses,
- 'model': model.state_dict(),
- 'optimizer': optimizer.state_dict()
- }, best_model_dir)
- torch.save({
- 'epoch': epoch,
- 'step': step,
- 'max_psnr': max_psnr,
- 'max_ssim': max_ssim,
- 'ssims': ssims,
- 'psnrs': psnrs,
- 'losses': losses,
- 'model': model.state_dict(),
- 'optimizer': optimizer.state_dict()
- }, save_model_dir)
- np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_losses.npy', losses)
- np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_ssims.npy', ssims)
- np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_psnrs.npy', psnrs)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement