From e40de895cf6329f87fb7faa1c5e6356ed010ab94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 18 Dec 2024 21:53:46 +0000 Subject: [PATCH] revert scaling factor, add ss event API --- python/sdist/amici/jax/model.py | 36 +++++++++++++------ python/sdist/amici/jax/petab.py | 24 ++++++++++++- .../benchmark-models/test_petab_benchmark.py | 21 +++-------- tests/petab_test_suite/test_petab_suite.py | 15 ++++++-- 4 files changed, 65 insertions(+), 31 deletions(-) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 4692524070..8b2c09fcc6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -12,6 +12,8 @@ import jax import jaxtyping as jt +from collections.abc import Callable + class ReturnValue(enum.Enum): llh = "log-likelihood" @@ -39,15 +41,11 @@ class JAXModel(eqx.Module): API version of the base class. :ivar jax_py_file: Path to the JAX model file. - :ivar ss_tol_scale_factor: - Tolerance scale factor for the steady state termination check. Multiplied with tolerances of the user-provided - step size controller. """ MODEL_API_VERSION = "0.0.2" api_version: str jax_py_file: Path - ss_tol_scale_factor: jnp.float_ = 10.0 def __init__(self): if self.api_version != self.MODEL_API_VERSION: @@ -259,6 +257,9 @@ def _eq( x0: jt.Float[jt.Array, "nxs"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]: """ @@ -290,10 +291,7 @@ def _eq( max_steps=max_steps, adjoint=diffrax.DirectAdjoint(), event=diffrax.Event( - cond_fn=diffrax.steady_state_event( - rtol=controller.rtol * self.ss_tol_scale_factor, - atol=controller.atol * self.ss_tol_scale_factor, - ) + cond_fn=steady_state_event, ), throw=False, ) @@ -474,6 +472,9 @@ def simulate_condition( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), @@ -549,7 +550,13 @@ def simulate_condition( # Post-equilibration if ts_posteq.shape[0]: x_solver, stats_posteq = self._eq( - p, tcl, x_solver, solver, controller, max_steps + p, + tcl, + x_solver, + solver, + controller, + steady_state_event, + max_steps, ) else: stats_posteq = None @@ -620,6 +627,9 @@ def preequilibrate_condition( mask_reinit: jt.Bool[jt.Array, "*nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: r""" @@ -647,7 +657,13 @@ def preequilibrate_condition( tcl = self._tcl(x0, p) current_x = self._x_solver(x0) current_x, stats_preeq = self._eq( - p, tcl, current_x, solver, controller, max_steps + p, + tcl, + current_x, + solver, + controller, + steady_state_event, + max_steps, ) return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 6a7da4b42f..8db5055fdf 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -3,6 +3,7 @@ from numbers import Number from collections.abc import Iterable from pathlib import Path +from collections.abc import Callable import diffrax @@ -465,6 +466,9 @@ def run_simulation( simulation_condition: tuple[str, ...], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ret: ReturnValue = ReturnValue.llh, @@ -507,6 +511,7 @@ def run_simulation( solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, adjoint=diffrax.RecursiveCheckpointAdjoint() if ret in (ReturnValue.llh, ReturnValue.chi2) else diffrax.DirectAdjoint(), @@ -518,6 +523,9 @@ def run_preequilibration( simulation_condition: str, solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ @@ -545,6 +553,7 @@ def run_preequilibration( solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, ) @@ -555,6 +564,9 @@ def run_simulations( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ret: ReturnValue | str = ReturnValue.llh, ): @@ -569,6 +581,9 @@ def run_simulations( ODE solver to use for simulation. :param controller: Step size controller to use for simulation. + :param steady_state_event: + Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation. :param ret: @@ -583,7 +598,9 @@ def run_simulations( simulation_conditions = problem.get_all_simulation_conditions() preeqs = { - sc: problem.run_preequilibration(sc, solver, controller, max_steps) + sc: problem.run_preequilibration( + sc, solver, controller, steady_state_event, max_steps + ) # only run preequilibration once per condition for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} } @@ -593,6 +610,7 @@ def run_simulations( sc, solver, controller, + steady_state_event, max_steps, preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), ret=ret, @@ -617,6 +635,9 @@ def petab_simulate( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ): """ @@ -637,6 +658,7 @@ def petab_simulate( problem, solver=solver, controller=controller, + steady_state_event=steady_state_event, max_steps=max_steps, ret=ReturnValue.y, ) diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 2a8aa72659..d9f836b0b4 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -8,7 +8,6 @@ from functools import partial from pathlib import Path -import diffrax import fiddy import amici import numpy as np @@ -287,12 +286,8 @@ def test_jax_llh(benchmark_problem): amici_solver = amici_model.getSolver() cur_settings = settings[problem_id] - if problem_id in ("Zheng_PNAS2012",): - tol = 1e-12 - else: - tol = 1e-8 - amici_solver.setAbsoluteTolerance(tol) - amici_solver.setRelativeTolerance(tol) + amici_solver.setAbsoluteTolerance(1e-8) + amici_solver.setRelativeTolerance(1e-8) amici_solver.setMaxSteps(10_000) simulate_amici = partial( @@ -348,17 +343,9 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - llh_jax, _ = beartype(run_simulations)(jax_problem) + if problem_id in problems_for_gradient_check: - kwargs = {} - if problem_id in ("Zheng_PNAS2012",): - kwargs["controller"] = diffrax.PIDController( - atol=1e-14, - rtol=1e-14, - pcoeff=0.4, - icoeff=0.3, - dcoeff=0.0, - ) + beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True )(jax_problem) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 5fe61adcf2..4fcbe0b631 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -4,6 +4,8 @@ import logging import sys +import diffrax + import amici import pandas as pd import petab.v1 as petab @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax): if jax: from amici.jax import JAXProblem, run_simulations, petab_simulate + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) jax_problem = JAXProblem(model, problem) - llh, ret = run_simulations(jax_problem) - chi2, _ = run_simulations(jax_problem, ret="chi2") - simulation_df = petab_simulate(jax_problem) + llh, ret = run_simulations( + jax_problem, steady_state_event=steady_state_event + ) + chi2, _ = run_simulations( + jax_problem, ret="chi2", steady_state_event=steady_state_event + ) + simulation_df = petab_simulate( + jax_problem, steady_state_event=steady_state_event + ) simulation_df.rename( columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True )