Skip to content

Commit

Permalink
Porting changes from CDCent (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon authored Apr 1, 2024
1 parent fe31ef8 commit f6cd324
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion model/src/pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _ar_scanner(carry, next):
)

last, ts = lax.scan(_ar_scanner, inits - self.mean, noise)
return (self.mean + ts.flatten(),)
return (jnp.hstack([inits, self.mean + ts.flatten()]),)

@staticmethod
def validate():
Expand Down
23 changes: 20 additions & 3 deletions model/src/test/test_ar_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
import numpyro
from numpy.testing import assert_almost_equal
from pyrenew.process import ARProcess


Expand All @@ -12,11 +11,29 @@ def test_ar_can_be_sampled():
"""
ar1 = ARProcess(5, jnp.array([0.95]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
## can sample with and without inits
# can sample with and without inits
ar1.sample(3532, inits=jnp.array([50.0]))
ar1.sample(5023)

ar3 = ARProcess(5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5]))
with numpyro.handlers.seed(rng_seed=62):
# can sample with and without inits
ar3.sample(1230)
ar3.sample(52, inits=jnp.array([50.0, 49.9, 48.2]))


def test_ar_samples_correctly_distributed():
"""
Check that AR processes have correctly-
distributed steps.
"""
ar_mean = 5
noise_sd = jnp.array([0.5])
ar_inits = jnp.array([25.0])
ar1 = ARProcess(ar_mean, jnp.array([0.75]), noise_sd)
with numpyro.handlers.seed(rng_seed=62):
# check it regresses to mean
# when started away from it
long_ts, *_ = ar1.sample(10000, inits=ar_inits)
assert_almost_equal(long_ts[0], ar_inits)
assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd
12 changes: 6 additions & 6 deletions model/src/test/test_random_walk.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpy.testing import assert_almost_equal
from pyrenew.process import SimpleRandomWalkProcess


Expand Down Expand Up @@ -30,15 +29,13 @@ def test_rw_samples_correctly_distributed():
[0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02]
):
rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd))

init_arr = jnp.array([532.0])
with numpyro.handlers.seed(rng_seed=62):
samples, *_ = rw_normal.sample(n_samples, init=jnp.array([50.0]))
samples, *_ = rw_normal.sample(n_samples, init=init_arr)

# diffs should not be greater than
# 4 sigma
diffs = jnp.diff(samples)
print(samples)
print(diffs)
assert jnp.all(jnp.abs(diffs - step_mean) < 4 * step_sd)

# sample mean of diffs should be
Expand All @@ -52,3 +49,6 @@ def test_rw_samples_correctly_distributed():
# should be approximately equal
# to the step sd
assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1)

# first value should be the init value
assert_almost_equal(samples[0], init_arr)

0 comments on commit f6cd324

Please sign in to comment.