Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add renormalized laplace and gaussian distribution and kernels #162

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import jax.numpy as jnp
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey
from jax.random import split
from tensorflow_probability.substrates import jax as tfp

from b3d.modeling_utils import (
_FIXED_COLOR_UNIFORM_WINDOW,
PythonMixtureDistribution,
renormalized_laplace,
truncated_laplace,
)

Expand Down Expand Up @@ -73,6 +75,60 @@ def logpdf_per_channel(
raise NotImplementedError


@Pytree.dataclass
class RenormalizedGaussianPixelColorDistribution(PixelColorDistribution):
"""
Sample a color from a renormalized Gaussian distribution centered around the given
latent_color (rgb value), given the color_scale (stddev of the Gaussian).

The support of the distribution is ([0, 1]^3).
"""

def sample(self, key, latent_color, color_scale, *args, **kwargs):
return jax.vmap(
genjax.truncated_normal.sample, in_axes=(0, 0, None, None, None)
)(
split(key, latent_color.shape[0]),
latent_color,
color_scale,
COLOR_MIN_VAL,
COLOR_MAX_VAL,
)

def logpdf_per_channel(
self, observed_color, latent_color, color_scale, *args, **kwargs
):
return jax.vmap(
genjax.truncated_normal.logpdf, in_axes=(0, 0, None, None, None)
)(observed_color, latent_color, color_scale, COLOR_MIN_VAL, COLOR_MAX_VAL)


@Pytree.dataclass
class RenormalizedLaplacePixelColorDistribution(PixelColorDistribution):
"""
Sample a color from a renormalized Laplace distribution centered around the given
latent_color (rgb value), given the color_scale (scale of the laplace).

The support of the distribution is ([0, 1]^3).
"""

def sample(self, key, latent_color, color_scale, *args, **kwargs):
return jax.vmap(renormalized_laplace.sample, in_axes=(0, 0, None, None, None))(
split(key, latent_color.shape[0]),
latent_color,
color_scale,
COLOR_MIN_VAL,
COLOR_MAX_VAL,
)

def logpdf_per_channel(
self, observed_color, latent_color, color_scale, *args, **kwargs
):
return jax.vmap(renormalized_laplace.logpdf, in_axes=(0, 0, None, None, None))(
observed_color, latent_color, color_scale, COLOR_MIN_VAL, COLOR_MAX_VAL
)


@Pytree.dataclass
class TruncatedLaplacePixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of a pixel from a truncated
Expand Down
61 changes: 61 additions & 0 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from b3d.modeling_utils import (
_FIXED_DEPTH_UNIFORM_WINDOW,
PythonMixtureDistribution,
renormalized_laplace,
truncated_laplace,
)

Expand Down Expand Up @@ -47,6 +48,66 @@ def logpdf(
raise NotImplementedError


@Pytree.dataclass
class RenormalizedGaussianPixelDepthDistribution(PixelDepthDistribution):
"""A distribution that generates the depth of a pixel from a Gaussian
distribution centered around the latent depth, with the spread controlled
by depth_scale. The support of the distribution is [near, far].
"""

near: float = Pytree.static()
far: float = Pytree.static()

def sample(
self, key: PRNGKey, latent_depth: float, depth_scale: float, *args, **kwargs
) -> float:
return genjax.truncated_normal.sample(
key, latent_depth, depth_scale, self.near, self.far
)

def logpdf(
self,
observed_depth: float,
latent_depth: float,
depth_scale: float,
*args,
**kwargs,
) -> float:
return genjax.truncated_normal.logpdf(
observed_depth, latent_depth, depth_scale, self.near, self.far
)


@Pytree.dataclass
class RenormalizedLaplacePixelDepthDistribution(PixelDepthDistribution):
"""A distribution that generates the depth of a pixel from a Laplace
distribution centered around the latent depth, with the spread controlled
by depth_scale. The support of the distribution is [near, far].
"""

near: float = Pytree.static()
far: float = Pytree.static()

def sample(
self, key: PRNGKey, latent_depth: float, depth_scale: float, *args, **kwargs
) -> float:
return renormalized_laplace.sample(
key, latent_depth, depth_scale, self.near, self.far
)

def logpdf(
self,
observed_depth: float,
latent_depth: float,
depth_scale: float,
*args,
**kwargs,
) -> float:
return renormalized_laplace.logpdf(
observed_depth, latent_depth, depth_scale, self.near, self.far
)


@Pytree.dataclass
class TruncatedLaplacePixelDepthDistribution(PixelDepthDistribution):
"""A distribution that generates the depth of a pixel from a truncated
Expand Down
35 changes: 35 additions & 0 deletions src/b3d/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import genjax
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -74,6 +76,39 @@ def logpdf(v, *args, **kwargs):
bernoulli = tfp_distribution(lambda logits: tfp.distributions.Bernoulli(logits=logits))
normal = tfp_distribution(tfp.distributions.Normal)


###


@Pytree.dataclass
class RenormalizedLaplace(genjax.ExactDensity):
def sample(self, key, loc, scale, low, high):
warnings.warn(
"RenormalizedLaplace sampling is currently not implemented correctly."
)
x = tfp.distributions.Laplace(loc, scale).sample(seed=key)
return jnp.clip(x, low, high)

def logpdf(self, obs, loc, scale, low, high):
laplace_logpdf = tfp.distributions.Laplace(loc, scale).log_prob(obs)
p_below_low = tfp.distributions.Laplace(loc, scale).cdf(low)
p_below_high = tfp.distributions.Laplace(loc, scale).cdf(high)
log_integral_of_laplace_pdf_within_this_range = jnp.log(
p_below_high - p_below_low
)
logpdf_if_in_range = (
laplace_logpdf - log_integral_of_laplace_pdf_within_this_range
)

return jnp.where(
jnp.logical_and(obs >= low, obs <= high),
logpdf_if_in_range,
-jnp.inf,
)


renormalized_laplace = RenormalizedLaplace()

### Mixture distribution combinator ###


Expand Down
8 changes: 6 additions & 2 deletions tests/gen3d/test_pixel_color_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
COLOR_MIN_VAL,
FullPixelColorDistribution,
MixturePixelColorDistribution,
RenormalizedGaussianPixelColorDistribution,
RenormalizedLaplacePixelColorDistribution,
TruncatedLaplacePixelColorDistribution,
UniformPixelColorDistribution,
)
Expand Down Expand Up @@ -36,6 +38,8 @@ def generate_color_grid(n_grid_steps: int):
sample_kernels_to_test = [
(UniformPixelColorDistribution(), ()),
(TruncatedLaplacePixelColorDistribution(), (0.1,)),
(RenormalizedLaplacePixelColorDistribution(), (0.1,)),
(RenormalizedGaussianPixelColorDistribution(), (0.1,)),
(
MixturePixelColorDistribution(),
(
Expand Down Expand Up @@ -80,8 +84,8 @@ def test_sample_in_valid_color_range(kernel_spec, latent_color):
keys
)
assert colors.shape == (num_samples, 3)
assert jnp.all(colors > 0)
assert jnp.all(colors < 1)
assert jnp.all(colors >= 0)
georgematheos marked this conversation as resolved.
Show resolved Hide resolved
assert jnp.all(colors <= 1)


def test_relative_logpdf():
Expand Down
8 changes: 6 additions & 2 deletions tests/gen3d/test_pixel_depth_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
UNEXPLAINED_DEPTH_NONRETURN_PROB,
FullPixelDepthDistribution,
MixturePixelDepthDistribution,
RenormalizedGaussianPixelDepthDistribution,
RenormalizedLaplacePixelDepthDistribution,
TruncatedLaplacePixelDepthDistribution,
UnexplainedPixelDepthDistribution,
UniformPixelDepthDistribution,
Expand All @@ -19,6 +21,8 @@
(UniformPixelDepthDistribution(near, far), ()),
(TruncatedLaplacePixelDepthDistribution(near, far), (0.25,)),
(UnexplainedPixelDepthDistribution(near, far), ()),
(RenormalizedLaplacePixelDepthDistribution(near, far), (0.25,)),
(RenormalizedGaussianPixelDepthDistribution(near, far), (0.25,)),
(
MixturePixelDepthDistribution(near, far),
(
Expand Down Expand Up @@ -72,8 +76,8 @@ def test_sample_in_valid_depth_range(kernel_spec, latent_depth):
keys
)
assert depths.shape == (num_samples,)
assert jnp.all((depths > near) | (depths == DEPTH_NONRETURN_VAL))
assert jnp.all((depths < far) | (depths == DEPTH_NONRETURN_VAL))
assert jnp.all((depths >= near) | (depths == DEPTH_NONRETURN_VAL))
assert jnp.all((depths <= far) | (depths == DEPTH_NONRETURN_VAL))


def test_relative_logpdf():
Expand Down
Loading