Advertisement
iSach

nerfplayer.py

Apr 28th, 2023
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.91 KB | None | 0 0
  1. # Copyright 2022 The Nerfstudio Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14.  
  15. """
  16. NeRFPlayer (https://arxiv.org/abs/2210.15947) complete implementation with nerfacto backbone.
  17. """
  18.  
  19. from __future__ import annotations
  20.  
  21. import functools
  22. from dataclasses import dataclass, field
  23. from typing import Dict, List, Type
  24.  
  25. import numpy as np
  26. import torch
  27. from torchmetrics import PeakSignalNoiseRatio
  28. from torchmetrics.functional import structural_similarity_index_measure
  29. from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
  30. from typing_extensions import Literal
  31.  
  32. from nerfstudio.cameras.rays import RayBundle
  33. from nerfstudio.field_components.field_heads import FieldHeadNames
  34. from nerfstudio.field_components.spatial_distortions import SceneContraction
  35. from nerfstudio.fields.nerfplayer_field import (
  36. NerfplayerField,
  37. TemporalHashMLPDensityField,
  38. )
  39. from nerfstudio.model_components.losses import (
  40. MSELoss,
  41. interlevel_loss,
  42. orientation_loss,
  43. pred_normal_loss,
  44. DepthLossType,
  45. depth_loss,
  46. )
  47. from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler
  48. from nerfstudio.model_components.renderers import (
  49. AccumulationRenderer,
  50. DepthRenderer,
  51. NormalsRenderer,
  52. RGBRenderer,
  53. DecompositionRenderer,
  54. )
  55. from nerfstudio.model_components.scene_colliders import NearFarCollider, AABBBoxCollider
  56. from nerfstudio.model_components.shaders import NormalsShader
  57. from nerfstudio.models.base_model import Model
  58. from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig
  59. from nerfstudio.utils import colormaps
  60. from nerfstudio.utils.dynmetric import DynMetric
  61.  
  62.  
  63. @dataclass
  64. class NerfplayerModelConfig(NerfactoModelConfig):
  65. """Nerfplayer Model Config with Nerfacto backbone"""
  66.  
  67. _target: Type = field(default_factory=lambda: NerfplayerModel)
  68. near_plane: float = 0.05
  69. """How far along the ray to start sampling."""
  70. far_plane: float = 1000.0
  71. """How far along the ray to stop sampling."""
  72. train_background_color: Literal["random", "black", "white"] = "random"
  73. """The training background color that is given to untrained areas."""
  74. eval_background_color: Literal["random", "black", "white", "last_sample"] = "white"
  75. """The training background color that is given to untrained areas."""
  76. num_levels: int = 16
  77. """Hashing grid parameter."""
  78. features_per_level: int = 2
  79. """Hashing grid parameter."""
  80. log2_hashmap_size: int = 17
  81. """Hashing grid parameter."""
  82. temporal_dim: int = 64
  83. """Hashing grid parameter. A higher temporal dim means a higher temporal frequency."""
  84. proposal_net_args_list: List[Dict] = field(
  85. default_factory=lambda: [
  86. {"hidden_dim": 16, "temporal_dim": 32, "log2_hashmap_size": 17, "num_levels": 5, "max_res": 64},
  87. {"hidden_dim": 16, "temporal_dim": 32, "log2_hashmap_size": 17, "num_levels": 5, "max_res": 256},
  88. ]
  89. )
  90. """Arguments for the proposal density fields."""
  91. disable_viewing_dependent: bool = True
  92. """Disable viewing dependent effects."""
  93. distortion_loss_mult: float = 1e-3
  94. """Distortion loss multiplier."""
  95. temporal_tv_weight: float = 1.0
  96. """Temporal TV balancing weight for feature channels."""
  97. depth_weight: float = 0.05
  98. """depth loss balancing weight for feature channels."""
  99. is_euclidean_depth: bool = True
  100. """Whether input depth maps are Euclidean distances (or z-distances)."""
  101. depth_sigma: float = 0.01
  102. """Uncertainty around depth values in meters (defaults to 1cm)."""
  103. should_decay_sigma: bool = False
  104. """Whether to exponentially decay sigma."""
  105. starting_depth_sigma: float = 0.2
  106. """Starting uncertainty around depth values in meters (defaults to 0.2m)."""
  107. sigma_decay_rate: float = 0.99985
  108. """Rate of exponential decay."""
  109. depth_loss_type: DepthLossType = DepthLossType.DS_NERF
  110. """Depth loss type."""
  111. prob_reg_loss_mult: float = 0.001 # Paper: 0.1, seems very high compared to experimental results done here.
  112. """Probability regularization loss multiplier."""
  113.  
  114.  
  115. class NerfplayerModel(NerfactoModel):
  116. """Nerfplayer model with Nerfacto backbone.
  117.  
  118. Args:
  119. config: Nerfplayer configuration to instantiate model
  120. """
  121.  
  122. config: NerfplayerModelConfig
  123.  
  124. def populate_modules(self):
  125. """Set the fields and modules."""
  126. Model.populate_modules(self)
  127.  
  128. if self.config.disable_scene_contraction:
  129. scene_contraction = None
  130. else:
  131. scene_contraction = SceneContraction(order=float("inf"))
  132.  
  133. if self.config.should_decay_sigma:
  134. self.depth_sigma = torch.tensor([self.config.starting_depth_sigma])
  135. else:
  136. self.depth_sigma = torch.tensor([self.config.depth_sigma])
  137.  
  138. # Fields
  139. self.field = NerfplayerField(
  140. self.scene_box.aabb,
  141. temporal_dim=self.config.temporal_dim,
  142. num_levels=self.config.num_levels,
  143. features_per_level=self.config.features_per_level,
  144. log2_hashmap_size=self.config.log2_hashmap_size,
  145. spatial_distortion=scene_contraction,
  146. num_images=self.num_train_data,
  147. use_pred_normals=self.config.predict_normals,
  148. use_average_appearance_embedding=self.config.use_average_appearance_embedding,
  149. disable_viewing_dependent=self.config.disable_viewing_dependent,
  150. )
  151.  
  152. self.density_fns = []
  153. num_prop_nets = self.config.num_proposal_iterations
  154. # Build the proposal network(s)
  155. self.proposal_networks = torch.nn.ModuleList()
  156. if self.config.use_same_proposal_network:
  157. assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed."
  158. prop_net_args = self.config.proposal_net_args_list[0]
  159. network = TemporalHashMLPDensityField(
  160. self.scene_box.aabb, spatial_distortion=scene_contraction, **prop_net_args
  161. )
  162. self.proposal_networks.append(network)
  163. self.density_fns.extend([network.density_fn for _ in range(num_prop_nets)])
  164. else:
  165. for i in range(num_prop_nets):
  166. prop_net_args = self.config.proposal_net_args_list[min(i, len(self.config.proposal_net_args_list) - 1)]
  167. network = TemporalHashMLPDensityField(
  168. self.scene_box.aabb,
  169. spatial_distortion=scene_contraction,
  170. **prop_net_args,
  171. )
  172. self.proposal_networks.append(network)
  173. self.density_fns.extend([network.density_fn for network in self.proposal_networks])
  174.  
  175. # Samplers
  176. update_schedule = lambda step: np.clip(
  177. np.interp(step, [0, self.config.proposal_warmup], [0, self.config.proposal_update_every]),
  178. 1,
  179. self.config.proposal_update_every,
  180. )
  181. self.proposal_sampler = ProposalNetworkSampler(
  182. num_nerf_samples_per_ray=self.config.num_nerf_samples_per_ray,
  183. num_proposal_samples_per_ray=self.config.num_proposal_samples_per_ray,
  184. num_proposal_network_iterations=self.config.num_proposal_iterations,
  185. single_jitter=self.config.use_single_jitter,
  186. update_sched=update_schedule,
  187. )
  188.  
  189. # Collider
  190. if self.config.disable_scene_contraction:
  191. self.collider = AABBBoxCollider(self.scene_box)
  192. else:
  193. self.collider = NearFarCollider(near_plane=self.config.near_plane, far_plane=self.config.far_plane)
  194.  
  195. self.background_color = self.config.train_background_color
  196.  
  197. # renderers
  198. self.renderer_rgb = RGBRenderer(background_color=self.config.train_background_color)
  199. self.renderer_accumulation = AccumulationRenderer()
  200. self.renderer_depth = DepthRenderer(method="expected") # for depth loss
  201. self.renderer_normals = NormalsRenderer()
  202. self.renderer_probs = DecompositionRenderer()
  203.  
  204. # shaders
  205. self.normals_shader = NormalsShader()
  206.  
  207. # losses
  208. self.rgb_loss = MSELoss()
  209.  
  210. # metrics
  211. self.psnr = PeakSignalNoiseRatio(data_range=1.0)
  212. self.ssim = structural_similarity_index_measure
  213. self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)
  214. self.dynmetric = DynMetric(self.psnr, self.ssim, self.lpips, "cuda")
  215. self.temporal_distortion = True # for viewer
  216.  
  217. def get_outputs(self, ray_bundle: RayBundle):
  218. assert ray_bundle.times is not None, "Time not provided."
  219. ray_samples, weights_list, ray_samples_list = self.proposal_sampler(
  220. ray_bundle, density_fns=[functools.partial(f, times=ray_bundle.times) for f in self.density_fns]
  221. )
  222. field_outputs = self.field(ray_samples, compute_normals=self.config.predict_normals)
  223. weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
  224. weights_list.append(weights)
  225. ray_samples_list.append(ray_samples)
  226.  
  227. if self.training:
  228. self.renderer_rgb.background_color = self.config.train_background_color
  229. else:
  230. self.renderer_rgb.background_color = self.config.eval_background_color
  231. rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
  232. depth = self.renderer_depth(weights=weights, ray_samples=ray_samples)
  233. accumulation = self.renderer_accumulation(weights=weights)
  234.  
  235. outputs = {
  236. "rgb": rgb,
  237. "accumulation": accumulation,
  238. "depth": depth,
  239. }
  240. if FieldHeadNames.PROBS in field_outputs:
  241. probs = self.renderer_probs(
  242. probs=field_outputs[FieldHeadNames.PROBS],
  243. weights=weights,
  244. )
  245. outputs["probs"] = probs
  246.  
  247. if self.config.predict_normals:
  248. outputs["normals"] = self.normals_shader(
  249. self.renderer_normals(normals=field_outputs[FieldHeadNames.NORMALS], weights=weights)
  250. )
  251. outputs["pred_normals"] = self.normals_shader(
  252. self.renderer_normals(field_outputs[FieldHeadNames.PRED_NORMALS], weights=weights)
  253. )
  254.  
  255. # These use a lot of GPU memory, so we avoid storing them for eval.
  256. if self.training:
  257. outputs["weights_list"] = weights_list
  258. outputs["ray_samples_list"] = ray_samples_list
  259.  
  260. if self.training and self.config.predict_normals:
  261. outputs["rendered_orientation_loss"] = orientation_loss(
  262. weights.detach(), field_outputs[FieldHeadNames.NORMALS], ray_bundle.directions
  263. )
  264.  
  265. outputs["rendered_pred_normal_loss"] = pred_normal_loss(
  266. weights.detach(),
  267. field_outputs[FieldHeadNames.NORMALS].detach(),
  268. field_outputs[FieldHeadNames.PRED_NORMALS],
  269. )
  270.  
  271. for i in range(self.config.num_proposal_iterations):
  272. outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i])
  273.  
  274. if ray_bundle.metadata is not None and "directions_norm" in ray_bundle.metadata:
  275. outputs["directions_norm"] = ray_bundle.metadata["directions_norm"]
  276.  
  277. return outputs
  278.  
  279. def _get_sigma(self) -> TensorType[0]:
  280. if not self.config.should_decay_sigma:
  281. return self.depth_sigma
  282.  
  283. self.depth_sigma = torch.maximum( # pylint: disable=attribute-defined-outside-init
  284. self.config.sigma_decay_rate * self.depth_sigma, torch.tensor([self.config.depth_sigma])
  285. )
  286. return self.depth_sigma
  287.  
  288. def get_metrics_dict(self, outputs, batch):
  289. metrics_dict = super().get_metrics_dict(outputs, batch)
  290.  
  291. if "depth_image" in batch.keys() and self.training:
  292. metrics_dict["depth_loss"] = 0.0
  293. sigma = self._get_sigma().to(self.device)
  294. termination_depth = batch["depth_image"].to(self.device)
  295. # Iterate over networks (proposal + nerf)
  296. for i in range(len(outputs["weights_list"])):
  297. metrics_dict["depth_loss"] += depth_loss(
  298. weights=outputs["weights_list"][i],
  299. ray_samples=outputs["ray_samples_list"][i],
  300. termination_depth=termination_depth,
  301. predicted_depth=outputs["depth"],
  302. sigma=sigma,
  303. directions_norm=outputs["directions_norm"],
  304. is_euclidean=self.config.is_euclidean_depth,
  305. depth_loss_type=self.config.depth_loss_type,
  306. ) / len(outputs["weights_list"])
  307.  
  308. return metrics_dict
  309.  
  310. def get_loss_dict(self, outputs, batch, metrics_dict=None):
  311. loss_dict = {}
  312. image = batch["image"].to(self.device)
  313. loss_dict["rgb_loss"] = self.rgb_loss(image, outputs["rgb"])
  314. if self.training:
  315. loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss(
  316. outputs["weights_list"], outputs["ray_samples_list"]
  317. )
  318. assert metrics_dict is not None and "distortion" in metrics_dict
  319. loss_dict["distortion_loss"] = self.config.distortion_loss_mult * metrics_dict["distortion"]
  320.  
  321. if "depth_image" in batch.keys() and self.config.depth_weight > 0:
  322. loss_dict["depth_loss"] = self.config.depth_weight * metrics_dict["depth_loss"]
  323.  
  324. if self.config.temporal_tv_weight > 0:
  325. loss_dict["temporal_tv_loss"] = self.field.newness_field.get_temporal_tv_loss()
  326. loss_dict["temporal_tv_loss"] += self.field.decomposition_field.get_temporal_tv_loss()
  327. for net in self.proposal_networks:
  328. loss_dict["temporal_tv_loss"] += net.encoding.get_temporal_tv_loss()
  329. loss_dict["temporal_tv_loss"] *= self.config.temporal_tv_weight
  330. loss_dict["temporal_tv_loss"] /= (
  331. len(self.proposal_networks) + 2
  332. ) # Average over all networks: 2 for field, 1 for each proposal
  333.  
  334. if "probs" in outputs:
  335. # 0=static, 1=deform, 2=new
  336. probs = outputs["probs"].view(-1, 3)
  337. probs_mean = probs.mean(dim=0)
  338. prob_loss = 0.01 * probs_mean[1] + probs_mean[2]
  339. loss_dict["prob_loss"] = prob_loss * self.config.prob_reg_loss_mult
  340.  
  341. return loss_dict
  342.  
  343. def get_image_metrics_and_images(
  344. self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]
  345. ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
  346. metrics_dict, images_dict = super().get_image_metrics_and_images(outputs, batch)
  347. if "depth_image" in batch.keys():
  348. ground_truth_depth = batch["depth_image"]
  349. if not self.config.is_euclidean_depth:
  350. ground_truth_depth = ground_truth_depth * outputs["directions_norm"]
  351.  
  352. ground_truth_depth_colormap = colormaps.apply_depth_colormap(ground_truth_depth)
  353. depth = images_dict["depth"]
  354. images_dict["depth"] = torch.cat([ground_truth_depth_colormap, depth], dim=1)
  355.  
  356. image = batch["image"].to(outputs["rgb"].device)
  357. rgb = outputs["rgb"]
  358. image = torch.moveaxis(image, -1, 0)[None, ...]
  359. rgb = torch.moveaxis(rgb, -1, 0)[None, ...]
  360.  
  361. return metrics_dict, images_dict
  362.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement