Advertisement
sk82

fusion.py

Nov 4th, 2022
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.11 KB | None | 0 0
  1.  
  2. from typing import List
  3.  
  4. from . import options
  5. import tensorflow as tf
  6.  
  7. import numpy as np
  8. import math
  9.  
  10.  
  11. def _relu(x: tf.Tensor) -> tf.Tensor:
  12.   return tf.nn.leaky_relu(x, alpha=0.2)
  13.  
  14.  
  15. # def make_init(inp):
  16. #   def init_function(shape, dtype=None):
  17. #     print("=== Conv2d init ===")
  18. #     print("Init Shape", shape)
  19. #     print("Numpy Shape", inp.shape)
  20. #     kernel = np.zeros(shape)
  21. #     # print(dir(weights))
  22. #     return kernel
  23. #   return init_function
  24.  
  25. def make_init(idx, fusion_layer):
  26.   def init_function(shape, dtype=None):
  27.     print("=== Conv2d init idx (%d) ===" % (idx))
  28.     print("Init shape", shape)
  29.     fuseLayer = fusion_layer.variables[idx]
  30.     pretrained_shape = fuseLayer.shape
  31.     print("Pretrained shape", pretrained_shape)
  32.     # conv kernel
  33.     if len(shape) == 4:
  34.       kernel = fuseLayer
  35.     # bias
  36.     else:
  37.       kernel = fuseLayer[:shape[0]]
  38.     return kernel
  39.   return init_function
  40.  
  41. _NUMBER_OF_COLOR_CHANNELS = 3
  42.  
  43. class Fusion(tf.keras.layers.Layer):
  44.   """The decoder."""
  45.  
  46.   def __init__(self, name: str, config: options.Options, fusion_layer):
  47.     super().__init__(name=name)
  48.  
  49.     # kernels, biases = self.getModelFusionData(model)
  50.  
  51.     # Each item 'convs[i]' will contain the list of convolutions to be applied
  52.     # for pyramid level 'i'.
  53.     self.convs: List[List[tf.keras.layers.Layer]] = []
  54.  
  55.     # Store the levels, so we can verify right number of levels in call().
  56.     self.levels = config.fusion_pyramid_levels
  57.  
  58.     # Create the convolutions. Roughly following the feature extractor, we
  59.     # double the number of filters when the resolution halves, but only up to
  60.     # the specialized_levels, after which we use the same number of filters on
  61.     # all levels.
  62.     #
  63.     # We create the convs in fine-to-coarse order, so that the array index
  64.     # for the convs will correspond to our normal indexing (0=finest level).
  65.     idx = 0
  66.     for i in range(config.fusion_pyramid_levels - 1):
  67.       m = config.specialized_levels
  68.       k = config.filters
  69.       num_filters = (k << i) if i < m else (k << m)
  70.  
  71.       convs: List[tf.keras.layers.Layer] = []
  72.       convs.append(
  73.           tf.keras.layers.Conv2D(
  74.               filters=num_filters,
  75.               kernel_size=[2, 2],
  76.               padding='same',
  77.               kernel_initializer = make_init(idx, fusion_layer),
  78.               bias_initializer = make_init(idx+1, fusion_layer)))
  79.       idx += 2
  80.       convs.append(
  81.           tf.keras.layers.Conv2D(
  82.               filters=num_filters,
  83.               kernel_size=[3, 3],
  84.               padding='same',
  85.               activation=_relu,
  86.               kernel_initializer =  make_init(idx, fusion_layer),
  87.               bias_initializer = make_init(idx+1, fusion_layer)))
  88.       idx += 2
  89.       convs.append(
  90.           tf.keras.layers.Conv2D(
  91.               filters=num_filters,
  92.               kernel_size=[3, 3],
  93.               padding='same',
  94.               activation=_relu,
  95.               kernel_initializer = make_init(idx, fusion_layer),
  96.               bias_initializer = make_init(idx+1, fusion_layer)))
  97.       idx += 2
  98.       self.convs.append(convs)
  99.  
  100.  
  101.     # The final convolution that outputs RGB:
  102.     self.output_conv = tf.keras.layers.Conv2D(
  103.         filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1,
  104.         kernel_initializer = make_init(idx, fusion_layer),
  105.         bias_initializer = make_init(idx+1, fusion_layer)
  106.         )
  107.  
  108.   def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor:
  109.     """Runs the fusion module.
  110.  
  111.    Args:
  112.      pyramid: The input feature pyramid as list of tensors. Each tensor being
  113.        in (B x H x W x C) format, with finest level tensor first.
  114.  
  115.    Returns:
  116.      A batch of RGB images.
  117.    Raises:
  118.      ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
  119.        the constructor.
  120.    """
  121.     if len(pyramid) != self.levels:
  122.       raise ValueError(
  123.           'Fusion called with different number of pyramid levels '
  124.           f'{len(pyramid)} than it was configured for, {self.levels}.')
  125.  
  126.     # As a slight difference to a conventional decoder (e.g. U-net), we don't
  127.     # apply any extra convolutions to the coarsest level, but just pass it
  128.     # to finer levels for concatenation. This choice has not been thoroughly
  129.     # evaluated, but is motivated by the educated guess that the fusion part
  130.     # probably does not need large spatial context, because at this point the
  131.     # features are spatially aligned by the preceding warp.
  132.     print("== [FUSE CALL]== ")
  133.     net = pyramid[-1]
  134.  
  135.     # Loop starting from the 2nd coarsest level:
  136.     for i in reversed(range(0, self.levels - 1)):
  137.       # Resize the tensor from coarser level to match for concatenation.
  138.       level_size = tf.shape(pyramid[i])[1:3]
  139.       net = tf.image.resize(net, level_size,
  140.                             tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  141.       # print("Level", i, "Net Shape", net.shape)
  142.       net = self.convs[i][0](net)
  143.       net = tf.concat([pyramid[i], net], axis=-1)
  144.       net = self.convs[i][1](net)
  145.       net = self.convs[i][2](net)
  146.     net = self.output_conv(net)
  147.     print("== [FUSE DONE] ==")
  148.     return net
  149.  
  150.  
  151.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement