diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 7d3d2969..412f05df 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -78,7 +78,7 @@ def _ar_scanner(carry, next): ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return (self.mean + ts.flatten(),) + return (jnp.hstack([inits, self.mean + ts.flatten()]),) @staticmethod def validate(): diff --git a/model/src/test/test_ar_process.py b/model/src/test/test_ar_process.py index af658bef..b1a6d9d2 100755 --- a/model/src/test/test_ar_process.py +++ b/model/src/test/test_ar_process.py @@ -1,7 +1,6 @@ -# -*- coding: utf-8 -*- - import jax.numpy as jnp import numpyro +from numpy.testing import assert_almost_equal from pyrenew.process import ARProcess @@ -12,11 +11,29 @@ def test_ar_can_be_sampled(): """ ar1 = ARProcess(5, jnp.array([0.95]), jnp.array([0.5])) with numpyro.handlers.seed(rng_seed=62): - ## can sample with and without inits + # can sample with and without inits ar1.sample(3532, inits=jnp.array([50.0])) ar1.sample(5023) ar3 = ARProcess(5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5])) with numpyro.handlers.seed(rng_seed=62): + # can sample with and without inits ar3.sample(1230) ar3.sample(52, inits=jnp.array([50.0, 49.9, 48.2])) + + +def test_ar_samples_correctly_distributed(): + """ + Check that AR processes have correctly- + distributed steps. + """ + ar_mean = 5 + noise_sd = jnp.array([0.5]) + ar_inits = jnp.array([25.0]) + ar1 = ARProcess(ar_mean, jnp.array([0.75]), noise_sd) + with numpyro.handlers.seed(rng_seed=62): + # check it regresses to mean + # when started away from it + long_ts, *_ = ar1.sample(10000, inits=ar_inits) + assert_almost_equal(long_ts[0], ar_inits) + assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 33e4fa96..43dc9043 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- - import jax.numpy as jnp import numpyro import numpyro.distributions as dist +from numpy.testing import assert_almost_equal from pyrenew.process import SimpleRandomWalkProcess @@ -30,15 +29,13 @@ def test_rw_samples_correctly_distributed(): [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - + init_arr = jnp.array([532.0]) with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal.sample(n_samples, init=jnp.array([50.0])) + samples, *_ = rw_normal.sample(n_samples, init=init_arr) # diffs should not be greater than # 4 sigma diffs = jnp.diff(samples) - print(samples) - print(diffs) assert jnp.all(jnp.abs(diffs - step_mean) < 4 * step_sd) # sample mean of diffs should be @@ -52,3 +49,6 @@ def test_rw_samples_correctly_distributed(): # should be approximately equal # to the step sd assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) + + # first value should be the init value + assert_almost_equal(samples[0], init_arr)