From 64de63aee3a116a13f89f4979484186f9c8e8f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 30 Jan 2025 10:37:28 +0000 Subject: [PATCH] jax vectorisation (#2636) * fix vectorisation * decorators, preequilibration * add documentation * fix petab_simulate * fix notebook * update/fix doc? * fix rtd * fix docstring * Update python/sdist/amici/jax/petab.py Co-authored-by: Daniel Weindl --------- Co-authored-by: Daniel Weindl --- documentation/rtd_requirements.txt | 1 + .../example_jax_petab/ExampleJaxPEtab.ipynb | 452 +++++++++++------- python/sdist/amici/jax/model.py | 20 +- python/sdist/amici/jax/petab.py | 435 +++++++++++++---- .../benchmark-models/test_petab_benchmark.py | 10 - 5 files changed, 646 insertions(+), 272 deletions(-) diff --git a/documentation/rtd_requirements.txt b/documentation/rtd_requirements.txt index 54a35f9f94..a8c914d5f1 100644 --- a/documentation/rtd_requirements.txt +++ b/documentation/rtd_requirements.txt @@ -5,6 +5,7 @@ setuptools>=67.7.2 pysb>=1.11.0 jax>=0.4.26 diffrax>=0.5.0 +interpax>=0.3.4 matplotlib==3.7.1 nbsphinx==0.9.1 nbformat==5.8.0 diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index e3677af346..848f6521ff 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -25,10 +25,16 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "c71c96da0da3144a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:12.220374Z", + "start_time": "2025-01-29T15:49:12.114366Z" + } + }, + "outputs": [], "source": [ "from amici.petab.petab_import import import_petab_problem\n", "import petab.v1 as petab\n", @@ -49,24 +55,29 @@ " verbose=False, # no text output\n", " jax=True, # return jax model\n", ")" - ], - "id": "c71c96da0da3144a" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "7e0f1c27bd71ee1f", + "metadata": {}, "source": [ "## Simulation\n", "\n", "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." - ], - "id": "7e0f1c27bd71ee1f" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "ccecc9a29acc7b73", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.455948Z", + "start_time": "2025-01-29T15:49:12.224414Z" + } + }, + "outputs": [], "source": [ "from amici.jax import JAXProblem, run_simulations\n", "\n", @@ -75,44 +86,53 @@ "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" - ], - "id": "ccecc9a29acc7b73" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", - "id": "415962751301c64a" + "id": "415962751301c64a", + "metadata": {}, + "source": [ + "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "596b86e45e18fe3d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.469126Z", + "start_time": "2025-01-29T15:49:13.464492Z" + } + }, + "outputs": [], "source": [ - "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", - "\n", - "# Access the results for the specified condition\n", - "results[simulation_condition]" - ], - "id": "596b86e45e18fe3d" + "# Access the results\n", + "results" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "a1b173e013f9210a", + "metadata": {}, "source": [ - "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", + "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results['x']` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", "\n", "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." - ], - "id": "a1b173e013f9210a" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "f4f5ff705a3f7402", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.517447Z", + "start_time": "2025-01-29T15:49:13.498128Z" + } + }, + "outputs": [], "source": [ "import jax\n", "\n", @@ -123,25 +143,35 @@ "llh, results = run_simulations(jax_problem)\n", "\n", "results" - ], - "id": "f4f5ff705a3f7402" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", - "id": "fe4d3b40ee3efdf2" + "id": "fe4d3b40ee3efdf2", + "metadata": {}, + "source": [ + "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "72f1ed397105e14a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.626555Z", + "start_time": "2025-01-29T15:49:13.540193Z" + } + }, + "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", "\n", "def plot_simulation(results):\n", " \"\"\"\n", @@ -151,14 +181,14 @@ " results (dict): Simulation results from run_simulations.\n", " \"\"\"\n", " # Extract the simulation results for the specific condition\n", - " sim_results = results[simulation_condition]\n", + " ic = results[\"simulation_conditions\"].index(simulation_condition)\n", "\n", " # Create a new figure for the state trajectories\n", " plt.figure(figsize=(8, 6))\n", - " for idx in range(sim_results[\"x\"].shape[1]):\n", - " time_points = np.array(sim_results[\"ts\"])\n", - " state_values = np.array(sim_results[\"x\"][:, idx])\n", - " plt.plot(time_points, state_values, label=jax_model.state_ids[idx])\n", + " for ix in range(results[\"x\"].shape[2]):\n", + " time_points = np.array(results[\"ts\"][ic, :])\n", + " state_values = np.array(results[\"x\"][ic, :, ix])\n", + " plt.plot(time_points, state_values, label=jax_model.state_ids[ix])\n", "\n", " # Add labels, legend, and grid\n", " plt.xlabel(\"Time\")\n", @@ -171,41 +201,53 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "72f1ed397105e14a" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", - "id": "4fa97c33719c2277" + "id": "4fa97c33719c2277", + "metadata": {}, + "source": [ + "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7950774a3e989042", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.640281Z", + "start_time": "2025-01-29T15:49:13.637222Z" + } + }, + "outputs": [], "source": [ "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", "results" - ], - "id": "7950774a3e989042" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "98b8516a75ce4d12", + "metadata": {}, "source": [ "## Updating Parameters\n", "\n", "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." - ], - "id": "98b8516a75ce4d12" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "3d278a3d21e709d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.690093Z", + "start_time": "2025-01-29T15:49:13.666663Z" + } + }, + "outputs": [], "source": [ "from dataclasses import FrozenInstanceError\n", "import jax\n", @@ -223,24 +265,29 @@ " jax_problem.parameters += noise\n", "except FrozenInstanceError as e:\n", " print(\"Error:\", e)" - ], - "id": "3d278a3d21e709d" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "4cc3d595de4a4085", + "metadata": {}, "source": [ "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", "\n", "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." - ], - "id": "4cc3d595de4a4085" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e47748376059628b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.810758Z", + "start_time": "2025-01-29T15:49:13.712463Z" + } + }, + "outputs": [], "source": [ "# Update the parameters and create a new JAXProblem instance\n", "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", @@ -250,105 +297,151 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "e47748376059628b" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "660baf605a4e8339", + "metadata": {}, "source": [ "## Computing Gradients\n", "\n", "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." - ], - "id": "660baf605a4e8339" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7033d09cc81b7f69", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:13.824702Z", + "start_time": "2025-01-29T15:49:13.821212Z" + } + }, + "outputs": [], "source": [ "try:\n", " # Attempt to compute the gradient of the run_simulations function\n", " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", "except TypeError as e:\n", " print(\"Error:\", e)" - ], - "id": "7033d09cc81b7f69" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", - "id": "dc9bc07cde00a926" + "id": "dc9bc07cde00a926", + "metadata": {}, + "source": [ + "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "a6704182200e6438", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:20.085633Z", + "start_time": "2025-01-29T15:49:13.853364Z" + } + }, + "outputs": [], "source": [ "import equinox as eqx\n", "\n", "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" - ], - "id": "a6704182200e6438" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", - "id": "851c3ec94cb5d086" + "id": "851c3ec94cb5d086", + "metadata": {}, + "source": [ + "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad.parameters", - "id": "c00c1581d7173d7a" + "id": "c00c1581d7173d7a", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:20.096400Z", + "start_time": "2025-01-29T15:49:20.093962Z" + } + }, + "outputs": [], + "source": [ + "grad.parameters" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", - "id": "375b835fecc5a022" + "id": "375b835fecc5a022", + "metadata": {}, + "source": [ + "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad", - "id": "f7c17f7459d0151f" + "id": "f7c17f7459d0151f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:20.123274Z", + "start_time": "2025-01-29T15:49:20.120144Z" + } + }, + "outputs": [], + "source": [ + "grad" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", - "id": "8eb7cc3db510c826" + "id": "8eb7cc3db510c826", + "metadata": {}, + "source": [ + "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad._measurements[simulation_condition]", - "id": "3badd4402cf6b8c6" + "id": "3badd4402cf6b8c6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:20.151355Z", + "start_time": "2025-01-29T15:49:20.148297Z" + } + }, + "outputs": [], + "source": [ + "grad._my" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", - "id": "58eb04393a1463d" + "id": "58eb04393a1463d", + "metadata": {}, + "source": [ + "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "1a91aff44b93157", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:21.966714Z", + "start_time": "2025-01-29T15:49:20.188760Z" + } + }, + "outputs": [], "source": [ "import jax.numpy as jnp\n", "import diffrax\n", @@ -356,11 +449,14 @@ "\n", "# Define the simulation condition\n", "simulation_condition = (\"model1_data1\",)\n", + "ic = jax_problem.simulation_conditions.index(simulation_condition)\n", "\n", "# Load condition-specific data\n", - "ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", - " simulation_condition\n", - "]\n", + "ts_dyn = jax_problem._ts_dyn[ic, :]\n", + "ts_posteq = jax_problem._ts_posteq[ic, :]\n", + "my = jax_problem._my[ic, :]\n", + "iys = jax_problem._iys[ic, :]\n", + "iy_trafos = jax_problem._iy_trafos[ic, :]\n", "\n", "# Load parameters for the specified condition\n", "p = jax_problem.load_parameters(simulation_condition[0])\n", @@ -388,24 +484,29 @@ "# Compute the gradient with respect to `ts_dyn`\n", "g = grad_ts_dyn(ts_dyn)\n", "g" - ], - "id": "1a91aff44b93157" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "9f870da7754e139c", + "metadata": {}, "source": [ "## Compilation & Profiling\n", "\n", "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." - ], - "id": "9f870da7754e139c" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "58ebdc110ea7457e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:22.363492Z", + "start_time": "2025-01-29T15:49:22.028899Z" + } + }, + "outputs": [], "source": [ "from time import time\n", "\n", @@ -414,14 +515,19 @@ "\n", "# Define a JIT-compiled gradient function with auxiliary outputs\n", "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" - ], - "id": "58ebdc110ea7457e" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e1242075f7e0faf", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:30.839352Z", + "start_time": "2025-01-29T15:49:22.371391Z" + } + }, + "outputs": [], "source": [ "# Measure the time taken for the first function call (including compilation)\n", "start = time()\n", @@ -432,14 +538,19 @@ "start = time()\n", "gradfun(jax_problem)\n", "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" - ], - "id": "e1242075f7e0faf" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "27181f367ccb1817", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:32.125487Z", + "start_time": "2025-01-29T15:49:30.847973Z" + } + }, + "outputs": [], "source": [ "%%timeit\n", "run_simulations(\n", @@ -452,14 +563,19 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "27181f367ccb1817" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "5b8d3a6162a3ae55", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:37.566080Z", + "start_time": "2025-01-29T15:49:32.193598Z" + } + }, + "outputs": [], "source": [ "%%timeit \n", "gradfun(\n", @@ -472,14 +588,19 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "5b8d3a6162a3ae55" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "d733a450635a749b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-29T15:49:52.877239Z", + "start_time": "2025-01-29T15:49:37.633290Z" + } + }, + "outputs": [], "source": [ "from amici.petab import simulate_petab\n", "import amici\n", @@ -494,23 +615,22 @@ "# Configure the solver with appropriate tolerances\n", "solver = amici_model.getSolver()\n", "solver.setAbsoluteTolerance(1e-8)\n", - "solver.setRelativeTolerance(1e-8)\n", + "solver.setRelativeTolerance(1e-16)\n", "\n", "# Prepare the parameters for the simulation\n", "problem_parameters = dict(\n", " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", ")" - ], - "id": "d733a450635a749b" + ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "413ed7c60b2cf4be", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:51:55.259985Z", - "start_time": "2024-11-19T09:51:55.257937Z" + "end_time": "2025-01-29T15:49:52.891165Z", + "start_time": "2025-01-29T15:49:52.889250Z" } }, "outputs": [], @@ -521,23 +641,15 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "768fa60e439ca8b4", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:51:57.417608Z", - "start_time": "2024-11-19T09:51:55.273367Z" + "end_time": "2025-01-29T15:50:06.598838Z", + "start_time": "2025-01-29T15:49:52.902527Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "26.1 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], + "outputs": [], "source": [ "%%timeit \n", "simulate_petab(\n", @@ -552,12 +664,12 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "b8382b0b2b68f49e", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:51:57.497361Z", - "start_time": "2024-11-19T09:51:57.494502Z" + "end_time": "2025-01-29T15:50:06.660478Z", + "start_time": "2025-01-29T15:50:06.658434Z" } }, "outputs": [], @@ -569,23 +681,15 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "3bae1fab8c416122", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:51:59.897459Z", - "start_time": "2024-11-19T09:51:57.511889Z" + "end_time": "2025-01-29T15:50:22.127188Z", + "start_time": "2025-01-29T15:50:06.673328Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "29.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], + "outputs": [], "source": [ "%%timeit \n", "simulate_petab(\n", @@ -600,12 +704,12 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "71e0358227e1dc74", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:51:59.972149Z", - "start_time": "2024-11-19T09:51:59.969006Z" + "end_time": "2025-01-29T15:50:22.195899Z", + "start_time": "2025-01-29T15:50:22.193851Z" } }, "outputs": [], @@ -617,23 +721,15 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "e3cc7971002b6d06", "metadata": { "ExecuteTime": { - "end_time": "2024-11-19T09:52:03.266074Z", - "start_time": "2024-11-19T09:51:59.992465Z" + "end_time": "2025-01-29T15:50:24.178434Z", + "start_time": "2025-01-29T15:50:22.207474Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "39.3 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], + "outputs": [], "source": [ "%%timeit \n", "simulate_petab(\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 8b2c09fcc6..616431dd94 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -479,6 +479,7 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: r""" @@ -510,10 +511,15 @@ def simulate_condition( :param adjoint: adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs). + :param steady_state_event: + event function for steady state. See :func:`diffrax.steady_state_event` for details. :param max_steps: maximum number of solver steps :param ret: which output to return. See :class:`ReturnValue` for available options. + :param ts_mask: + mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of + the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. :return: output according to `ret` and statistics """ @@ -522,6 +528,9 @@ def simulate_condition( else: x = self._x0(p) + if not ts_mask.shape[0]: + ts_mask = jnp.ones_like(my, dtype=jnp.bool_) + # Re-initialization if x_reinit.shape[0]: x = jnp.where(mask_reinit, x_reinit, x) @@ -566,9 +575,11 @@ def simulate_condition( ) ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0) + x = jnp.concatenate((x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) + nllhs = jnp.where(ts_mask, nllhs, 0.0) llh = -jnp.sum(nllhs) stats = dict( @@ -608,12 +619,13 @@ def simulate_condition( ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) m_obj = obs_trafo(my, iy_trafos) if ret == ReturnValue.chi2: - output = jnp.sum( - jnp.square(ys_obj - m_obj) - / jnp.square(self._sigmays(ts, x, p, tcl, iys)) - ) + sigma_obj = self._sigmays(ts, x, p, tcl, iys) + chi2 = jnp.square((ys_obj - m_obj) / sigma_obj) + chi2 = jnp.where(ts_mask, chi2, 0.0) + output = jnp.sum(chi2) else: output = ys_obj - m_obj + output = jnp.where(ts_mask, output, 0.0) else: raise NotImplementedError(f"Return value {ret} not implemented.") diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 43498ce536..c47a00e1e3 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -80,18 +80,15 @@ class JAXProblem(eqx.Module): parameters: jnp.ndarray model: JAXModel + simulation_conditions: tuple[tuple[str, ...], ...] _parameter_mappings: dict[str, ParameterMappingForCondition] - _measurements: dict[ - tuple[str, ...], - tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - ], - ] - _petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]] + _ts_dyn: np.ndarray + _ts_posteq: np.ndarray + _my: np.ndarray + _iys: np.ndarray + _iy_trafos: np.ndarray + _ts_masks: np.ndarray + _petab_measurement_indices: np.ndarray _petab_problem: petab.Problem def __init__(self, model: JAXModel, petab_problem: petab.Problem): @@ -105,11 +102,19 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ self.model = model scs = petab_problem.get_simulation_conditions_from_measurement_df() + self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) self._petab_problem = petab_problem self._parameter_mappings = self._get_parameter_mappings(scs) - self._measurements, self._petab_measurement_indices = ( - self._get_measurements(scs) - ) + ( + self._ts_dyn, + self._ts_posteq, + self._my, + self._iys, + self._iy_trafos, + self._ts_masks, + self._petab_measurement_indices, + ) = self._get_measurements(scs) + self.parameters = self._get_nominal_parameter_values() def save(self, directory: Path): @@ -180,17 +185,13 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> tuple[ - dict[ - tuple[str, ...], - tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - ], - ], - dict[tuple[str, ...], tuple[int, ...]], + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. @@ -199,11 +200,17 @@ def _get_measurements( Simulation conditions to create parameter mappings for. Same format as returned by :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. :return: - Dictionary mapping simulation conditions to measurements (tuple of pre-equilibrium, dynamic, - post-equilibrium time points; measurements and observable indices). + tuple of padded + - dynamic time points + - post-equilibrium time points + - measurements + - observable indices + - observable transformations indices + - measurement masks + - data indices (index in petab measurement dataframe). """ measurements = dict() - indices = dict() + petab_indices = dict() for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -249,8 +256,87 @@ def _get_measurements( iys, iy_trafos, ) - indices[tuple(simulation_condition)] = tuple(index.tolist()) - return measurements, indices + petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) + + # compute maximum lengths + n_ts_dyn = max( + len(ts_dyn) for ts_dyn, _, _, _, _ in measurements.values() + ) + n_ts_posteq = max( + len(ts_posteq) for _, ts_posteq, _, _, _ in measurements.values() + ) + + # pad with last value and stack + ts_dyn = np.stack( + [ + np.pad(x, (0, n_ts_dyn - len(x)), mode="edge") + for x, _, _, _, _ in measurements.values() + ] + ) + ts_posteq = np.stack( + [ + np.pad(x, (0, n_ts_posteq - len(x)), mode="edge") + for _, x, _, _, _ in measurements.values() + ] + ) + + def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): + return np.concatenate( + ( + np.pad(x_dyn, (0, n_ts_dyn - len(x_dyn)), mode="edge"), + np.pad(x_peq, (0, n_ts_posteq - len(x_peq)), mode="edge"), + ) + ) + + my = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, x, _, _ in measurements.values() + ] + ) + iys = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, x, _ in measurements.values() + ] + ) + iy_trafos = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, _, x in measurements.values() + ] + ) + ts_masks = np.stack( + [ + np.concatenate( + ( + np.pad(np.ones_like(tdyn), (0, n_ts_dyn - len(tdyn))), + np.pad( + np.ones_like(tpeq), (0, n_ts_posteq - len(tpeq)) + ), + ) + ) + for tdyn, tpeq, _, _, _ in measurements.values() + ] + ).astype(bool) + petab_indices = np.stack( + [ + pad_measurement( + idx[: len(tdyn)], idx[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for (tdyn, tpeq, _, _, _), idx in zip( + measurements.values(), petab_indices.values() + ) + ] + ) + + return ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -462,9 +548,53 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ return eqx.tree_at(lambda p: p.parameters, self, p) + def _prepare_conditions( + self, conditions: Iterable[str] + ) -> tuple[ + jt.Float[jt.Array, "np"], # noqa: F821 + jt.Bool[jt.Array, "nx"], # noqa: F821 + jt.Float[jt.Array, "nx"], # noqa: F821 + ]: + """ + Prepare conditions for simulation. + + :param conditions: + Simulation conditions to prepare. + :return: + Tuple of parameter arrays, reinitialisation masks and reinitialisation values. + """ + p_array = jnp.stack([self.load_parameters(sc) for sc in conditions]) + + mask_reinit_array = jnp.stack( + [ + self.load_reinitialisation(sc, p)[0] + for sc, p in zip(conditions, p_array) + ] + ) + x_reinit_array = jnp.stack( + [ + self.load_reinitialisation(sc, p)[1] + for sc, p in zip(conditions, p_array) + ] + ) + return p_array, mask_reinit_array, x_reinit_array + + @eqx.filter_vmap( + in_axes={ + "max_steps": None, + "self": None, + }, # only list arguments here where eqx.is_array(0) is not the right thing + ) def run_simulation( self, - simulation_condition: tuple[str, ...], + p: jt.Float[jt.Array, "np"], # noqa: F821, F722 + ts_dyn: np.ndarray, + ts_posteq: np.ndarray, + my: np.ndarray, + iys: np.ndarray, + iy_trafos: np.ndarray, + mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, steady_state_event: Callable[ @@ -472,33 +602,47 @@ def run_simulation( ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 + ts_mask: np.ndarray = np.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. - :param simulation_condition: - Simulation condition to run simulation for. + :param p: + Parameters for the simulation condition + :param ts_dyn: + (Padded) dynamic time points + :param ts_posteq: + (Padded) post-equilibrium time points + :param my: + (Padded) measurements + :param iys: + (Padded) observable indices + :param iy_trafos: + (Padded) observable transformations indices + :param mask_reinit: + Mask for states that need reinitialisation + :param x_reinit: + Reinitialisation values for states :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation + :param steady_state_event: + Steady state event function to use for post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation :param x_preeq: - Pre-equilibration state if available + Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will + be initialised to the model default values. + :param ts_mask: + padding mask, see :meth:`JAXModel.simulate_condition` for details. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: Tuple of output value and simulation statistics """ - ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ - simulation_condition - ] - p = self.load_parameters(simulation_condition[0]) - mask_reinit, x_reinit = self.load_reinitialisation( - simulation_condition[0], p - ) return self.model.simulate_condition( p=p, ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), @@ -509,6 +653,7 @@ def run_simulation( x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, + ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, max_steps=max_steps, @@ -519,9 +664,73 @@ def run_simulation( ret=ret, ) + def run_simulations( + self, + simulation_conditions: list[str], + preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], + max_steps: jnp.int_, + ret: ReturnValue = ReturnValue.llh, + ): + """ + Run simulations for a list of simulation conditions. + + :param simulation_conditions: + List of simulation conditions to run simulations for. + :param preeq_array: + Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation + conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty. + :param solver: + ODE solver to use for simulation. + :param controller: + Step size controller to use for simulation. + :param steady_state_event: + Steady state event function to use for post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. + :param max_steps: + Maximum number of steps to take during simulation. + :param ret: + which output to return. See :class:`ReturnValue` for available options. + :return: + Output value and condition specific results and statistics. Results and statistics are returned as a dict + with arrays with the leading dimension corresponding to the simulation conditions. + """ + p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( + simulation_conditions + ) + return self.run_simulation( + p_array, + self._ts_dyn, + self._ts_posteq, + self._my, + self._iys, + self._iy_trafos, + mask_reinit_array, + x_reinit_array, + solver, + controller, + steady_state_event, + max_steps, + preeq_array, + self._ts_masks, + ret, + ) + + @eqx.filter_vmap( + in_axes={ + "max_steps": None, + "self": None, + }, # only list arguments here where eqx.is_array(0) is not the right thing + ) def run_preequilibration( self, - simulation_condition: str, + p: jt.Float[jt.Array, "np"], # noqa: F821, F722 + mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, steady_state_event: Callable[ @@ -532,21 +741,24 @@ def run_preequilibration( """ Run a pre-equilibration simulation for a given simulation condition. - :param simulation_condition: - Simulation condition to run simulation for. + :param p: + Parameters for the simulation condition + :param mask_reinit: + Mask for states that need reinitialisation + :param x_reinit: + Reinitialisation values for states :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation + :param steady_state_event: + Steady state event function to use for pre-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation :return: Pre-equilibration state """ - p = self.load_parameters(simulation_condition) - mask_reinit, x_reinit = self.load_reinitialisation( - simulation_condition, p - ) return self.model.preequilibrate_condition( p=p, mask_reinit=mask_reinit, @@ -557,6 +769,29 @@ def run_preequilibration( steady_state_event=steady_state_event, ) + def run_preequilibrations( + self, + simulation_conditions: list[str], + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], + max_steps: jnp.int_, + ): + p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( + simulation_conditions + ) + return self.run_preequilibration( + p_array, + mask_reinit_array, + x_reinit_array, + solver, + controller, + steady_state_event, + max_steps, + ) + def run_simulations( problem: JAXProblem, @@ -577,7 +812,9 @@ def run_simulations( :param problem: Problem to run simulations for. :param simulation_conditions: - Simulation conditions to run simulations for. + Simulation conditions to run simulations for. This is a series of tuples, where each tuple contains the + simulation condition or the pre-equilibration condition followed by the simulation condition. Default is to run + simulations for all conditions. :param solver: ODE solver to use for simulation. :param controller: @@ -598,36 +835,62 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() - preeqs = { - sc: problem.run_preequilibration( - sc, solver, controller, steady_state_event, max_steps - ) - # only run preequilibration once per condition - for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} + dynamic_conditions = [sc[0] for sc in simulation_conditions] + preequilibration_conditions = list( + {sc[1] for sc in simulation_conditions if len(sc) > 1} + ) + + conditions = { + "dynamic_conditions": dynamic_conditions, + "preequilibration_conditions": preequilibration_conditions, + "simulation_conditions": simulation_conditions, } - results = { - sc: problem.run_simulation( - sc, + if preequilibration_conditions: + preeqs, preresults = problem.run_preequilibrations( + preequilibration_conditions, solver, controller, steady_state_event, max_steps, - preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), - ret=ret, ) - for sc in simulation_conditions - } - stats = { - sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] - for sc, res in results.items() - } - if ret in (ReturnValue.llh, ReturnValue.chi2): - output = sum(r for r, _ in results.values()) else: - output = {sc: res[0] for sc, res in results.items()} + preresults = { + "stats_preeq": None, + } + + if dynamic_conditions: + preeq_array = jnp.stack( + [ + preeqs[preequilibration_conditions.index(sc[1]), :] + if len(sc) > 1 + else jnp.array([]) + for sc in simulation_conditions + ] + ) + output, results = problem.run_simulations( + dynamic_conditions, + preeq_array, + solver, + controller, + steady_state_event, + max_steps, + ret, + ) + else: + output = jnp.array(0.0) + results = { + "llh": jnp.array([]), + "stats_dyn": None, + "stats_posteq": None, + "ts": jnp.array([]), + "x": jnp.array([]), + } + + if ret in (ReturnValue.llh, ReturnValue.chi2): + output = jnp.sum(output) - return output, stats + return output, results | preresults | conditions def petab_simulate( @@ -652,6 +915,9 @@ def petab_simulate( Step size controller to use for simulation. :param max_steps: Maximum number of steps to take during simulation. + :param steady_state_event: + Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :return: petab simulation dataframe. """ @@ -664,20 +930,25 @@ def petab_simulate( ret=ReturnValue.y, ) dfs = [] - for sc, ys in y.items(): + for ic, sc in enumerate(r["dynamic_conditions"]): obs = [ problem.model.observable_ids[io] - for io in problem._measurements[sc][3] + for io in problem._iys[ic, problem._ts_masks[ic, :]] ] - t = jnp.concat(problem._measurements[sc][:2]) + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) + ) df_sc = pd.DataFrame( { - petab.SIMULATION: ys, - petab.TIME: t, + petab.SIMULATION: y[ic, problem._ts_masks[ic, :]], + petab.TIME: t[problem._ts_masks[ic, :]], petab.OBSERVABLE_ID: obs, - petab.SIMULATION_CONDITION_ID: [sc[0]] * len(t), + petab.SIMULATION_CONDITION_ID: [sc] * len(t), }, - index=problem._petab_measurement_indices[sc], + index=problem._petab_measurement_indices[ic, :], ) if ( petab.OBSERVABLE_PARAMETERS @@ -685,19 +956,23 @@ def petab_simulate( ): df_sc[petab.OBSERVABLE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" )[petab.OBSERVABLE_PARAMETERS] ) if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df: df_sc[petab.NOISE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" )[petab.NOISE_PARAMETERS] ) if ( petab.PREEQUILIBRATION_CONDITION_ID in problem._petab_problem.measurement_df ): - df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = sc[1] + df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = ( + problem._petab_problem.measurement_df.query( + f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" + )[petab.PREEQUILIBRATION_CONDITION_ID] + ) dfs.append(df_sc) return pd.concat(dfs).sort_index() diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index d9f836b0b4..2f3fbb433a 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -274,16 +274,6 @@ def test_jax_llh(benchmark_problem): problem_id, petab_problem, amici_model = benchmark_problem - if problem_id in ( - "Bachmann_MSB2011", - "Isensee_JCB2018", - "Lucarelli_CellSystems2018", - "SalazarCavazos_MBoC2020", - "Smith_BMCSystBiol2013", - ): - # confirmed to work (no gradients) 27/10/2024 but experienced high local runtime (M2 MBA, >30s) - pytest.skip("Excluded from JAX check due to excessive runtime") - amici_solver = amici_model.getSolver() cur_settings = settings[problem_id] amici_solver.setAbsoluteTolerance(1e-8)