Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add renormalized laplace and gaussian distribution and kernels #162

Merged
merged 2 commits into from
Sep 11, 2024

Conversation

georgematheos
Copy link
Collaborator

No description provided.

Copy link
Contributor

@horizon-blue horizon-blue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow thanks for putting this together in this short amount of time! This looks great to me overall -- I just left some inline comments, most of which are just cosmetics stuff. Feel free to merge this whenever you're ready.

Comment on lines 94 to 98
p_below_low = tfp.distributions.Laplace(loc, scale).cdf(low)
p_below_high = tfp.distributions.Laplace(loc, scale).cdf(high)
log_integral_of_laplace_pdf_over_this_range = jnp.log(
p_below_high - p_below_low
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For numerical stability, perhaps we can consider using Laplace.log_cdf and Laplace. log_survival_function?

(btw also just for clarity: I think log_integral_of_laplace_pdf_over_this_range is actually referring to the integral of laplace within the range?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@horizon-blue can you please suggest the right code snippet we need to do log_cdf and log_survival_function here?

tests/gen3d/test_pixel_color_kernels.py Show resolved Hide resolved
def sample(self, key, latent_color, color_scale, *args, **kwargs):
return jax.vmap(
genjax.truncated_normal.sample, in_axes=(0, 0, None, None, None)
)(split(key, latent_color.shape[0]), latent_color, color_scale, 0.0, 1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to you feel about usingCOLOR_MIN_VAL and COLOR_MAX_VAL instead of hard-coding the magic constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!

@georgematheos
Copy link
Collaborator Author

@horizon-blue I made the suggested changes that I knew how to make!

@georgematheos georgematheos merged commit ac49e93 into gen3d Sep 11, 2024
6 of 7 checks passed
@georgematheos georgematheos deleted the gm/gen3d/renormalized_dists branch September 11, 2024 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants