Advertisement
ridwan100

demo.py

Jan 8th, 2024
809
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.83 KB | None | 0 0
  1. # -*- encoding: utf-8 -*-
  2. import argparse
  3. import copy
  4. import json
  5. from pathlib import Path
  6. import warnings
  7.  
  8. import cv2
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. from networks.paperedge import GlobalWarper, LocalWarper, WarperUtil
  13. # Suppress the torch.meshgrid warning
  14. warnings.filterwarnings("ignore", category=UserWarning, module="torch.functional")
  15. cv2.setNumThreads(0)
  16. cv2.ocl.setUseOpenCL(False)
  17.  
  18.  
  19. def load_img(img_path):
  20.     im = cv2.imread(img_path).astype(np.float32) / 255.0
  21.     im = im[:, :, (2, 1, 0)]
  22.     im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA)
  23.     im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
  24.     return im
  25.  
  26.  
  27. if __name__ == '__main__':
  28.     parser = argparse.ArgumentParser()
  29.     parser.add_argument('--Enet_ckpt', type=str,
  30.                         default='models/G_w_checkpoint_13820.pt')
  31.     parser.add_argument('--Tnet_ckpt', type=str,
  32.                         default='models/L_w_checkpoint_27640.pt')
  33.     parser.add_argument('--img_path', type=str, default='images/3.jpg')
  34.     parser.add_argument('--out_dir', type=str, default='output')
  35.     args = parser.parse_args()
  36.  
  37.     img_path = args.img_path
  38.     dst_dir = args.out_dir
  39.     Path(dst_dir).mkdir(parents=True, exist_ok=True)
  40.  
  41.     netG = GlobalWarper().to('cuda')
  42.     netG.load_state_dict(torch.load(args.Enet_ckpt)['G'])
  43.     netG.eval()
  44.  
  45.     netL = LocalWarper().to('cuda')
  46.     netL.load_state_dict(torch.load(args.Tnet_ckpt)['L'])
  47.     netL.eval()
  48.  
  49.     warpUtil = WarperUtil(64).to('cuda')
  50.  
  51.     gs_d, ls_d = None, None
  52.     with torch.no_grad():
  53.         x = load_img(img_path)
  54.         x = x.unsqueeze(0)
  55.         x = x.to('cuda')
  56.         d = netG(x)  # d_E the edged-based deformation field
  57.         d = warpUtil.global_post_warp(d, 64)
  58.         gs_d = copy.deepcopy(d)
  59.  
  60.         d = F.interpolate(d, size=256, mode='bilinear', align_corners=True)
  61.         y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True)
  62.         ls_d = netL(y0)
  63.         ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True)
  64.         ls_d = ls_d.clamp(-1.0, 1.0)
  65.  
  66.     im = cv2.imread(img_path).astype(np.float32) / 255.0
  67.     im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
  68.     im = im.to('cuda').unsqueeze(0)
  69.  
  70.     gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
  71.     gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach()
  72.     tmp_y = gs_y.squeeze().permute(1, 2, 0).cpu().numpy()
  73.     cv2.imwrite(f'{dst_dir}/result_gs.png', tmp_y * 255.)
  74.  
  75.     ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
  76.     ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach()
  77.     ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy()
  78.     cv2.imwrite(f'{dst_dir}/result_ls.png', ls_y * 255.)
  79.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement