Advertisement
sk82

Jax Continuing Training

Jan 12th, 2023
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.60 KB | None | 0 0
  1. ### Replace the "train model" function with this:
  2.  
  3. def train_model(network_size, learning_rate, epochs, B, dataset, params=None):
  4.  
  5. print("Train Model")
  6. init_fn, apply_fn = make_network(*network_size)
  7.  
  8. model_pred = jit(lambda params, x: apply_fn(params, input_mapping(x, B)))
  9. model_loss = jit(lambda params, x, y: .5 * np.mean((model_pred(params, x) - y) ** 2))
  10. model_psnr = jit(lambda params, x, y: -10 * np.log10(2.*model_loss(params, x, y)))
  11. model_grad_loss = jit(lambda params, x, y: grad(model_loss)(params, x, y))
  12.  
  13. opt_init, opt_update, get_params = optimizers.adam(learning_rate)
  14. opt_update = jit(opt_update)
  15.  
  16. init_data = dataset[0][0][0]
  17. if params is None:
  18. _, params = init_fn(rand_key, (-1, input_mapping(init_data, B).shape[-1]))
  19. opt_state = opt_init(params)
  20.  
  21. train_psnrs = []
  22. test_psnrs = []
  23. pred_imgs = []
  24. xs = []
  25.  
  26. max_psnr = 0
  27.  
  28. #for i in tqdm(range(iters), desc='train iter', leave=False):
  29. for i in range(epochs):
  30. print("Epoch", i)
  31. for k in tqdm(range(len(dataset)), desc=f'Epoch {i}/{epochs}'):
  32. # X = coordinates
  33. # Y = pixel value (r,g,b)
  34. X,Y = dataset[k]
  35.  
  36. #opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state)
  37. opt_state = opt_update(i, model_grad_loss(get_params(opt_state), X, Y), opt_state)
  38.  
  39. if i % 5 == 0:
  40. train_psnrs.append(model_psnr(get_params(opt_state), X, Y))
  41. # test_psnrs.append(model_psnr(get_params(opt_state), *test_data))
  42. # pred_imgs.append(model_pred(get_params(opt_state), test_data[0]))
  43. xs.append(i)
  44.  
  45. test_coords, _ = dataset[0]
  46. prediction = model_pred(get_params(opt_state), test_coords)
  47. # print("Prediction Shape", prediction.shape)
  48. img = prediction.reshape((1024//DOWN_SCALE, 768//DOWN_SCALE, 3))
  49. # print("Img shape", img.shape)
  50. plt.imshow(img)
  51. plt.title(f"Epoch {i}")
  52. plt.show()
  53. current_psnr = train_psnrs[-1]
  54. print("PSNR", current_psnr)
  55. plt.plot(xs, train_psnrs)
  56. plt.show()
  57.  
  58. if current_psnr > max_psnr:
  59. max_psnr = current_psnr
  60. save_network(opt_state, i, current_psnr)
  61.  
  62.  
  63. return {
  64. 'state': get_params(opt_state),
  65. 'train_psnrs': train_psnrs,
  66. 'test_psnrs': test_psnrs,
  67. # 'pred_imgs': np.stack(pred_imgs),
  68. 'xs': xs,
  69. }
  70.  
  71. ### In the "Train Networks" cell call the model as follows:
  72.  
  73. outputs = train_model(network_size, learning_rate, iters, B_dict['gauss_10.0'], ds, outputs['state'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement