Advertisement
sk82

working interpolator ???

Nov 10th, 2022 (edited)
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.65 KB | None | 0 0
  1. # Copyright 2022 Google LLC
  2.  
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6.  
  7. #     https://www.apache.org/licenses/LICENSE-2.0
  8.  
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A wrapper class for running a frame interpolation TF2 saved model.
  16.  
  17. Usage:
  18.  model_path='/tmp/saved_model/'
  19.  it = Interpolator(model_path)
  20.  result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
  21.  
  22.  Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
  23.  (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
  24. """
  25. from typing import Optional
  26. import numpy as np
  27. import tensorflow as tf
  28.  
  29. def saveImage(image,name):
  30.     image_in_uint8_range = np.clip(image * 255, 0.0, 255).astype(np.uint8)
  31.     # image_in_uint8 = (image_in_uint8_range + .5).astype(np.uint8)
  32.     image_data = tf.io.encode_png(image_in_uint8_range)
  33.     tf.io.write_file(name, image_data)
  34.  
  35.  
  36. def _pad_to_align(x, align):
  37.   """Pad image batch x so width and height divide by align.
  38.  
  39.  Args:
  40.    x: Image batch to align.
  41.    align: Number to align to.
  42.  
  43.  Returns:
  44.    1) An image padded so width % align == 0 and height % align == 0.
  45.    2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
  46.      to undo the padding.
  47.  """
  48.   # Input checking.
  49.   assert np.ndim(x) == 4
  50.   assert align > 0, 'align must be a positive number.'
  51.  
  52.   height, width = x.shape[-3:-1]
  53.   height_to_pad = (align - height % align) if height % align != 0 else 0
  54.   width_to_pad = (align - width % align) if width % align != 0 else 0
  55.  
  56.   bbox_to_pad = {
  57.       'offset_height': height_to_pad // 2,
  58.       'offset_width': width_to_pad // 2,
  59.       'target_height': height + height_to_pad,
  60.       'target_width': width + width_to_pad
  61.   }
  62.   padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
  63.   bbox_to_crop = {
  64.       'offset_height': height_to_pad // 2,
  65.       'offset_width': width_to_pad // 2,
  66.       'target_height': height,
  67.       'target_width': width
  68.   }
  69.   return padded_x, bbox_to_crop
  70.  
  71.  
  72. class Interpolator:
  73.   """A class for generating interpolated frames between two input frames.
  74.  
  75.  Uses TF2 saved model format.
  76.  """
  77.  
  78.   def __init__(self, model_path: str,
  79.                align: Optional[int] = None) -> None:
  80.     """Loads a saved model.
  81.  
  82.    Args:
  83.      model_path: Path to the saved model. If none are provided, uses the
  84.        default model.
  85.      align: 'If >1, pad the input size so it divides with this before
  86.        inference.'
  87.    """
  88.     self._model = tf.compat.v2.saved_model.load(model_path)
  89.     import tensorflow_addons.image as tfa_image
  90.     self._modelKeras = tf.keras.models.load_model(model_path, custom_objects={'tfa_image' : tfa_image}, compile=False)
  91.     self._align = align
  92.  
  93.   def interpolate(self, x0: np.ndarray, x1: np.ndarray,
  94.                   dt: np.ndarray) -> np.ndarray:
  95.     """Generates an interpolated frame between given two batches of frames.
  96.  
  97.    All input tensors should be np.float32 datatype.
  98.  
  99.    Args:
  100.      x0: First image batch. Dimensions: (batch_size, height, width, channels)
  101.      x1: Second image batch. Dimensions: (batch_size, height, width, channels)
  102.      dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
  103.  
  104.    Returns:
  105.      The result with dimensions (batch_size, height, width, channels).
  106.    """
  107.    
  108.    
  109.     tf.keras.backend.set_floatx('float32')
  110.     tf.random.set_seed(42)
  111.     np.random.seed(42)
  112.    
  113.    
  114.     if self._align is not None:
  115.       x0, bbox_to_crop = _pad_to_align(x0, self._align)
  116.       x1, _ = _pad_to_align(x1, self._align)
  117.  
  118.     inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
  119.     result = self._model(inputs, training=False)
  120.     # image = result['image']
  121.  
  122.     # if self._align is not None:
  123.     #   image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
  124.     # return image.numpy()
  125.     from ..models.film_net import util
  126.     import time
  127.     from ..models.film_net import options as film_net_options
  128.     from ..models.film_net import feature_extractor
  129.     from ..models.film_net import fusion
  130.     from ..models.film_net import pyramid_flow_estimator
  131.  
  132.     config = film_net_options.Options()
  133.    
  134.     # from the training configs
  135.     config.pyramid_levels = 7
  136.     config.fusion_pyramid_levels = 5
  137.     config.specialized_levels = 3
  138.     config.sub_levels = 4
  139.     config.flow_convs = [3, 3, 3, 3]
  140.     config.flow_filters = [32, 64, 128, 256]
  141.     config.filters = 64
  142.  
  143.     x0_decoded = x0
  144.     x1_decoded = x1
  145.  
  146.     # shuffle images
  147.     image_pyramids = [
  148.         util.build_image_pyramid(x0_decoded, config),
  149.         util.build_image_pyramid(x1_decoded, config)
  150.     ]
  151.  
  152.     # Siamese feature pyramids:
  153.     # extract = feature_extractor.FeatureExtractor('feat_net', config)
  154.    
  155.     print("=== Extract ===")
  156.     # extract = self._modelKeras.get_layer('feat_net')
  157.     # extract = tf.keras.backend.function([extract.input], [extract.output])
  158.     extract_layer = self._modelKeras.get_layer('feat_net')
  159.     extract = tf.keras.Model(inputs = extract_layer.input, outputs = extract_layer.output)
  160.    
  161.     print("Building feature pyramids")
  162.     feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
  163.     print("...done")
  164.  
  165.     # predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
  166.     #     'predict_flow', config)
  167.  
  168.     # # # Predict forward flow.
  169.     # forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
  170.     #                                              feature_pyramids[1])
  171.     # # # Predict backward flow.
  172.     # backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
  173.     #                                               feature_pyramids[0])
  174.  
  175.     fusion_pyramid_levels = config.fusion_pyramid_levels
  176.  
  177.     # forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:fusion_pyramid_levels]
  178.     # backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:fusion_pyramid_levels]
  179.  
  180.     interpolationTime = np.array(dt[..., np.newaxis], dtype=np.float32)
  181.  
  182.     # interpolationTime = tf.keras.Input(shape=(1,), batch_size=None, dtype=tf.float32, name='time')
  183.  
  184.       #Using the model output instead of computing
  185.     forward_flow_pyramid = result['forward_flow_pyramid']
  186.     backward_flow_pyramid = result['backward_flow_pyramid']
  187.  
  188.  
  189.  
  190.    
  191.     # mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(interpolationTime)
  192.     # backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
  193.     # forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
  194.    
  195.  
  196.  
  197.     backward_flow  = backward_flow_pyramid
  198.     forward_flow  = forward_flow_pyramid
  199.    
  200.     #print("Backward flow pyramid shape", len(backward_flow))
  201.     #print("Forward flow pyramid shape", len(forward_flow))
  202.  
  203.     # for i in range(len(backward_flow_pyramid)):
  204.     #     backward_flow[i] = backward_flow[i] * .5
  205.     #     forward_flow[i] = forward_flow[i]  * .5
  206.     # _time = interpolationTime
  207.     # mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(_time)
  208.     # backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
  209.     # forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
  210.  
  211.     pyramids_to_warp = [
  212.         util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
  213.                                   feature_pyramids[0][:fusion_pyramid_levels]),
  214.         util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
  215.                                   feature_pyramids[1][:fusion_pyramid_levels])
  216.     ]
  217.  
  218.     #These are the warped starting and ending images (x0_warped and x1_warped), which implies that everything that leads to this is correct
  219.     forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
  220.     backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
  221.  
  222.     """
  223.    saveImage( forward_warped_pyramid[0][..., 0:3][0],'my_forwardwarpedpyramid.png')
  224.    saveImage( backward_warped_pyramid[0][..., 0:3][0],'my_backwardwarpedpyramid.png')
  225.    saveImage(result['x0_warped'][0],'their_forwardwarpedpyramid.png')
  226.    saveImage(result['x1_warped'][0],'their_backwardwarpedpyramid.png')
  227.    """
  228.  
  229.     aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
  230.                                                 backward_warped_pyramid)
  231.     aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
  232.     aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
  233.  
  234.     print("==== creating fuses ==== ")
  235.    
  236.    
  237.     fusion_layer = self._modelKeras.get_layer('fusion')
  238.     # fuse = fusion.Fusion('fusion', config, fusion_layer)
  239.    
  240.     fuse = tf.keras.backend.function([fusion_layer.input], [fusion_layer.output])
  241.    
  242.     # fusionModel = tf.keras.Model(inputs = fusion_layer.input, outputs = fusion_layer.output)
  243.     # fusion = fusion()
  244.    
  245.    
  246.     print("===== predicting on aligned pyramid")
  247.     prediction = fuse(aligned_pyramid)[0]
  248.    
  249.     print("====== output image")
  250.     print(prediction.shape)
  251.    
  252.     output_color = prediction[0][..., :3]
  253.     saveImage(output_color, "my_interpolated.png")
  254.     saveImage(result['image'][0],"their_interpolated.png")
  255.     return prediction
  256.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement