Advertisement
iSach

nerfplayer_field.py

Apr 28th, 2023
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 16.74 KB | None | 0 0
  1. # Copyright 2022 The Nerfstudio Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14.  
  15. """
  16. Field implementations for NeRFPlayer (https://arxiv.org/abs/2210.15947) implementation with nerfacto backbone
  17. """
  18.  
  19.  
  20. from typing import Dict, Optional, Tuple
  21.  
  22. import numpy as np
  23. import torch
  24. from torch.nn.parameter import Parameter
  25. from torchtyping import TensorType
  26.  
  27. from nerfstudio.cameras.rays import Frustums, RaySamples
  28. from nerfstudio.data.scene_box import SceneBox
  29. from nerfstudio.field_components.activations import trunc_exp
  30. from nerfstudio.field_components.embedding import Embedding
  31. from nerfstudio.field_components.field_heads import (
  32. FieldHeadNames,
  33. )
  34. from nerfstudio.field_components.spatial_distortions import SpatialDistortion
  35. from nerfstudio.field_components.temporal_grid import TemporalGridEncoder
  36. from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn
  37.  
  38. try:
  39. import tinycudann as tcnn
  40. except ImportError:
  41. # tinycudann module doesn't exist
  42. pass
  43.  
  44.  
  45. class TemporalHashMLPDensityField(Field):
  46. """A lightweight temporal density field module.
  47.  
  48. Args:
  49. aabb: Parameters of scene aabb bounds
  50. temporal_dim: Hashing grid parameter. A higher temporal dim means a higher temporal frequency.
  51. num_layers: Number of hidden layers
  52. hidden_dim: Dimension of hidden layers
  53. spatial_distortion: Spatial distortion module
  54. num_levels: Hashing grid parameter. Used for initialize TemporalGridEncoder class.
  55. max_res: Hashing grid parameter. Used for initialize TemporalGridEncoder class.
  56. base_res: Hashing grid parameter. Used for initialize TemporalGridEncoder class.
  57. log2_hashmap_size: Hashing grid parameter. Used for initialize TemporalGridEncoder class.
  58. features_per_level: Hashing grid parameter. Used for initialize TemporalGridEncoder class.
  59. """
  60.  
  61. def __init__(
  62. self,
  63. aabb: TensorType,
  64. temporal_dim: int = 64,
  65. num_layers: int = 2,
  66. hidden_dim: int = 64,
  67. spatial_distortion: Optional[SpatialDistortion] = None,
  68. num_levels: int = 8,
  69. max_res: int = 1024,
  70. base_res: int = 16,
  71. log2_hashmap_size: int = 18,
  72. features_per_level: int = 2,
  73. ) -> None:
  74. super().__init__()
  75. # from .temporal_grid import test; test() # DEBUG
  76. self.aabb = Parameter(aabb, requires_grad=False)
  77. self.spatial_distortion = spatial_distortion
  78. growth_factor = np.exp((np.log(max_res) - np.log(base_res)) / (num_levels - 1))
  79.  
  80. self.encoding = TemporalGridEncoder(
  81. input_dim=3,
  82. temporal_dim=temporal_dim,
  83. num_levels=num_levels,
  84. level_dim=features_per_level,
  85. per_level_scale=growth_factor,
  86. base_resolution=base_res,
  87. log2_hashmap_size=log2_hashmap_size,
  88. )
  89. self.linear = tcnn.Network(
  90. n_input_dims=num_levels * features_per_level,
  91. n_output_dims=1,
  92. network_config={
  93. "otype": "FullyFusedMLP",
  94. "activation": "ReLU",
  95. "output_activation": "None",
  96. "n_neurons": hidden_dim,
  97. "n_hidden_layers": num_layers - 1,
  98. },
  99. )
  100.  
  101. # pylint: disable=arguments-differ
  102. def density_fn(self, positions: TensorType["bs":..., 3], times: TensorType["bs", 1]) -> TensorType["bs":..., 1]:
  103. """Returns only the density. Used primarily with the density grid.
  104.  
  105. Args:
  106. positions: the origin of the samples/frustums
  107. times: the time of rays
  108. """
  109. if len(positions.shape) == 3 and len(times.shape) == 2:
  110. # position is [ray, sample, 3]; times is [ray, 1]
  111. times = times[:, None] # RaySamples can handle the shape
  112. # Need to figure out a better way to descibe positions with a ray.
  113. ray_samples = RaySamples(
  114. frustums=Frustums(
  115. origins=positions,
  116. directions=torch.ones_like(positions),
  117. starts=torch.zeros_like(positions[..., :1]),
  118. ends=torch.zeros_like(positions[..., :1]),
  119. pixel_area=torch.ones_like(positions[..., :1]),
  120. ),
  121. times=times,
  122. )
  123. density, _ = self.get_density(ray_samples)
  124. return density
  125.  
  126. def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, None]:
  127. if self.spatial_distortion is not None:
  128. positions = self.spatial_distortion(ray_samples.frustums.get_positions())
  129. positions = (positions + 2.0) / 4.0
  130. else:
  131. positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
  132. positions_flat = positions.view(-1, 3)
  133. time_flat = ray_samples.times.reshape(-1, 1)
  134. x = self.encoding(positions_flat, time_flat).to(positions)
  135. density_before_activation = self.linear(x).view(*ray_samples.frustums.shape, -1)
  136.  
  137. # Rectifying the density with an exponential is much more stable than a ReLU or
  138. # softplus, because it enables high post-activation (float32) density outputs
  139. # from smaller internal (float16) parameters.
  140. density = trunc_exp(density_before_activation)
  141. return density, None
  142.  
  143. def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None) -> dict:
  144. return {}
  145.  
  146.  
  147. class NerfplayerField(Field):
  148. """NeRFPlayer (https://arxiv.org/abs/2210.15947) field with nerfacto backbone.
  149.  
  150. Args:
  151. aabb: parameters of scene aabb bounds
  152. num_images: number of images in the dataset
  153. num_layers: number of hidden layers
  154. hidden_dim: dimension of hidden layers
  155. geo_feat_dim: output geo feat dimensions
  156. num_levels: number of levels of the hashmap for the base mlp
  157. max_res: maximum resolution of the hashmap for the base mlp
  158. log2_hashmap_size: size of the hashmap for the base mlp
  159. num_layers_color: number of hidden layers for color network
  160. num_layers_transient: number of hidden layers for transient network
  161. hidden_dim_color: dimension of hidden layers for color network
  162. hidden_dim_transient: dimension of hidden layers for transient network
  163. appearance_embedding_dim: dimension of appearance embedding
  164. transient_embedding_dim: dimension of transient embedding
  165. use_transient_embedding: whether to use transient embedding
  166. use_semantics: whether to use semantic segmentation
  167. num_semantic_classes: number of semantic classes
  168. use_pred_normals: whether to use predicted normals
  169. use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference
  170. spatial_distortion: spatial distortion to apply to the scene
  171. """
  172.  
  173. def __init__(
  174. self,
  175. aabb: TensorType,
  176. num_images: int,
  177. num_layers: int = 3,
  178. hidden_dim: int = 64,
  179. geo_feat_dim: int = 15,
  180. temporal_dim: int = 64,
  181. num_levels: int = 16,
  182. features_per_level: int = 2,
  183. base_resolution: int = 16,
  184. log2_hashmap_size: int = 19,
  185. num_layers_color: int = 4,
  186. num_layers_transient: int = 2,
  187. hidden_dim_color: int = 64,
  188. hidden_dim_transient: int = 64,
  189. appearance_embedding_dim: int = 32,
  190. transient_embedding_dim: int = 16,
  191. use_transient_embedding: bool = False,
  192. use_semantics: bool = False,
  193. num_semantic_classes: int = 100,
  194. use_pred_normals: bool = False,
  195. use_average_appearance_embedding: bool = False,
  196. disable_viewing_dependent: bool = False,
  197. spatial_distortion: Optional[SpatialDistortion] = None,
  198. ) -> None:
  199. super().__init__()
  200.  
  201. self.aabb = Parameter(aabb, requires_grad=False)
  202. self.geo_feat_dim = geo_feat_dim
  203.  
  204. self.spatial_distortion = spatial_distortion
  205. self.num_images = num_images
  206. self.appearance_embedding_dim = appearance_embedding_dim
  207. self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
  208. self.use_average_appearance_embedding = use_average_appearance_embedding
  209. self.use_transient_embedding = use_transient_embedding
  210. self.use_semantics = use_semantics
  211. self.use_pred_normals = use_pred_normals
  212.  
  213. self.direction_encoding = tcnn.Encoding(
  214. n_input_dims=3,
  215. encoding_config={
  216. "otype": "SphericalHarmonics",
  217. "degree": 4,
  218. },
  219. )
  220.  
  221. self.position_encoding = tcnn.Encoding(
  222. n_input_dims=3,
  223. encoding_config={"otype": "Frequency", "n_frequencies": 2},
  224. )
  225.  
  226. feature_dim = num_levels * features_per_level
  227.  
  228. # deformation_field
  229. self.deformation_field = tcnn.Network(
  230. n_input_dims=3,
  231. n_output_dims=3,
  232. network_config={
  233. "otype": "FullyFusedMLP",
  234. "activation": "ReLU",
  235. "output_activation": "None",
  236. "n_neurons": 128,
  237. "n_hidden_layers": 3,
  238. },
  239. )
  240.  
  241. # Explicit encoding for the stationary field. Does not depend on time.
  242. self.stationary_field = tcnn.Encoding(
  243. n_input_dims=3,
  244. encoding_config={
  245. "otype": "HashGrid",
  246. "n_levels": num_levels,
  247. "n_features_per_level": features_per_level,
  248. "log2_hashmap_size": log2_hashmap_size,
  249. "base_resolution": base_resolution,
  250. "per_level_scale": 1.4472692012786865, # base_res * scale ** (level), base level = 0
  251. },
  252. )
  253.  
  254. # MLP for the stationary field.
  255. # (features, t) -> (features)
  256. self.stationary_field_mlp = tcnn.Network(
  257. n_input_dims=feature_dim + 1,
  258. n_output_dims=feature_dim,
  259. network_config={
  260. "otype": "FullyFusedMLP",
  261. "activation": "ReLU",
  262. "output_activation": "None",
  263. "n_neurons": 64,
  264. "n_hidden_layers": 1,
  265. },
  266. )
  267.  
  268. self.newness_field = TemporalGridEncoder(
  269. input_dim=3,
  270. temporal_dim=temporal_dim,
  271. num_levels=num_levels,
  272. level_dim=features_per_level,
  273. base_resolution=base_resolution,
  274. log2_hashmap_size=log2_hashmap_size,
  275. desired_resolution=1024 * (self.aabb.max() - self.aabb.min()),
  276. )
  277.  
  278. self.decomposition_field = TemporalGridEncoder(
  279. input_dim=3,
  280. temporal_dim=temporal_dim,
  281. num_levels=num_levels,
  282. level_dim=features_per_level,
  283. base_resolution=base_resolution,
  284. log2_hashmap_size=log2_hashmap_size,
  285. desired_resolution=1024 * (self.aabb.max() - self.aabb.min()),
  286. )
  287.  
  288. self.decomposition_mlp = tcnn.Network(
  289. n_input_dims=feature_dim,
  290. n_output_dims=3,
  291. network_config={
  292. "otype": "FullyFusedMLP",
  293. "activation": "ReLU",
  294. "output_activation": "None",
  295. "n_neurons": 64,
  296. "n_hidden_layers": 1,
  297. },
  298. )
  299. self._probs = None
  300.  
  301. # Radiance Field (first component for density, the rest for color)
  302. self.mlp_base_decode = tcnn.Network(
  303. n_input_dims=feature_dim,
  304. n_output_dims=1 + self.geo_feat_dim,
  305. network_config={
  306. "otype": "FullyFusedMLP",
  307. "activation": "ReLU",
  308. "output_activation": "None",
  309. "n_neurons": hidden_dim,
  310. "n_hidden_layers": num_layers - 1,
  311. },
  312. )
  313.  
  314. in_dim = self.direction_encoding.n_output_dims + self.geo_feat_dim
  315. if disable_viewing_dependent:
  316. in_dim = self.geo_feat_dim
  317. self.direction_encoding = None
  318. self.mlp_head = tcnn.Network(
  319. n_input_dims=in_dim,
  320. n_output_dims=3,
  321. network_config={
  322. "otype": "FullyFusedMLP",
  323. "activation": "ReLU",
  324. "output_activation": "Sigmoid",
  325. "n_neurons": hidden_dim_color,
  326. "n_hidden_layers": num_layers_color - 1,
  327. },
  328. )
  329.  
  330. def get_density(self, ray_samples: RaySamples) -> Tuple[TensorType, TensorType]:
  331. """Computes and returns the densities."""
  332. if self.spatial_distortion is not None:
  333. positions = ray_samples.frustums.get_positions()
  334. positions = self.spatial_distortion(positions)
  335. positions = (positions + 2.0) / 4.0
  336. else:
  337. positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
  338. positions_flat = positions.view(-1, 3)
  339. assert ray_samples.times is not None, "Time should be included in the input for NeRFPlayer"
  340. times_flat = ray_samples.times.reshape(-1, 1)
  341.  
  342. # 1. Get the deformation field
  343. deformation = self.deformation_field(positions_flat)
  344.  
  345. # Deform the positions
  346. deformed_positions = positions_flat + deformation
  347.  
  348. # 2. Get the stationary field
  349. v_stat = self.stationary_field(positions_flat)
  350. v_deform = self.stationary_field(deformed_positions)
  351. v_stat = self.stationary_field_mlp(torch.cat([v_stat, times_flat], dim=-1))
  352. v_deform = self.stationary_field_mlp(torch.cat([v_deform, times_flat], dim=-1))
  353.  
  354. # 3. Get the newness field
  355. v_new = self.newness_field(positions_flat, times_flat)
  356.  
  357. # 4. Get the decomposition field
  358. v_decomp = self.decomposition_field(positions_flat, times_flat)
  359. probs = self.decomposition_mlp(v_decomp)
  360. probs = torch.softmax(probs, dim=-1)
  361. self._probs = probs
  362.  
  363. # Mix features
  364. # Sizes:
  365. # probs: (batch_size, 3)
  366. # v: (batch_size, 32)
  367. v = (
  368. probs[:, 0].unsqueeze(-1) * v_stat
  369. + probs[:, 1].unsqueeze(-1) * v_deform
  370. + probs[:, 2].unsqueeze(-1) * v_new
  371. )
  372.  
  373. h = self.mlp_base_decode(v).view(*ray_samples.frustums.shape, -1)
  374. density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1)
  375.  
  376. # Rectifying the density with an exponential is much more stable than a ReLU or
  377. # softplus, because it enables high post-activation (float32) density outputs
  378. # from smaller internal (float16) parameters.
  379. density = trunc_exp(density_before_activation.to(positions))
  380. return density, base_mlp_out
  381.  
  382. def get_outputs(
  383. self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None
  384. ) -> Dict[FieldHeadNames, TensorType]:
  385. assert density_embedding is not None
  386. directions = shift_directions_for_tcnn(ray_samples.frustums.directions)
  387. directions_flat = directions.view(-1, 3)
  388.  
  389. if self.direction_encoding is not None:
  390. d = self.direction_encoding(directions_flat)
  391. if density_embedding is None:
  392. positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
  393. h = torch.cat([d, positions.view(-1, 3)], dim=-1)
  394. else:
  395. h = torch.cat([d, density_embedding.view(-1, self.geo_feat_dim)], dim=-1)
  396. else:
  397. # viewing direction is disabled
  398. if density_embedding is None:
  399. positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
  400. h = positions.view(-1, 3)
  401. else:
  402. h = density_embedding.view(-1, self.geo_feat_dim)
  403.  
  404. rgb = self.mlp_head(h).view(*ray_samples.frustums.directions.shape[:-1], -1).to(directions)
  405.  
  406. outputs = {FieldHeadNames.RGB: rgb}
  407.  
  408. if self._probs is not None:
  409. outputs[FieldHeadNames.PROBS] = self._probs.view(*ray_samples.frustums.directions.shape[:-1], -1).to(
  410. directions
  411. )
  412. self._probs = None
  413.  
  414. return outputs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement