Skip to content

Commit

Permalink
Have tests cover multivariate case
Browse files Browse the repository at this point in the history
  • Loading branch information
emdupre committed Jan 6, 2024
1 parent 83eaaf4 commit 822695b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions bayes_ca/tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.lax import conv
Expand All @@ -18,9 +19,11 @@
# Test
key = jr.PRNGKey(0)
num_timesteps = 1000
num_features = 3
K = 100 # max run length
lmbda = 0.1 # prob changepoint
mu0 = 0.0 # prior mean
mu0 = jnp.repeat(mu0, num_features)
sigmasq0 = 3**2 # prior variance
sigmasq = 0.5**2 # observation variance

Expand All @@ -37,8 +40,15 @@
this_key, key = jr.split(key)
xs = mus + jnp.sqrt(sigmasq) * jr.normal(this_key, mus.shape)

partial_sums, partial_counts = _compute_gaussian_stats(xs, K + 1)
lls = _compute_gaussian_lls(xs, K + 1, mu0, sigmasq0, sigmasq)
# partial_sums, partial_counts = _compute_gaussian_stats(xs, K + 1)
partial_sums, partial_counts = jax.vmap(_compute_gaussian_stats, in_axes=(-1, None), out_axes=-1)(
xs, K + 1
)
# lls = _compute_gaussian_lls(xs, K + 1, mu0, sigmasq0, sigmasq)
lls = jax.vmap(_compute_gaussian_lls, in_axes=(-1, None, 0, None, None))(
xs, K + 1, mu0, sigmasq0, sigmasq
)
lls = lls.sum(axis=0)
_, _, transition_probs = cp_smoother(hazard_rates, lls)


Expand Down

0 comments on commit 822695b

Please sign in to comment.