From a9abbee0ad9459a0cb4465a8f9c9ab3ceaab9f35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 29 Jan 2025 09:57:03 +0000 Subject: [PATCH] fix vectorisation --- python/sdist/amici/jax/model.py | 12 +- python/sdist/amici/jax/petab.py | 256 ++++++++++++++---- .../benchmark-models/test_petab_benchmark.py | 10 - 3 files changed, 206 insertions(+), 72 deletions(-) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 8b2c09fcc6..268b595170 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -479,6 +479,7 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: r""" @@ -522,6 +523,9 @@ def simulate_condition( else: x = self._x0(p) + if not ts_mask.shape[0]: + ts_mask = jnp.zeros_like(my, dtype=jnp.bool_) + # Re-initialization if x_reinit.shape[0]: x = jnp.where(mask_reinit, x_reinit, x) @@ -566,9 +570,11 @@ def simulate_condition( ) ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0) + x = jnp.concatenate((x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) + nllhs = jnp.where(ts_mask, nllhs, 0.0) llh = -jnp.sum(nllhs) stats = dict( @@ -608,10 +614,8 @@ def simulate_condition( ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) m_obj = obs_trafo(my, iy_trafos) if ret == ReturnValue.chi2: - output = jnp.sum( - jnp.square(ys_obj - m_obj) - / jnp.square(self._sigmays(ts, x, p, tcl, iys)) - ) + sigma_obj = self._sigmays(ts, x, p, tcl, iys) + output = jnp.sum(jnp.square((ys_obj - m_obj) / sigma_obj)) else: output = ys_obj - m_obj else: diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 43498ce536..cfc3a3d8ca 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -81,17 +81,13 @@ class JAXProblem(eqx.Module): parameters: jnp.ndarray model: JAXModel _parameter_mappings: dict[str, ParameterMappingForCondition] - _measurements: dict[ - tuple[str, ...], - tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - ], - ] - _petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]] + _ts_dyn: np.ndarray + _ts_posteq: np.ndarray + _my: np.ndarray + _iys: np.ndarray + _iy_trafos: np.ndarray + _ts_masks: np.ndarray + _petab_measurement_indices: np.ndarray _petab_problem: petab.Problem def __init__(self, model: JAXModel, petab_problem: petab.Problem): @@ -107,9 +103,16 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): scs = petab_problem.get_simulation_conditions_from_measurement_df() self._petab_problem = petab_problem self._parameter_mappings = self._get_parameter_mappings(scs) - self._measurements, self._petab_measurement_indices = ( - self._get_measurements(scs) - ) + ( + self._ts_dyn, + self._ts_posteq, + self._my, + self._iys, + self._iy_trafos, + self._ts_masks, + self._petab_measurement_indices, + ) = self._get_measurements(scs) + self.parameters = self._get_nominal_parameter_values() def save(self, directory: Path): @@ -180,17 +183,13 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> tuple[ - dict[ - tuple[str, ...], - tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - ], - ], - dict[tuple[str, ...], tuple[int, ...]], + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. @@ -199,11 +198,17 @@ def _get_measurements( Simulation conditions to create parameter mappings for. Same format as returned by :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. :return: - Dictionary mapping simulation conditions to measurements (tuple of pre-equilibrium, dynamic, - post-equilibrium time points; measurements and observable indices). + tuple of padded + - dynamic time points + - post-equilibrium time points + - measurements + - observable indices + - observable transformations indices + - measurement masks + - data indices (index in petab measurement dataframe). """ measurements = dict() - indices = dict() + petab_indices = dict() for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -249,8 +254,89 @@ def _get_measurements( iys, iy_trafos, ) - indices[tuple(simulation_condition)] = tuple(index.tolist()) - return measurements, indices + petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) + + # compute maximum lengths + n_ts_dyn = max( + len(ts_dyn) for ts_dyn, _, _, _, _ in measurements.values() + ) + n_ts_posteq = max( + len(ts_posteq) for _, ts_posteq, _, _, _ in measurements.values() + ) + + # pad with last value and stack + ts_dyn = np.stack( + [ + np.pad(x, (0, n_ts_dyn - len(x)), mode="edge") + for x, _, _, _, _ in measurements.values() + ] + ) + ts_posteq = np.stack( + [ + np.pad(x, (0, n_ts_posteq - len(x)), mode="edge") + for _, x, _, _, _ in measurements.values() + ] + ) + + def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): + return np.concatenate( + ( + np.pad(x_dyn, (0, n_ts_dyn - len(x_dyn)), mode="edge"), + np.pad(x_peq, (0, n_ts_posteq - len(x_peq)), mode="edge"), + ) + ) + + my = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, x, _, _ in measurements.values() + ] + ) + iys = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, x, _ in measurements.values() + ] + ) + iy_trafos = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, _, x in measurements.values() + ] + ) + ts_masks = np.stack( + [ + np.concatenate( + ( + np.pad(np.ones_like(tdyn), (0, n_ts_dyn - len(tdyn))), + np.pad( + np.ones_like(tpeq), (0, n_ts_posteq - len(tpeq)) + ), + ) + ) + for tdyn, tpeq, _, _, _ in measurements.values() + ] + ).astype(bool) + petab_indices = np.stack( + [ + pad_measurement( + idx[: len(tdyn)], idx[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for (tdyn, tpeq, _, _, _), idx in zip( + measurements.values(), petab_indices.values() + ) + ] + ) + + # create mask for my + + return ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -464,7 +550,14 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": def run_simulation( self, - simulation_condition: tuple[str, ...], + p: jt.Float[jt.Array, "np"], # noqa: F821, F722 + ts_dyn: np.ndarray, + ts_posteq: np.ndarray, + my: np.ndarray, + iys: np.ndarray, + iy_trafos: np.ndarray, + mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, steady_state_event: Callable[ @@ -472,13 +565,12 @@ def run_simulation( ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 + ts_mask: np.ndarray = np.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. - :param simulation_condition: - Simulation condition to run simulation for. :param solver: ODE solver to use for simulation :param controller: @@ -492,13 +584,6 @@ def run_simulation( :return: Tuple of output value and simulation statistics """ - ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ - simulation_condition - ] - p = self.load_parameters(simulation_condition[0]) - mask_reinit, x_reinit = self.load_reinitialisation( - simulation_condition[0], p - ) return self.model.simulate_condition( p=p, ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), @@ -509,6 +594,7 @@ def run_simulation( x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, + ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, max_steps=max_steps, @@ -598,6 +684,22 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() + p_array = jnp.stack( + [problem.load_parameters(sc[0]) for sc in simulation_conditions] + ) + mask_reinit_array = jnp.stack( + [ + problem.load_reinitialisation(sc[0], p)[0] + for sc, p in zip(simulation_conditions, p_array) + ] + ) + x_reinit_array = jnp.stack( + [ + problem.load_reinitialisation(sc[0], p)[1] + for sc, p in zip(simulation_conditions, p_array) + ] + ) + preeqs = { sc: problem.run_preequilibration( sc, solver, controller, steady_state_event, max_steps @@ -606,26 +708,64 @@ def run_simulations( for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} } - results = { - sc: problem.run_simulation( - sc, - solver, - controller, - steady_state_event, - max_steps, - preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), - ret=ret, - ) - for sc in simulation_conditions - } - stats = { - sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] - for sc, res in results.items() - } + preeq_array = jnp.stack( + [ + preeqs[sc[1]][0] if len(sc) > 1 else jnp.array([]) + for sc in simulation_conditions + ] + ) + + parallel_run_simulation = eqx.filter_vmap( + JAXProblem.run_simulation, + in_axes=( + None, # problem + 0, # p + 0, # ts_dyn + 0, # ts_posteq + 0, # my + 0, # iys + 0, # iy_trafos + 0, # mask_reinit + 0, # x_reinit + None, # solver + None, # controller + None, # steady_state_event + None, # max_steps + 0, # preeq_array + 0, # ts_masks + None, # ret + ), + ) + + results = parallel_run_simulation( + problem, + p_array, + problem._ts_dyn, + problem._ts_posteq, + problem._my, + problem._iys, + problem._iy_trafos, + mask_reinit_array, + x_reinit_array, + solver, + controller, + steady_state_event, + max_steps, + preeq_array, + problem._ts_masks, + ret, + ) + if ret in (ReturnValue.llh, ReturnValue.chi2): - output = sum(r for r, _ in results.values()) + output = jnp.sum(results[0]) else: - output = {sc: res[0] for sc, res in results.items()} + output = results[0] + + stats = ( + results[1]["stats_dyn"] | results[1]["stats_posteq"] + if results[1]["stats_posteq"] + else results[1]["stats_dyn"] + ) return output, stats diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index d9f836b0b4..2f3fbb433a 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -274,16 +274,6 @@ def test_jax_llh(benchmark_problem): problem_id, petab_problem, amici_model = benchmark_problem - if problem_id in ( - "Bachmann_MSB2011", - "Isensee_JCB2018", - "Lucarelli_CellSystems2018", - "SalazarCavazos_MBoC2020", - "Smith_BMCSystBiol2013", - ): - # confirmed to work (no gradients) 27/10/2024 but experienced high local runtime (M2 MBA, >30s) - pytest.skip("Excluded from JAX check due to excessive runtime") - amici_solver = amici_model.getSolver() cur_settings = settings[problem_id] amici_solver.setAbsoluteTolerance(1e-8)