diff --git a/bayes_ca/tests/test_inference.py b/bayes_ca/tests/test_inference.py index 72a9c34..707151d 100644 --- a/bayes_ca/tests/test_inference.py +++ b/bayes_ca/tests/test_inference.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp import jax.random as jr from jax.lax import conv @@ -18,9 +19,11 @@ # Test key = jr.PRNGKey(0) num_timesteps = 1000 +num_features = 3 K = 100 # max run length lmbda = 0.1 # prob changepoint mu0 = 0.0 # prior mean +mu0 = jnp.repeat(mu0, num_features) sigmasq0 = 3**2 # prior variance sigmasq = 0.5**2 # observation variance @@ -37,8 +40,15 @@ this_key, key = jr.split(key) xs = mus + jnp.sqrt(sigmasq) * jr.normal(this_key, mus.shape) -partial_sums, partial_counts = _compute_gaussian_stats(xs, K + 1) -lls = _compute_gaussian_lls(xs, K + 1, mu0, sigmasq0, sigmasq) +# partial_sums, partial_counts = _compute_gaussian_stats(xs, K + 1) +partial_sums, partial_counts = jax.vmap(_compute_gaussian_stats, in_axes=(-1, None), out_axes=-1)( + xs, K + 1 +) +# lls = _compute_gaussian_lls(xs, K + 1, mu0, sigmasq0, sigmasq) +lls = jax.vmap(_compute_gaussian_lls, in_axes=(-1, None, 0, None, None))( + xs, K + 1, mu0, sigmasq0, sigmasq +) +lls = lls.sum(axis=0) _, _, transition_probs = cp_smoother(hazard_rates, lls)