From e7cfd12989eaf8ee98a97a86ac65331d91fac4bd Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 11 Sep 2024 19:50:45 +0000 Subject: [PATCH 1/2] Add renormalized laplace and gaussian distribution and kernels --- .../pixel_kernels/pixel_color_kernels.py | 46 ++++++++++++++ .../pixel_kernels/pixel_depth_kernels.py | 61 +++++++++++++++++++ src/b3d/modeling_utils.py | 35 +++++++++++ tests/gen3d/test_pixel_color_kernels.py | 8 ++- tests/gen3d/test_pixel_depth_kernels.py | 8 ++- 5 files changed, 154 insertions(+), 4 deletions(-) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py index 867f84d5..69c66b0b 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py @@ -6,11 +6,13 @@ import jax.numpy as jnp from genjax import Pytree from genjax.typing import FloatArray, PRNGKey +from jax.random import split from tensorflow_probability.substrates import jax as tfp from b3d.modeling_utils import ( _FIXED_COLOR_UNIFORM_WINDOW, PythonMixtureDistribution, + renormalized_laplace, truncated_laplace, ) @@ -73,6 +75,50 @@ def logpdf_per_channel( raise NotImplementedError +@Pytree.dataclass +class RenormalizedGaussianPixelColorDistribution(PixelColorDistribution): + """ + Sample a color from a renormalized Gaussian distribution centered around the given + latent_color (rgb value), given the color_scale (stddev of the Gaussian). + + The support of the distribution is ([0, 1]^3). + """ + + def sample(self, key, latent_color, color_scale, *args, **kwargs): + return jax.vmap( + genjax.truncated_normal.sample, in_axes=(0, 0, None, None, None) + )(split(key, latent_color.shape[0]), latent_color, color_scale, 0.0, 1.0) + + def logpdf_per_channel( + self, observed_color, latent_color, color_scale, *args, **kwargs + ): + return jax.vmap( + genjax.truncated_normal.logpdf, in_axes=(0, 0, None, None, None) + )(observed_color, latent_color, color_scale, 0.0, 1.0) + + +@Pytree.dataclass +class RenormalizedLaplacePixelColorDistribution(PixelColorDistribution): + """ + Sample a color from a renormalized Laplace distribution centered around the given + latent_color (rgb value), given the color_scale (scale of the laplace). + + The support of the distribution is ([0, 1]^3). + """ + + def sample(self, key, latent_color, color_scale, *args, **kwargs): + return jax.vmap(renormalized_laplace.sample, in_axes=(0, 0, None, None, None))( + split(key, latent_color.shape[0]), latent_color, color_scale, 0.0, 1.0 + ) + + def logpdf_per_channel( + self, observed_color, latent_color, color_scale, *args, **kwargs + ): + return jax.vmap(renormalized_laplace.logpdf, in_axes=(0, 0, None, None, None))( + observed_color, latent_color, color_scale, 0.0, 1.0 + ) + + @Pytree.dataclass class TruncatedLaplacePixelColorDistribution(PixelColorDistribution): """A distribution that generates the color of a pixel from a truncated diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py index 186e9931..9d4e4adf 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py @@ -12,6 +12,7 @@ from b3d.modeling_utils import ( _FIXED_DEPTH_UNIFORM_WINDOW, PythonMixtureDistribution, + renormalized_laplace, truncated_laplace, ) @@ -47,6 +48,66 @@ def logpdf( raise NotImplementedError +@Pytree.dataclass +class RenormalizedGaussianPixelDepthDistribution(PixelDepthDistribution): + """A distribution that generates the depth of a pixel from a Gaussian + distribution centered around the latent depth, with the spread controlled + by depth_scale. The support of the distribution is [near, far]. + """ + + near: float = Pytree.static() + far: float = Pytree.static() + + def sample( + self, key: PRNGKey, latent_depth: float, depth_scale: float, *args, **kwargs + ) -> float: + return genjax.truncated_normal.sample( + key, latent_depth, depth_scale, self.near, self.far + ) + + def logpdf( + self, + observed_depth: float, + latent_depth: float, + depth_scale: float, + *args, + **kwargs, + ) -> float: + return genjax.truncated_normal.logpdf( + observed_depth, latent_depth, depth_scale, self.near, self.far + ) + + +@Pytree.dataclass +class RenormalizedLaplacePixelDepthDistribution(PixelDepthDistribution): + """A distribution that generates the depth of a pixel from a Laplace + distribution centered around the latent depth, with the spread controlled + by depth_scale. The support of the distribution is [near, far]. + """ + + near: float = Pytree.static() + far: float = Pytree.static() + + def sample( + self, key: PRNGKey, latent_depth: float, depth_scale: float, *args, **kwargs + ) -> float: + return renormalized_laplace.sample( + key, latent_depth, depth_scale, self.near, self.far + ) + + def logpdf( + self, + observed_depth: float, + latent_depth: float, + depth_scale: float, + *args, + **kwargs, + ) -> float: + return renormalized_laplace.logpdf( + observed_depth, latent_depth, depth_scale, self.near, self.far + ) + + @Pytree.dataclass class TruncatedLaplacePixelDepthDistribution(PixelDepthDistribution): """A distribution that generates the depth of a pixel from a truncated diff --git a/src/b3d/modeling_utils.py b/src/b3d/modeling_utils.py index e10431f4..e35ec16a 100644 --- a/src/b3d/modeling_utils.py +++ b/src/b3d/modeling_utils.py @@ -1,3 +1,5 @@ +import warnings + import genjax import jax import jax.numpy as jnp @@ -74,6 +76,39 @@ def logpdf(v, *args, **kwargs): bernoulli = tfp_distribution(lambda logits: tfp.distributions.Bernoulli(logits=logits)) normal = tfp_distribution(tfp.distributions.Normal) + +### + + +@Pytree.dataclass +class RenormalizedLaplace(genjax.ExactDensity): + def sample(self, key, loc, scale, low, high): + warnings.warn( + "RenormalizedLaplace sampling is currently not implemented correctly." + ) + x = tfp.distributions.Laplace(loc, scale).sample(seed=key) + return jnp.clip(x, low, high) + + def logpdf(self, obs, loc, scale, low, high): + laplace_logpdf = tfp.distributions.Laplace(loc, scale).log_prob(obs) + p_below_low = tfp.distributions.Laplace(loc, scale).cdf(low) + p_below_high = tfp.distributions.Laplace(loc, scale).cdf(high) + log_integral_of_laplace_pdf_over_this_range = jnp.log( + p_below_high - p_below_low + ) + logpdf_if_in_range = ( + laplace_logpdf - log_integral_of_laplace_pdf_over_this_range + ) + + return jnp.where( + jnp.logical_and(obs >= low, obs <= high), + logpdf_if_in_range, + -jnp.inf, + ) + + +renormalized_laplace = RenormalizedLaplace() + ### Mixture distribution combinator ### diff --git a/tests/gen3d/test_pixel_color_kernels.py b/tests/gen3d/test_pixel_color_kernels.py index e861b4b8..10505756 100644 --- a/tests/gen3d/test_pixel_color_kernels.py +++ b/tests/gen3d/test_pixel_color_kernels.py @@ -8,6 +8,8 @@ COLOR_MIN_VAL, FullPixelColorDistribution, MixturePixelColorDistribution, + RenormalizedGaussianPixelColorDistribution, + RenormalizedLaplacePixelColorDistribution, TruncatedLaplacePixelColorDistribution, UniformPixelColorDistribution, ) @@ -36,6 +38,8 @@ def generate_color_grid(n_grid_steps: int): sample_kernels_to_test = [ (UniformPixelColorDistribution(), ()), (TruncatedLaplacePixelColorDistribution(), (0.1,)), + (RenormalizedLaplacePixelColorDistribution(), (0.1,)), + (RenormalizedGaussianPixelColorDistribution(), (0.1,)), ( MixturePixelColorDistribution(), ( @@ -80,8 +84,8 @@ def test_sample_in_valid_color_range(kernel_spec, latent_color): keys ) assert colors.shape == (num_samples, 3) - assert jnp.all(colors > 0) - assert jnp.all(colors < 1) + assert jnp.all(colors >= 0) + assert jnp.all(colors <= 1) def test_relative_logpdf(): diff --git a/tests/gen3d/test_pixel_depth_kernels.py b/tests/gen3d/test_pixel_depth_kernels.py index 03aff1c9..4c4aceac 100644 --- a/tests/gen3d/test_pixel_depth_kernels.py +++ b/tests/gen3d/test_pixel_depth_kernels.py @@ -6,6 +6,8 @@ UNEXPLAINED_DEPTH_NONRETURN_PROB, FullPixelDepthDistribution, MixturePixelDepthDistribution, + RenormalizedGaussianPixelDepthDistribution, + RenormalizedLaplacePixelDepthDistribution, TruncatedLaplacePixelDepthDistribution, UnexplainedPixelDepthDistribution, UniformPixelDepthDistribution, @@ -19,6 +21,8 @@ (UniformPixelDepthDistribution(near, far), ()), (TruncatedLaplacePixelDepthDistribution(near, far), (0.25,)), (UnexplainedPixelDepthDistribution(near, far), ()), + (RenormalizedLaplacePixelDepthDistribution(near, far), (0.25,)), + (RenormalizedGaussianPixelDepthDistribution(near, far), (0.25,)), ( MixturePixelDepthDistribution(near, far), ( @@ -72,8 +76,8 @@ def test_sample_in_valid_depth_range(kernel_spec, latent_depth): keys ) assert depths.shape == (num_samples,) - assert jnp.all((depths > near) | (depths == DEPTH_NONRETURN_VAL)) - assert jnp.all((depths < far) | (depths == DEPTH_NONRETURN_VAL)) + assert jnp.all((depths >= near) | (depths == DEPTH_NONRETURN_VAL)) + assert jnp.all((depths <= far) | (depths == DEPTH_NONRETURN_VAL)) def test_relative_logpdf(): From c0a49296c53216c19ccaa05f5bcf2aa73b04c67a Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 11 Sep 2024 20:29:58 +0000 Subject: [PATCH 2/2] minor improvements --- .../gen3d/pixel_kernels/pixel_color_kernels.py | 18 ++++++++++++++---- src/b3d/modeling_utils.py | 4 ++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py index 69c66b0b..7c7f7a9b 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py @@ -87,14 +87,20 @@ class RenormalizedGaussianPixelColorDistribution(PixelColorDistribution): def sample(self, key, latent_color, color_scale, *args, **kwargs): return jax.vmap( genjax.truncated_normal.sample, in_axes=(0, 0, None, None, None) - )(split(key, latent_color.shape[0]), latent_color, color_scale, 0.0, 1.0) + )( + split(key, latent_color.shape[0]), + latent_color, + color_scale, + COLOR_MIN_VAL, + COLOR_MAX_VAL, + ) def logpdf_per_channel( self, observed_color, latent_color, color_scale, *args, **kwargs ): return jax.vmap( genjax.truncated_normal.logpdf, in_axes=(0, 0, None, None, None) - )(observed_color, latent_color, color_scale, 0.0, 1.0) + )(observed_color, latent_color, color_scale, COLOR_MIN_VAL, COLOR_MAX_VAL) @Pytree.dataclass @@ -108,14 +114,18 @@ class RenormalizedLaplacePixelColorDistribution(PixelColorDistribution): def sample(self, key, latent_color, color_scale, *args, **kwargs): return jax.vmap(renormalized_laplace.sample, in_axes=(0, 0, None, None, None))( - split(key, latent_color.shape[0]), latent_color, color_scale, 0.0, 1.0 + split(key, latent_color.shape[0]), + latent_color, + color_scale, + COLOR_MIN_VAL, + COLOR_MAX_VAL, ) def logpdf_per_channel( self, observed_color, latent_color, color_scale, *args, **kwargs ): return jax.vmap(renormalized_laplace.logpdf, in_axes=(0, 0, None, None, None))( - observed_color, latent_color, color_scale, 0.0, 1.0 + observed_color, latent_color, color_scale, COLOR_MIN_VAL, COLOR_MAX_VAL ) diff --git a/src/b3d/modeling_utils.py b/src/b3d/modeling_utils.py index e35ec16a..e2d9c0e5 100644 --- a/src/b3d/modeling_utils.py +++ b/src/b3d/modeling_utils.py @@ -93,11 +93,11 @@ def logpdf(self, obs, loc, scale, low, high): laplace_logpdf = tfp.distributions.Laplace(loc, scale).log_prob(obs) p_below_low = tfp.distributions.Laplace(loc, scale).cdf(low) p_below_high = tfp.distributions.Laplace(loc, scale).cdf(high) - log_integral_of_laplace_pdf_over_this_range = jnp.log( + log_integral_of_laplace_pdf_within_this_range = jnp.log( p_below_high - p_below_low ) logpdf_if_in_range = ( - laplace_logpdf - log_integral_of_laplace_pdf_over_this_range + laplace_logpdf - log_integral_of_laplace_pdf_within_this_range ) return jnp.where(