Skip to content

Commit

Permalink
Add correct transient rendering and loss to nerfacto
Browse files Browse the repository at this point in the history
Disable by default
  • Loading branch information
Dawars committed Jan 17, 2025
1 parent 84881f4 commit db4b462
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@
scale_gradients_by_distance_squared,
)
from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
DepthRenderer,
NormalsRenderer,
RGBRenderer,
UncertaintyRenderer,
)
from nerfstudio.model_components.scene_colliders import NearFarCollider
from nerfstudio.model_components.shaders import NormalsShader
from nerfstudio.models.base_model import Model, ModelConfig
Expand All @@ -63,9 +69,9 @@ class NerfactoModelConfig(ModelConfig):
"""Dimension of hidden layers"""
hidden_dim_color: int = 64
"""Dimension of hidden layers for color network"""
use_transient_embedding: bool = True
use_transient_embedding: bool = False
"""Whether to use an transient embedding."""
hidden_dim_transient: int = 128
hidden_dim_transient: int = 64
"""Dimension of hidden layers for transient network"""
transient_embed_dim: int = 16
"""Dimension of the transient embedding."""
Expand Down Expand Up @@ -240,6 +246,7 @@ def update_schedule(step):
self.renderer_accumulation = AccumulationRenderer()
self.renderer_depth = DepthRenderer(method="median")
self.renderer_expected_depth = DepthRenderer(method="expected")
self.renderer_uncertainty = UncertaintyRenderer()
self.renderer_normals = NormalsRenderer()

# shaders
Expand Down Expand Up @@ -311,11 +318,25 @@ def get_outputs(self, ray_bundle: RayBundle):
if self.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
if self.training and self.config.use_transient_embedding:
static_density = field_outputs[FieldHeadNames.DENSITY]
transient_density = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]
weights_static = ray_samples.get_weights(static_density)
weights_transient = ray_samples.get_weights(transient_density)
weights = weights_static
rgb_static_component = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights_static)
rgb_transient_component = self.renderer_rgb(
rgb=field_outputs[FieldHeadNames.TRANSIENT_RGB], weights=weights_transient
)
rgb = rgb_static_component + rgb_transient_component
else:
weights_static = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
weights = weights_static
rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)

weights_list.append(weights)
ray_samples_list.append(ray_samples)

rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
with torch.no_grad():
depth = self.renderer_depth(weights=weights, ray_samples=ray_samples)
expected_depth = self.renderer_expected_depth(weights=weights, ray_samples=ray_samples)
Expand Down Expand Up @@ -351,6 +372,13 @@ def get_outputs(self, ray_bundle: RayBundle):

for i in range(self.config.num_proposal_iterations):
outputs[f"prop_depth_{i}"] = self.renderer_depth(weights=weights_list[i], ray_samples=ray_samples_list[i])

# transients
if self.training and self.config.use_transient_embedding:
uncertainty = self.renderer_uncertainty(field_outputs[FieldHeadNames.UNCERTAINTY], weights_transient)
outputs["uncertainty"] = uncertainty + 0.1 # NOTE(ethan): this is the uncertainty min
outputs["density_transient"] = field_outputs[FieldHeadNames.TRANSIENT_DENSITY]

return outputs

def get_metrics_dict(self, outputs, batch):
Expand All @@ -375,7 +403,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None):
gt_image=image,
)

loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)
if self.training and self.config.use_transient_embedding:
# transient loss
betas = outputs["uncertainty"]
loss_dict["uncertainty_loss"] = 3 + torch.log(betas).mean()
loss_dict["density_loss"] = 0.01 * outputs["density_transient"].mean()
loss_dict["rgb_loss"] = (((gt_rgb - pred_rgb) ** 2).sum(-1) / (betas[..., 0] ** 2)).mean()
else:
loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)

if self.training:
loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss(
outputs["weights_list"], outputs["ray_samples_list"]
Expand Down

0 comments on commit db4b462

Please sign in to comment.