Advertisement
sk82

interpolator.py

Nov 4th, 2022
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.54 KB | None | 0 0
  1. # Copyright 2022 Google LLC
  2.  
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6.  
  7. #     https://www.apache.org/licenses/LICENSE-2.0
  8.  
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A wrapper class for running a frame interpolation TF2 saved model.
  16.  
  17. Usage:
  18.  model_path='/tmp/saved_model/'
  19.  it = Interpolator(model_path)
  20.  result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
  21.  
  22.  Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
  23.  (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
  24. """
  25. from typing import Optional
  26. import numpy as np
  27. import tensorflow as tf
  28.  
  29. def saveImage(image,name):
  30.     image_in_uint8_range = np.clip(image * 255, 0.0, 255)
  31.     image_in_uint8 = (image_in_uint8_range + .5).astype(np.uint8)
  32.     image_data = tf.io.encode_png(image_in_uint8)
  33.     tf.io.write_file(name, image_data)
  34.  
  35.  
  36. def _pad_to_align(x, align):
  37.   """Pad image batch x so width and height divide by align.
  38.  
  39.  Args:
  40.    x: Image batch to align.
  41.    align: Number to align to.
  42.  
  43.  Returns:
  44.    1) An image padded so width % align == 0 and height % align == 0.
  45.    2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
  46.      to undo the padding.
  47.  """
  48.   # Input checking.
  49.   assert np.ndim(x) == 4
  50.   assert align > 0, 'align must be a positive number.'
  51.  
  52.   height, width = x.shape[-3:-1]
  53.   height_to_pad = (align - height % align) if height % align != 0 else 0
  54.   width_to_pad = (align - width % align) if width % align != 0 else 0
  55.  
  56.   bbox_to_pad = {
  57.       'offset_height': height_to_pad // 2,
  58.       'offset_width': width_to_pad // 2,
  59.       'target_height': height + height_to_pad,
  60.       'target_width': width + width_to_pad
  61.   }
  62.   padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
  63.   bbox_to_crop = {
  64.       'offset_height': height_to_pad // 2,
  65.       'offset_width': width_to_pad // 2,
  66.       'target_height': height,
  67.       'target_width': width
  68.   }
  69.   return padded_x, bbox_to_crop
  70.  
  71.  
  72. class Interpolator:
  73.   """A class for generating interpolated frames between two input frames.
  74.  
  75.  Uses TF2 saved model format.
  76.  """
  77.  
  78.   def __init__(self, model_path: str,
  79.                align: Optional[int] = None) -> None:
  80.     """Loads a saved model.
  81.  
  82.    Args:
  83.      model_path: Path to the saved model. If none are provided, uses the
  84.        default model.
  85.      align: 'If >1, pad the input size so it divides with this before
  86.        inference.'
  87.    """
  88.     self._model = tf.compat.v2.saved_model.load(model_path)
  89.     import tensorflow_addons.image as tfa_image
  90.     self._modelKeras = tf.keras.models.load_model(model_path, custom_objects={'tfa_image' : tfa_image})
  91.     self._align = align
  92.  
  93.   def interpolate(self, x0: np.ndarray, x1: np.ndarray,
  94.                   dt: np.ndarray) -> np.ndarray:
  95.     """Generates an interpolated frame between given two batches of frames.
  96.  
  97.    All input tensors should be np.float32 datatype.
  98.  
  99.    Args:
  100.      x0: First image batch. Dimensions: (batch_size, height, width, channels)
  101.      x1: Second image batch. Dimensions: (batch_size, height, width, channels)
  102.      dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
  103.  
  104.    Returns:
  105.      The result with dimensions (batch_size, height, width, channels).
  106.    """
  107.     if self._align is not None:
  108.       x0, bbox_to_crop = _pad_to_align(x0, self._align)
  109.       x1, _ = _pad_to_align(x1, self._align)
  110.  
  111.     inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
  112.     result = self._model(inputs, training=False)
  113.     # image = result['image']
  114.  
  115.     # if self._align is not None:
  116.     #   image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
  117.     # return image.numpy()
  118.     from ..models.film_net import util
  119.     import time
  120.     from ..models.film_net import options as film_net_options
  121.     from ..models.film_net import feature_extractor
  122.     from ..models.film_net import fusion
  123.     from ..models.film_net import pyramid_flow_estimator
  124.  
  125.     config = film_net_options.Options()
  126.    
  127.     # from the training configs
  128.     config.pyramid_levels = 7
  129.     config.fusion_pyramid_levels = 5
  130.     config.specialized_levels = 3
  131.     config.sub_levels = 4
  132.     config.flow_convs = [3, 3, 3, 3]
  133.     config.flow_filters = [32, 64, 128, 256]
  134.     config.filters = 64
  135.  
  136.     x0_decoded = x0
  137.     x1_decoded = x1
  138.  
  139.     # shuffle images
  140.     image_pyramids = [
  141.         util.build_image_pyramid(x0_decoded, config),
  142.         util.build_image_pyramid(x1_decoded, config)
  143.     ]
  144.  
  145.     # Siamese feature pyramids:
  146.     extract = feature_extractor.FeatureExtractor('feat_net', config)
  147.     feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
  148.  
  149.     predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
  150.         'predict_flow', config)
  151.  
  152.     # Predict forward flow.
  153.     forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
  154.                                                  feature_pyramids[1])
  155.     # Predict backward flow.
  156.     backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
  157.                                                   feature_pyramids[0])
  158.  
  159.     fusion_pyramid_levels = config.fusion_pyramid_levels
  160.  
  161.     #forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:fusion_pyramid_levels]
  162.     #backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:fusion_pyramid_levels]
  163.  
  164.     #interpolationTime = dt[..., np.newaxis]
  165.  
  166.     #interpolationTime = tf.keras.Input(shape=(1,), batch_size=None, dtype=tf.float32, name='time')
  167.  
  168.       #Using the model output instead of computing
  169.     forward_flow_pyramid = result['forward_flow_pyramid']
  170.     backward_flow_pyramid = result['backward_flow_pyramid']
  171.  
  172.  
  173.  
  174.     """
  175.    mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(interpolationTime)
  176.    backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
  177.    forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
  178.    """
  179.  
  180.  
  181.     backward_flow  = backward_flow_pyramid
  182.     forward_flow  = forward_flow_pyramid
  183.  
  184.     for i in range(len(backward_flow_pyramid)):
  185.         backward_flow[i] = backward_flow[i] * .5
  186.         forward_flow[i] = forward_flow[i] * .5
  187.  
  188.     pyramids_to_warp = [
  189.         util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
  190.                                   feature_pyramids[0][:fusion_pyramid_levels]),
  191.         util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
  192.                                   feature_pyramids[1][:fusion_pyramid_levels])
  193.     ]
  194.  
  195.     #These are the warped starting and ending images (x0_warped and x1_warped), which implies that everything that leads to this is correct
  196.     forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
  197.     backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
  198.  
  199.     """
  200.    saveImage( forward_warped_pyramid[0][..., 0:3][0],'my_forwardwarpedpyramid.png')
  201.    saveImage( backward_warped_pyramid[0][..., 0:3][0],'my_backwardwarpedpyramid.png')
  202.    saveImage(result['x0_warped'][0],'their_forwardwarpedpyramid.png')
  203.    saveImage(result['x1_warped'][0],'their_backwardwarpedpyramid.png')
  204.    """
  205.  
  206.     aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
  207.                                                 backward_warped_pyramid)
  208.     aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
  209.     aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
  210.  
  211.     print("==== creating fuses ==== ")
  212.    
  213.    
  214.     fusion_layer = self._modelKeras.get_layer('fusion')
  215.     fuse = fusion.Fusion('fusion', config, fusion_layer)
  216.     # fusionModel = tf.keras.Model(inputs = fusion_layer.input, outputs = fusion_layer.output)
  217.     # fusion = fusion()
  218.    
  219.    
  220.     print("===== predicting on aligned pyramid")
  221.     prediction = fuse(aligned_pyramid)
  222.    
  223.     print("====== output image")
  224.     print(prediction)
  225.    
  226.     output_color = prediction[..., :3]
  227.     saveImage(output_color[0], "my_interpolated.png")
  228.     saveImage(result['image'][0],"their_interpolated.png")
  229.     return prediction
  230.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement