Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out global hyperparameters #172

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions notebooks/bayes3d_paper/run_ycbv_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
#!/usr/bin/env python
import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import fire
import genjax
import jax
import jax.numpy as jnp
from b3d import Mesh, Pose
from b3d.chisight.gen3d.model import (
dynamic_object_generative_model,
make_colors_choicemap,
make_depth_nonreturn_prob_choicemap,
make_visibility_prob_choicemap,
)
from genjax import Pytree
from tqdm import tqdm


def run_tracking(scene=None, object=None, debug=False):
import importlib
import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import genjax
import jax
import jax.numpy as jnp
from b3d import Mesh, Pose
from b3d.chisight.gen3d.model import (
dynamic_object_generative_model,
make_colors_choicemap,
make_depth_nonreturn_prob_choicemap,
make_visibility_prob_choicemap,
)
from genjax import Pytree
from tqdm import tqdm

importlib.reload(b3d.mesh)
importlib.reload(b3d.io.data_loader)
importlib.reload(b3d.utils)
importlib.reload(b3d.renderer.renderer_original)

FRAME_RATE = 50

ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")

b3d.rr_init("run_ycbv_evaluation")

if scene is None:
Expand Down
156 changes: 68 additions & 88 deletions notebooks/bayes3d_paper/tester.ipynb

Large diffs are not rendered by default.

208 changes: 100 additions & 108 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey

import b3d
from b3d.chisight.gen3d.pixel_kernels import is_unexplained
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
RenormalizedLaplacePixelColorDistribution,
UniformPixelColorDistribution,
Expand All @@ -20,6 +18,7 @@
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
FullPixelRGBDDistribution,
PixelRGBDDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.projection import PixelsPointsAssociation

Expand All @@ -33,19 +32,14 @@ class ImageKernel(genjax.ExactDensity):
The support of the distribution is [0, 1]^3 x [near, far].
"""

near: float = Pytree.static()
far: float = Pytree.static()
image_height: int = Pytree.static()
image_width: int = Pytree.static()

def get_pixels_points_association(
self, transformed_points, hyperparams: Mapping
) -> PixelsPointsAssociation:
return PixelsPointsAssociation.from_points_and_intrinsics(
transformed_points,
hyperparams["intrinsics"],
self.image_height,
self.image_width,
hyperparams["intrinsics"]["image_height"].const,
hyperparams["intrinsics"]["image_width"].const,
)

@abstractmethod
Expand All @@ -64,11 +58,6 @@ def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:

@Pytree.dataclass
class NoOcclusionPerVertexImageKernel(ImageKernel):
near: float = Pytree.static()
far: float = Pytree.static()
image_height: int = Pytree.static()
image_width: int = Pytree.static()

def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArray:
"""Generate latent RGBD image by projecting the vertices directly to the image
plane, without checking for occlusions.
Expand Down Expand Up @@ -96,17 +85,24 @@ def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArr
[pixel_latent_rgbd, pixel_latent_depth[..., None]], axis=-1
)

keys = jax.random.split(key, (self.image_height, self.image_width))
keys = jax.random.split(
key,
(
hyperparams["intrinsics"]["image_height"].const,
hyperparams["intrinsics"]["image_width"].const,
),
)
return jax.vmap(
jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0)),
in_axes=(0, 0, None, None, 0, 0),
jax.vmap(vertex_kernel.sample, in_axes=(0, 0, None, None, 0, 0, None)),
in_axes=(0, 0, None, None, 0, 0, None),
)(
keys,
pixel_latent_rgbd,
state["color_scale"],
state["depth_scale"],
pixel_visibility_prob,
pixel_depth_nonreturn_prob,
hyperparams["intrinsics"],
)

def logpdf(
Expand All @@ -122,13 +118,14 @@ def logpdf(
(state["colors"], transformed_points[..., 2, None]), axis=-1
)

scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0))(
scores = jax.vmap(vertex_kernel.logpdf, in_axes=(0, 0, None, None, 0, 0, None))(
observed_rgbd_per_point,
latent_rgbd_per_point,
state["color_scale"],
state["depth_scale"],
state["visibility_prob"],
state["depth_nonreturn_prob"],
hyperparams["intrinsics"],
)
# Points that don't hit the camera plane should not contribute to the score.
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)
Expand All @@ -146,93 +143,88 @@ def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
return FullPixelRGBDDistribution(
RenormalizedLaplacePixelColorDistribution(),
UniformPixelColorDistribution(),
RenormalizedLaplacePixelDepthDistribution(self.near, self.far),
UniformPixelDepthDistribution(self.near, self.far),
)


@Pytree.dataclass
class OldNoOcclusionPerVertexImageKernel(ImageKernel):
near: float = Pytree.static()
far: float = Pytree.static()
image_height: int = Pytree.static()
image_width: int = Pytree.static()

@jax.jit
def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArray:
return jnp.zeros(
(
self.image_height,
self.image_width,
4,
)
)

@jax.jit
def logpdf(
self, observed_rgbd: FloatArray, state: Mapping, hyperparams: Mapping
) -> FloatArray:
return self.info_func(observed_rgbd, state, hyperparams)["scores"].sum()

def info_from_trace(self, trace):
return self.info_func(
trace.get_choices()["rgbd"],
trace.get_retval()["new_state"],
trace.get_args()[0],
)

def info_func(self, observed_rgbd, state, hyperparams):
transformed_points = state["pose"].apply(hyperparams["vertices"])
projected_pixel_coordinates = jnp.rint(
b3d.xyz_to_pixel_coordinates(
transformed_points,
hyperparams["intrinsics"]["fx"],
hyperparams["intrinsics"]["fy"],
hyperparams["intrinsics"]["cx"],
hyperparams["intrinsics"]["cy"],
)
).astype(jnp.int32)

observed_rgbd_masked = observed_rgbd[
projected_pixel_coordinates[..., 0], projected_pixel_coordinates[..., 1]
]

color_visible_branch_score = jax.scipy.stats.laplace.logpdf(
observed_rgbd_masked[..., :3], state["colors"], state["color_scale"]
).sum(axis=-1)
color_not_visible_score = jnp.log(1 / 1.0**3)
color_score = jnp.logaddexp(
color_visible_branch_score + jnp.log(state["visibility_prob"]),
color_not_visible_score + jnp.log(1 - state["visibility_prob"]),
)

depth_visible_branch_score = jax.scipy.stats.laplace.logpdf(
observed_rgbd_masked[..., 3],
transformed_points[..., 2],
state["depth_scale"],
)
depth_not_visible_score = jnp.log(1 / 1.0)
_depth_score = jnp.logaddexp(
depth_visible_branch_score + jnp.log(state["visibility_prob"]),
depth_not_visible_score + jnp.log(1 - state["visibility_prob"]),
)
is_depth_non_return = observed_rgbd_masked[..., 3] < 0.0001

non_return_probability = 0.05
depth_score = jnp.where(
is_depth_non_return, jnp.log(non_return_probability), _depth_score
)

lmbda = 0.5
scores = lmbda * color_score + (1.0 - lmbda) * depth_score
return {
"scores": scores,
"observed_rgbd_masked": observed_rgbd_masked,
}

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# Note: The distributions were originally defined for per-pixel computation,
# but they should work for per-vertex computation as well, except that
# they don't expect observed_rgbd to be invalid, so we need to handle
# that manually.
raise NotImplementedError
RenormalizedLaplacePixelDepthDistribution(),
UniformPixelDepthDistribution(),
)


# @Pytree.dataclass
# class OldNoOcclusionPerVertexImageKernel(ImageKernel):
# @jax.jit
# def sample(self, key: PRNGKey, state: Mapping, hyperparams: Mapping) -> FloatArray:
# return jnp.zeros(
# (
# hyperparams["intrinsics"]["image_height"].const,
# hyperparams["intrinsics"]["image_width"].const,
# 4,
# )
# )

# @jax.jit
# def logpdf(
# self, observed_rgbd: FloatArray, state: Mapping, hyperparams: Mapping
# ) -> FloatArray:
# return self.info_func(observed_rgbd, state, hyperparams)["scores"].sum()

# def info_from_trace(self, trace):
# return self.info_func(
# trace.get_choices()["rgbd"],
# trace.get_retval()["new_state"],
# trace.get_args()[0],
# )

# def info_func(self, observed_rgbd, state, hyperparams):
# transformed_points = state["pose"].apply(hyperparams["vertices"])
# projected_pixel_coordinates = jnp.rint(
# b3d.xyz_to_pixel_coordinates(
# transformed_points,
# hyperparams["intrinsics"]["fx"],
# hyperparams["intrinsics"]["fy"],
# hyperparams["intrinsics"]["cx"],
# hyperparams["intrinsics"]["cy"],
# )
# ).astype(jnp.int32)

# observed_rgbd_masked = observed_rgbd[
# projected_pixel_coordinates[..., 0], projected_pixel_coordinates[..., 1]
# ]

# color_visible_branch_score = jax.scipy.stats.laplace.logpdf(
# observed_rgbd_masked[..., :3], state["colors"], state["color_scale"]
# ).sum(axis=-1)
# color_not_visible_score = jnp.log(1 / 1.0**3)
# color_score = jnp.logaddexp(
# color_visible_branch_score + jnp.log(state["visibility_prob"]),
# color_not_visible_score + jnp.log(1 - state["visibility_prob"]),
# )

# depth_visible_branch_score = jax.scipy.stats.laplace.logpdf(
# observed_rgbd_masked[..., 3],
# transformed_points[..., 2],
# state["depth_scale"],
# )
# depth_not_visible_score = jnp.log(1 / 1.0)
# _depth_score = jnp.logaddexp(
# depth_visible_branch_score + jnp.log(state["visibility_prob"]),
# depth_not_visible_score + jnp.log(1 - state["visibility_prob"]),
# )
# is_depth_non_return = observed_rgbd_masked[..., 3] < 0.0001

# non_return_probability = 0.05
# depth_score = jnp.where(
# is_depth_non_return, jnp.log(non_return_probability), _depth_score
# )

# lmbda = 0.5
# scores = lmbda * color_score + (1.0 - lmbda) * depth_score
# return {
# "scores": scores,
# "observed_rgbd_masked": observed_rgbd_masked,
# }

# def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# # Note: The distributions were originally defined for per-pixel computation,
# # but they should work for per-vertex computation as well, except that
# # they don't expect observed_rgbd to be invalid, so we need to handle
# # that manually.
# raise NotImplementedError
3 changes: 3 additions & 0 deletions src/b3d/chisight/gen3d/inference_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def propose_a_points_attributes(
obs_rgbd_kernel=hyperparams["image_kernel"].get_rgbd_vertex_kernel(),
color_scale=new_state["color_scale"],
depth_scale=new_state["depth_scale"],
intrinsics=hyperparams["intrinsics"],
inference_hyperparams=inference_hyperparams,
)

Expand All @@ -181,6 +182,7 @@ def _propose_a_points_attributes(
obs_rgbd_kernel,
color_scale,
depth_scale,
intrinsics,
inference_hyperparams,
):
k1, k2 = split(key, 2)
Expand All @@ -198,6 +200,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
depth_scale=depth_scale,
visibility_prob=visprob,
depth_nonreturn_prob=dnrprob,
intrinsics=intrinsics,
)
return (
visprob_transition_score
Expand Down
27 changes: 0 additions & 27 deletions src/b3d/chisight/gen3d/pixel_kernels/__init__.py

This file was deleted.

Loading
Loading