Skip to content

Commit

Permalink
Fix an edge case (1-D Gaussians) (#5)
Browse files Browse the repository at this point in the history
The optimal covariance function had a legacy use of numpy and not jax.numpy which was not caught by the test due to the lack of jitting of this function...
  • Loading branch information
AdrienCorenflos authored May 24, 2022
1 parent 48a3dfc commit 63d3117
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions coupled_rejection_sampling/mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.random
import jax.scipy.linalg as jlinalg
import jax.scipy.stats as jstats
import numpy as np
from jax import numpy as jnp
from jax.scipy.linalg import cho_solve

Expand Down Expand Up @@ -33,7 +32,7 @@ def get_optimal_covariance(chol_P, chol_Sig):
"""
d = chol_P.shape[0]
if d == 1:
return np.maximum(chol_P, chol_Sig)
return jnp.maximum(chol_P, chol_Sig)

right_Y = jlinalg.solve_triangular(chol_P, chol_Sig, lower=True) # Y = RY.T RY
w_Y, v_Y = jlinalg.eigh(right_Y.T @ right_Y)
Expand Down

0 comments on commit 63d3117

Please sign in to comment.