Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ### Replace the "train model" function with this:
- def train_model(network_size, learning_rate, epochs, B, dataset, params=None):
- print("Train Model")
- init_fn, apply_fn = make_network(*network_size)
- model_pred = jit(lambda params, x: apply_fn(params, input_mapping(x, B)))
- model_loss = jit(lambda params, x, y: .5 * np.mean((model_pred(params, x) - y) ** 2))
- model_psnr = jit(lambda params, x, y: -10 * np.log10(2.*model_loss(params, x, y)))
- model_grad_loss = jit(lambda params, x, y: grad(model_loss)(params, x, y))
- opt_init, opt_update, get_params = optimizers.adam(learning_rate)
- opt_update = jit(opt_update)
- init_data = dataset[0][0][0]
- if params is None:
- _, params = init_fn(rand_key, (-1, input_mapping(init_data, B).shape[-1]))
- opt_state = opt_init(params)
- train_psnrs = []
- test_psnrs = []
- pred_imgs = []
- xs = []
- max_psnr = 0
- #for i in tqdm(range(iters), desc='train iter', leave=False):
- for i in range(epochs):
- print("Epoch", i)
- for k in tqdm(range(len(dataset)), desc=f'Epoch {i}/{epochs}'):
- # X = coordinates
- # Y = pixel value (r,g,b)
- X,Y = dataset[k]
- #opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state)
- opt_state = opt_update(i, model_grad_loss(get_params(opt_state), X, Y), opt_state)
- if i % 5 == 0:
- train_psnrs.append(model_psnr(get_params(opt_state), X, Y))
- # test_psnrs.append(model_psnr(get_params(opt_state), *test_data))
- # pred_imgs.append(model_pred(get_params(opt_state), test_data[0]))
- xs.append(i)
- test_coords, _ = dataset[0]
- prediction = model_pred(get_params(opt_state), test_coords)
- # print("Prediction Shape", prediction.shape)
- img = prediction.reshape((1024//DOWN_SCALE, 768//DOWN_SCALE, 3))
- # print("Img shape", img.shape)
- plt.imshow(img)
- plt.title(f"Epoch {i}")
- plt.show()
- current_psnr = train_psnrs[-1]
- print("PSNR", current_psnr)
- plt.plot(xs, train_psnrs)
- plt.show()
- if current_psnr > max_psnr:
- max_psnr = current_psnr
- save_network(opt_state, i, current_psnr)
- return {
- 'state': get_params(opt_state),
- 'train_psnrs': train_psnrs,
- 'test_psnrs': test_psnrs,
- # 'pred_imgs': np.stack(pred_imgs),
- 'xs': xs,
- }
- ### In the "Train Networks" cell call the model as follows:
- outputs = train_model(network_size, learning_rate, iters, B_dict['gauss_10.0'], ds, outputs['state'])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement