Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Copyright 2022 Google LLC
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- # https://www.apache.org/licenses/LICENSE-2.0
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """A wrapper class for running a frame interpolation TF2 saved model.
- Usage:
- model_path='/tmp/saved_model/'
- it = Interpolator(model_path)
- result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
- Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
- (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
- """
- from typing import Optional
- import numpy as np
- import tensorflow as tf
- def saveImage(image,name):
- image_in_uint8_range = np.clip(image * 255, 0.0, 255).astype(np.uint8)
- # image_in_uint8 = (image_in_uint8_range + .5).astype(np.uint8)
- image_data = tf.io.encode_png(image_in_uint8_range)
- tf.io.write_file(name, image_data)
- def _pad_to_align(x, align):
- """Pad image batch x so width and height divide by align.
- Args:
- x: Image batch to align.
- align: Number to align to.
- Returns:
- 1) An image padded so width % align == 0 and height % align == 0.
- 2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
- to undo the padding.
- """
- # Input checking.
- assert np.ndim(x) == 4
- assert align > 0, 'align must be a positive number.'
- height, width = x.shape[-3:-1]
- height_to_pad = (align - height % align) if height % align != 0 else 0
- width_to_pad = (align - width % align) if width % align != 0 else 0
- bbox_to_pad = {
- 'offset_height': height_to_pad // 2,
- 'offset_width': width_to_pad // 2,
- 'target_height': height + height_to_pad,
- 'target_width': width + width_to_pad
- }
- padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
- bbox_to_crop = {
- 'offset_height': height_to_pad // 2,
- 'offset_width': width_to_pad // 2,
- 'target_height': height,
- 'target_width': width
- }
- return padded_x, bbox_to_crop
- class Interpolator:
- """A class for generating interpolated frames between two input frames.
- Uses TF2 saved model format.
- """
- def __init__(self, model_path: str,
- align: Optional[int] = None) -> None:
- """Loads a saved model.
- Args:
- model_path: Path to the saved model. If none are provided, uses the
- default model.
- align: 'If >1, pad the input size so it divides with this before
- inference.'
- """
- self._model = tf.compat.v2.saved_model.load(model_path)
- import tensorflow_addons.image as tfa_image
- self._modelKeras = tf.keras.models.load_model(model_path, custom_objects={'tfa_image' : tfa_image}, compile=False)
- self._align = align
- def interpolate(self, x0: np.ndarray, x1: np.ndarray,
- dt: np.ndarray) -> np.ndarray:
- """Generates an interpolated frame between given two batches of frames.
- All input tensors should be np.float32 datatype.
- Args:
- x0: First image batch. Dimensions: (batch_size, height, width, channels)
- x1: Second image batch. Dimensions: (batch_size, height, width, channels)
- dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
- Returns:
- The result with dimensions (batch_size, height, width, channels).
- """
- tf.keras.backend.set_floatx('float32')
- tf.random.set_seed(42)
- np.random.seed(42)
- if self._align is not None:
- x0, bbox_to_crop = _pad_to_align(x0, self._align)
- x1, _ = _pad_to_align(x1, self._align)
- inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
- result = self._model(inputs, training=False)
- # image = result['image']
- # if self._align is not None:
- # image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
- # return image.numpy()
- from ..models.film_net import util
- import time
- from ..models.film_net import options as film_net_options
- from ..models.film_net import feature_extractor
- from ..models.film_net import fusion
- from ..models.film_net import pyramid_flow_estimator
- config = film_net_options.Options()
- # from the training configs
- config.pyramid_levels = 7
- config.fusion_pyramid_levels = 5
- config.specialized_levels = 3
- config.sub_levels = 4
- config.flow_convs = [3, 3, 3, 3]
- config.flow_filters = [32, 64, 128, 256]
- config.filters = 64
- x0_decoded = x0
- x1_decoded = x1
- # shuffle images
- image_pyramids = [
- util.build_image_pyramid(x0_decoded, config),
- util.build_image_pyramid(x1_decoded, config)
- ]
- # Siamese feature pyramids:
- # extract = feature_extractor.FeatureExtractor('feat_net', config)
- print("=== Extract ===")
- # extract = self._modelKeras.get_layer('feat_net')
- # extract = tf.keras.backend.function([extract.input], [extract.output])
- extract_layer = self._modelKeras.get_layer('feat_net')
- extract = tf.keras.Model(inputs = extract_layer.input, outputs = extract_layer.output)
- print("Building feature pyramids")
- feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
- print("...done")
- # predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
- # 'predict_flow', config)
- # # # Predict forward flow.
- # forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
- # feature_pyramids[1])
- # # # Predict backward flow.
- # backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
- # feature_pyramids[0])
- fusion_pyramid_levels = config.fusion_pyramid_levels
- # forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:fusion_pyramid_levels]
- # backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:fusion_pyramid_levels]
- interpolationTime = np.array(dt[..., np.newaxis], dtype=np.float32)
- # interpolationTime = tf.keras.Input(shape=(1,), batch_size=None, dtype=tf.float32, name='time')
- #Using the model output instead of computing
- forward_flow_pyramid = result['forward_flow_pyramid']
- backward_flow_pyramid = result['backward_flow_pyramid']
- # mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(interpolationTime)
- # backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
- # forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
- backward_flow = backward_flow_pyramid
- forward_flow = forward_flow_pyramid
- #print("Backward flow pyramid shape", len(backward_flow))
- #print("Forward flow pyramid shape", len(forward_flow))
- # for i in range(len(backward_flow_pyramid)):
- # backward_flow[i] = backward_flow[i] * .5
- # forward_flow[i] = forward_flow[i] * .5
- # _time = interpolationTime
- # mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(_time)
- # backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
- # forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
- pyramids_to_warp = [
- util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
- feature_pyramids[0][:fusion_pyramid_levels]),
- util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
- feature_pyramids[1][:fusion_pyramid_levels])
- ]
- #These are the warped starting and ending images (x0_warped and x1_warped), which implies that everything that leads to this is correct
- forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
- backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
- """
- saveImage( forward_warped_pyramid[0][..., 0:3][0],'my_forwardwarpedpyramid.png')
- saveImage( backward_warped_pyramid[0][..., 0:3][0],'my_backwardwarpedpyramid.png')
- saveImage(result['x0_warped'][0],'their_forwardwarpedpyramid.png')
- saveImage(result['x1_warped'][0],'their_backwardwarpedpyramid.png')
- """
- aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
- backward_warped_pyramid)
- aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
- aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
- print("==== creating fuses ==== ")
- fusion_layer = self._modelKeras.get_layer('fusion')
- # fuse = fusion.Fusion('fusion', config, fusion_layer)
- fuse = tf.keras.backend.function([fusion_layer.input], [fusion_layer.output])
- # fusionModel = tf.keras.Model(inputs = fusion_layer.input, outputs = fusion_layer.output)
- # fusion = fusion()
- print("===== predicting on aligned pyramid")
- prediction = fuse(aligned_pyramid)[0]
- print("====== output image")
- print(prediction.shape)
- output_color = prediction[0][..., :3]
- saveImage(output_color, "my_interpolated.png")
- saveImage(result['image'][0],"their_interpolated.png")
- return prediction
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement