Skip to content

Commit

Permalink
revert scaling factor, add ss event API
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 18, 2024
1 parent 9016bc0 commit e40de89
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
36 changes: 26 additions & 10 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import jax
import jaxtyping as jt

from collections.abc import Callable


class ReturnValue(enum.Enum):
llh = "log-likelihood"
Expand Down Expand Up @@ -39,15 +41,11 @@ class JAXModel(eqx.Module):
API version of the base class.
:ivar jax_py_file:
Path to the JAX model file.
:ivar ss_tol_scale_factor:
Tolerance scale factor for the steady state termination check. Multiplied with tolerances of the user-provided
step size controller.
"""

MODEL_API_VERSION = "0.0.2"
api_version: str
jax_py_file: Path
ss_tol_scale_factor: jnp.float_ = 10.0

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
Expand Down Expand Up @@ -259,6 +257,9 @@ def _eq(
x0: jt.Float[jt.Array, "nxs"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]:
"""
Expand Down Expand Up @@ -290,10 +291,7 @@ def _eq(
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(
cond_fn=diffrax.steady_state_event(
rtol=controller.rtol * self.ss_tol_scale_factor,
atol=controller.atol * self.ss_tol_scale_factor,
)
cond_fn=steady_state_event,
),
throw=False,
)
Expand Down Expand Up @@ -474,6 +472,9 @@ def simulate_condition(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
Expand Down Expand Up @@ -549,7 +550,13 @@ def simulate_condition(
# Post-equilibration
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
p,
tcl,
x_solver,
solver,
controller,
steady_state_event,
max_steps,
)
else:
stats_posteq = None
Expand Down Expand Up @@ -620,6 +627,9 @@ def preequilibrate_condition(
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Expand Down Expand Up @@ -647,7 +657,13 @@ def preequilibrate_condition(
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
p,
tcl,
current_x,
solver,
controller,
steady_state_event,
max_steps,
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)
Expand Down
24 changes: 23 additions & 1 deletion python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numbers import Number
from collections.abc import Iterable
from pathlib import Path
from collections.abc import Callable


import diffrax
Expand Down Expand Up @@ -465,6 +466,9 @@ def run_simulation(
simulation_condition: tuple[str, ...],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: ReturnValue = ReturnValue.llh,
Expand Down Expand Up @@ -507,6 +511,7 @@ def run_simulation(
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in (ReturnValue.llh, ReturnValue.chi2)
else diffrax.DirectAdjoint(),
Expand All @@ -518,6 +523,9 @@ def run_preequilibration(
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Expand Down Expand Up @@ -545,6 +553,7 @@ def run_preequilibration(
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
)


Expand All @@ -555,6 +564,9 @@ def run_simulations(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
ret: ReturnValue | str = ReturnValue.llh,
):
Expand All @@ -569,6 +581,9 @@ def run_simulations(
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
Expand All @@ -583,7 +598,9 @@ def run_simulations(
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}
Expand All @@ -593,6 +610,7 @@ def run_simulations(
sc,
solver,
controller,
steady_state_event,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
Expand All @@ -617,6 +635,9 @@ def petab_simulate(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
):
"""
Expand All @@ -637,6 +658,7 @@ def petab_simulate(
problem,
solver=solver,
controller=controller,
steady_state_event=steady_state_event,
max_steps=max_steps,
ret=ReturnValue.y,
)
Expand Down
21 changes: 4 additions & 17 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from functools import partial
from pathlib import Path

import diffrax
import fiddy
import amici
import numpy as np
Expand Down Expand Up @@ -287,12 +286,8 @@ def test_jax_llh(benchmark_problem):

amici_solver = amici_model.getSolver()
cur_settings = settings[problem_id]
if problem_id in ("Zheng_PNAS2012",):
tol = 1e-12
else:
tol = 1e-8
amici_solver.setAbsoluteTolerance(tol)
amici_solver.setRelativeTolerance(tol)
amici_solver.setAbsoluteTolerance(1e-8)
amici_solver.setRelativeTolerance(1e-8)
amici_solver.setMaxSteps(10_000)

simulate_amici = partial(
Expand Down Expand Up @@ -348,17 +343,9 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
llh_jax, _ = beartype(run_simulations)(jax_problem)

if problem_id in problems_for_gradient_check:
kwargs = {}
if problem_id in ("Zheng_PNAS2012",):
kwargs["controller"] = diffrax.PIDController(
atol=1e-14,
rtol=1e-14,
pcoeff=0.4,
icoeff=0.3,
dcoeff=0.0,
)
beartype(run_simulations)(jax_problem)
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem)
Expand Down
15 changes: 12 additions & 3 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import sys

import diffrax

import amici
import pandas as pd
import petab.v1 as petab
Expand Down Expand Up @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax):
if jax:
from amici.jax import JAXProblem, run_simulations, petab_simulate

steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
jax_problem = JAXProblem(model, problem)
llh, ret = run_simulations(jax_problem)
chi2, _ = run_simulations(jax_problem, ret="chi2")
simulation_df = petab_simulate(jax_problem)
llh, ret = run_simulations(
jax_problem, steady_state_event=steady_state_event
)
chi2, _ = run_simulations(
jax_problem, ret="chi2", steady_state_event=steady_state_event
)
simulation_df = petab_simulate(
jax_problem, steady_state_event=steady_state_event
)
simulation_df.rename(
columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True
)
Expand Down

0 comments on commit e40de89

Please sign in to comment.