Advertisement
ridwan100

doc3d.py

Jan 8th, 2024
757
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.99 KB | None | 0 0
  1. import cv2
  2. import numpy as np
  3. import scipy.interpolate
  4. import os
  5. import csv
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from torch.utils.data import Dataset, DataLoader
  10. from scipy.io import loadmat
  11. import random
  12. import time
  13. import matplotlib.pyplot as plt
  14. from hdf5storage import loadmat
  15.  
  16. import kornia.augmentation as KA
  17. import kornia.geometry.transform as KG
  18. os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
  19. class Doc3D(Dataset):
  20.     def __init__(self, root_dir, is_train=True, num=0):
  21.         super(Doc3D, self).__init__()
  22.         # self.is_train = is_train
  23.         self.num = num
  24.         # load the list of doc3d images
  25.         if is_train:
  26.             with open('/home/ridwan/thesis/PaperEdge/doc3d_root/doc3d_train.txt', 'r') as fid:
  27.                 self.X = fid.read().splitlines()
  28.         else:
  29.             with open('/home/ridwan/thesis/PaperEdge/doc3d_root/doc3d_val.txt', 'r') as fid:
  30.                 self.X = fid.read().splitlines()
  31.         self.X = [root_dir + '/img/' + t + '.png' for t in self.X]
  32.        
  33.         # load the background images
  34.         with open('bgtex.txt', 'r') as fid:
  35.             self.bgtex = fid.read().splitlines()        
  36.  
  37.     def __len__(self):
  38.         if self.num:
  39.             return self.num
  40.         else:
  41.             return len(self.X)
  42.  
  43.     def __getitem__(self, index):
  44.         index = index % 10
  45.         t = self.X[index]
  46.         t1 = '/home/ridwan/thesis/doc3D-renderer/tex/000752a1-1637-4256-83fd-62ce48f5f88b(1).jpg'
  47.         t2= '/home/ridwan/thesis/PaperEdge/output/result_ls.png'
  48.         print(t)
  49.         im = cv2.imread(t).astype(np.float32) / 255.0
  50.         im = im[..., ::-1]
  51.         plt.figure(figsize=(12, 8))
  52.  
  53.         # Original Image
  54.         plt.subplot(2, 3, 1)
  55.         im = cv2.imread(t).astype(np.float32) / 255.0
  56.         im = im[..., ::-1]
  57.         plt.imshow(im)
  58.         plt.title('Curved Image')
  59.         plt.subplot(2, 3, 2)
  60.         im1 = cv2.imread(t1).astype(np.float32) / 255.0
  61.         im1 = im1[..., ::-1]
  62.         plt.imshow(im1)
  63.         plt.title('badlad image')
  64.  
  65.         # Image 2
  66.         plt.subplot(2, 3, 3)
  67.         im2 = cv2.imread(t2).astype(np.float32) / 255.0
  68.         im2 = im2[..., ::-1]
  69.         plt.imshow(im2)
  70.         plt.title('flattend image')
  71.  
  72.         t = t.replace('/img/', '/wc/')
  73.         t = t[:-3] + 'exr'
  74.         wc = cv2.imread(t, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED).astype(np.float32)
  75.  
  76.         t = t.replace('/wc/', '/bm/')
  77.         t = t[:-3] + 'mat'
  78.         bm = loadmat(t)['bm']
  79.         plt.subplot(2, 3, 4)
  80.  
  81.         x_coords = bm[:, :, 0]
  82.         y_coords = bm[:, :, 1]
  83.  
  84.         # Flatten the coordinates to 1D arrays
  85.         flat_x_coords = x_coords.flatten()
  86.         flat_y_coords = y_coords.flatten()
  87.  
  88.         # Remove NaN values (optional, depending on your requirements)
  89.         valid_indices = ~np.isnan(flat_x_coords)
  90.         flat_x_coords = flat_x_coords[valid_indices]
  91.         flat_y_coords = flat_y_coords[valid_indices]
  92.  
  93.         # Plot using scatter plot
  94.         plt.scatter(flat_x_coords, flat_y_coords, s=1)  # Adjust the marker size (s) as needed
  95.         plt.title('Scatter Plot of Coordinates')
  96.         plt.xlabel('X-coordinate')
  97.         plt.ylabel('Y-coordinate')
  98.         plt.show()        # print(bm)
  99.        
  100.         # random sample a background image
  101.         ind = random.randint(0, len(self.bgtex) - 1)
  102.         bg = cv2.imread(self.bgtex[ind]).astype(np.float32) / 255.0
  103.         bg = cv2.resize(bg, (200, 200))
  104.         bg = np.tile(bg, (3, 3, 1))
  105.  
  106.         im = torch.from_numpy(im.transpose((2, 0, 1)).copy())
  107.         wc = torch.from_numpy(wc.transpose((2, 0, 1)).copy())
  108.         bm = torch.from_numpy(bm.transpose((2, 0, 1)).copy())
  109.         bg = torch.from_numpy(bg.transpose((2, 0, 1)).copy())
  110.  
  111.         return im, wc, bm, bg
  112.  
  113.  
  114.  
  115. class Doc3DDataAug(nn.Module):
  116.     def __init__(self):
  117.         super(Doc3DDataAug, self).__init__()
  118.         self.cj = KA.ColorJitter(0.1, 0.1, 0.1, 0.1)
  119.    
  120.     def forward(self, img, wc, bm, bg):
  121.         # tight crop
  122.         mask = (wc[:, 0] != 0) & (wc[:, 1] != 0) & (wc[:, 2] != 0)
  123.        
  124.         B = img.size(0)
  125.         c = torch.randint(20, (B, 5))
  126.         img_list = []
  127.         bm_list = []
  128.         for ii in range(B):
  129.             x_img = img[ii]
  130.             x_bm = bm[ii]
  131.             x_msk = mask[ii]
  132.             y, x = x_msk.nonzero(as_tuple=True)
  133.             minx = x.min()
  134.             maxx = x.max()
  135.             miny = y.min()
  136.             maxy = y.max()
  137.             x_img = x_img[:, miny : maxy + 1, minx : maxx + 1]
  138.             x_msk = x_msk[None, miny : maxy + 1, minx : maxx + 1]
  139.  
  140.             # padding
  141.             x_img = F.pad(x_img, c[ii, : 4].tolist())
  142.             x_msk = F.pad(x_msk, c[ii, : 4].tolist())
  143.  
  144.             x_bm[0, :, :] = (x_bm[0, :, :] - minx + c[ii][0]) / x_img.size(2) * 2 - 1
  145.             x_bm[1, :, :] = (x_bm[1, :, :] - miny + c[ii][2]) / x_img.size(1) * 2 - 1
  146.  
  147.             # replace bg
  148.             if c[ii][-1] > 2:
  149.                 x_bg = bg[ii][:, :x_img.size(1), :x_img.size(2)]
  150.             else:
  151.                 x_bg = torch.ones_like(x_img) * torch.rand((3, 1, 1), device=x_img.device)
  152.             x_msk = x_msk.float()
  153.             x_img = x_img * x_msk + x_bg * (1. - x_msk)
  154.  
  155.             # resize
  156.             x_img = KG.resize(x_img[None, :], (256, 256))
  157.             img_list.append(x_img)
  158.             bm_list.append(x_bm)
  159.         img = torch.cat(img_list)
  160.         bm = torch.stack(bm_list)
  161.         # jitter color
  162.         img = self.cj(img)
  163.         return img, bm
  164.  
  165.  
  166. if __name__ == '__main__':
  167.     dt = Doc3D(root_dir='/home/ridwan/thesis/PaperEdge/doc3d_root')
  168.     from visdom import Visdom
  169.     vis = Visdom(port=8097)
  170.     x, xt, y, yt, t = dt[999]
  171.  
  172.     vis.image(x.clamp(0, 1), opts={'caption': 'x'}, win='x')
  173.     vis.image(xt.clamp(0, 1), opts={'caption': 'xt'}, win='xt')
  174.     vis.image(y.clamp(0, 1), opts={'caption': 'y'}, win='y')
  175.     vis.image(yt.clamp(0, 1), opts={'caption': 'yt'}, win='yt')
  176.     vis.image(t, opts={'caption': 't'}, win='t')
  177.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement