From db4b4628cf8f52408e93695f84e922ce699a8f8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1vid=20Komorowicz?= Date: Thu, 16 Jan 2025 19:01:47 +0100 Subject: [PATCH] Add correct transient rendering and loss to nerfacto Disable by default --- nerfstudio/models/nerfacto.py | 48 ++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index 4c73fe7f6d..62d311dffc 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -41,7 +41,13 @@ scale_gradients_by_distance_squared, ) from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler -from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer +from nerfstudio.model_components.renderers import ( + AccumulationRenderer, + DepthRenderer, + NormalsRenderer, + RGBRenderer, + UncertaintyRenderer, +) from nerfstudio.model_components.scene_colliders import NearFarCollider from nerfstudio.model_components.shaders import NormalsShader from nerfstudio.models.base_model import Model, ModelConfig @@ -63,9 +69,9 @@ class NerfactoModelConfig(ModelConfig): """Dimension of hidden layers""" hidden_dim_color: int = 64 """Dimension of hidden layers for color network""" - use_transient_embedding: bool = True + use_transient_embedding: bool = False """Whether to use an transient embedding.""" - hidden_dim_transient: int = 128 + hidden_dim_transient: int = 64 """Dimension of hidden layers for transient network""" transient_embed_dim: int = 16 """Dimension of the transient embedding.""" @@ -240,6 +246,7 @@ def update_schedule(step): self.renderer_accumulation = AccumulationRenderer() self.renderer_depth = DepthRenderer(method="median") self.renderer_expected_depth = DepthRenderer(method="expected") + self.renderer_uncertainty = UncertaintyRenderer() self.renderer_normals = NormalsRenderer() # shaders @@ -311,11 +318,25 @@ def get_outputs(self, ray_bundle: RayBundle): if self.config.use_gradient_scaling: field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples) - weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY]) + if self.training and self.config.use_transient_embedding: + static_density = field_outputs[FieldHeadNames.DENSITY] + transient_density = field_outputs[FieldHeadNames.TRANSIENT_DENSITY] + weights_static = ray_samples.get_weights(static_density) + weights_transient = ray_samples.get_weights(transient_density) + weights = weights_static + rgb_static_component = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights_static) + rgb_transient_component = self.renderer_rgb( + rgb=field_outputs[FieldHeadNames.TRANSIENT_RGB], weights=weights_transient + ) + rgb = rgb_static_component + rgb_transient_component + else: + weights_static = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY]) + weights = weights_static + rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights) + weights_list.append(weights) ray_samples_list.append(ray_samples) - rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights) with torch.no_grad(): depth = self.renderer_depth(weights=weights, ray_samples=ray_samples) expected_depth = self.renderer_expected_depth(weights=weights, ray_samples=ray_samples) @@ -351,6 +372,13 @@ def get_outputs(self, ray_bundle: RayBundle): for i in range(self.config.num_proposal_iterations): outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i]) + + # transients + if self.training and self.config.use_transient_embedding: + uncertainty = self.renderer_uncertainty(field_outputs[FieldHeadNames.UNCERTAINTY], weights_transient) + outputs["uncertainty"] = uncertainty + 0.1 # NOTE(ethan): this is the uncertainty min + outputs["density_transient"] = field_outputs[FieldHeadNames.TRANSIENT_DENSITY] + return outputs def get_metrics_dict(self, outputs, batch): @@ -375,7 +403,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None): gt_image=image, ) - loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb) + if self.training and self.config.use_transient_embedding: + # transient loss + betas = outputs["uncertainty"] + loss_dict["uncertainty_loss"] = 3 + torch.log(betas).mean() + loss_dict["density_loss"] = 0.01 * outputs["density_transient"].mean() + loss_dict["rgb_loss"] = (((gt_rgb - pred_rgb) ** 2).sum(-1) / (betas[..., 0] ** 2)).mean() + else: + loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb) + if self.training: loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss( outputs["weights_list"], outputs["ray_samples_list"]