From deb8f2d2d83d351a425e3021ea3c8989dd4c568d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 31 Jan 2025 15:25:05 +0000 Subject: [PATCH] vectorisation go brrrrr --- python/sdist/amici/jax/petab.py | 284 +++++++++++++++++++------------- 1 file changed, 165 insertions(+), 119 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 7de1553403..82b9ca8b39 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -89,8 +89,12 @@ class JAXProblem(eqx.Module): _iys: np.ndarray _iy_trafos: np.ndarray _ts_masks: np.ndarray - _ops: np.ndarray - _nps: np.ndarray + _op_numeric: np.ndarray + _op_mask: np.ndarray + _op_indices: np.ndarray + _np_numeric: np.ndarray + _np_mask: np.ndarray + _np_indices: np.ndarray _petab_measurement_indices: np.ndarray _petab_problem: petab.Problem @@ -116,8 +120,12 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): self._iy_trafos, self._ts_masks, self._petab_measurement_indices, - self._ops, - self._nps, + self._op_numeric, + self._op_mask, + self._op_indices, + self._np_numeric, + self._np_mask, + self._np_indices, ) = self._get_measurements(scs) self.parameters = self._get_nominal_parameter_values() @@ -208,6 +216,10 @@ def _get_measurements( np.ndarray, np.ndarray, np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. @@ -224,6 +236,12 @@ def _get_measurements( - observable transformations indices - measurement masks - data indices (index in petab measurement dataframe). + - numeric values for observable parameter overrides + - non-numeric mask for observable parameter overrides + - parameter indices (problem parameters) for observable parameter overrides + - numeric values for noise parameter overrides + - non-numeric mask for noise parameter overrides + - parameter indices (problem parameters) for noise parameter overrides """ measurements = dict() petab_indices = dict() @@ -288,21 +306,41 @@ def _get_measurements( else: iy_trafos = np.zeros_like(iys) - parameter_overrides = dict() + parameter_overrides_par_indices = dict() + parameter_overrides_numeric_vals = dict() + parameter_overrides_mask = dict() + + def get_parameter_override(x): + if ( + x in self._petab_problem.parameter_df.index + and not self._petab_problem.parameter_df.loc[ + x, petab.ESTIMATE + ] + ): + return self._petab_problem.parameter_df.loc[ + x, petab.NOMINAL_VALUE + ] + return x + for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: if col not in m or m[col].isna().all(): - mat = jnp.ones((len(m), n_pars[col])) - + mat_numeric = jnp.ones((len(m), n_pars[col])) + par_mask = np.zeros_like(mat_numeric, dtype=bool) + par_index = np.zeros_like(mat_numeric, dtype=int) elif np.issubdtype(m[col].dtype, np.number): - mat = np.expand_dims(m[col].values, axis=1) + mat_numeric = np.expand_dims(m[col].values, axis=1) + par_mask = np.zeros_like(mat_numeric, dtype=bool) + par_index = np.zeros_like(mat_numeric, dtype=int) else: split_vals = m[col].str.split(";") list_vals = split_vals.apply( - lambda x: x + lambda x: [get_parameter_override(y) for y in x] if isinstance(x, list) else [] if pd.isna(x) - else [float(x)] + else [ + x + ] # every string gets transformed to lists, so this is already a float ) vals = list_vals.apply( lambda x: np.pad( @@ -313,44 +351,63 @@ def _get_measurements( ) ) mat = np.stack(vals) - - parameter_overrides[col] = mat + # deconstruct such that we can reconstruct mapped parameter overrides via vectorized operations + # mat = np.where(par_mask, map(lambda ip: p.at[ip], par_index), mat_numeric) + par_index = np.vectorize( + lambda x: self.parameter_ids.index(x) + if x in self.parameter_ids + else -1 + )(mat) + # map out numeric values + par_mask = par_index != -1 + # remove non-numeric values + mat[par_mask] = 0.0 + mat_numeric = mat.astype(float) + # replace dummy index with some valid index + par_index[~par_mask] = 0 + + parameter_overrides_numeric_vals[col] = mat_numeric + parameter_overrides_mask[col] = par_mask + parameter_overrides_par_indices[col] = par_index measurements[tuple(simulation_condition)] = ( - ts_dyn, - ts_posteq, - my, - iys, - iy_trafos, - parameter_overrides[petab.OBSERVABLE_PARAMETERS], - parameter_overrides[petab.NOISE_PARAMETERS], + ts_dyn, # 0 + ts_posteq, # 1 + my, # 2 + iys, # 3 + iy_trafos, # 4 + parameter_overrides_numeric_vals[ + petab.OBSERVABLE_PARAMETERS + ], # 5 + parameter_overrides_mask[petab.OBSERVABLE_PARAMETERS], # 6 + parameter_overrides_par_indices[ + petab.OBSERVABLE_PARAMETERS + ], # 7 + parameter_overrides_numeric_vals[petab.NOISE_PARAMETERS], # 8 + parameter_overrides_mask[petab.NOISE_PARAMETERS], # 9 + parameter_overrides_par_indices[petab.NOISE_PARAMETERS], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) # compute maximum lengths - n_ts_dyn = max( - len(ts_dyn) for ts_dyn, _, _, _, _, _, _ in measurements.values() - ) - n_ts_posteq = max( - len(ts_posteq) - for _, ts_posteq, _, _, _, _, _ in measurements.values() - ) + n_ts_dyn = max(len(mv[0]) for mv in measurements.values()) + n_ts_posteq = max(len(mv[1]) for mv in measurements.values()) # pad with last value and stack ts_dyn = np.stack( [ - np.pad(x, (0, n_ts_dyn - len(x)), mode="edge") - for x, _, _, _, _, _, _ in measurements.values() + np.pad(mv[0], (0, n_ts_dyn - len(mv[0])), mode="edge") + for mv in measurements.values() ] ) ts_posteq = np.stack( [ - np.pad(x, (0, n_ts_posteq - len(x)), mode="edge") - for _, x, _, _, _, _, _ in measurements.values() + np.pad(mv[1], (0, n_ts_posteq - len(mv[1])), mode="edge") + for mv in measurements.values() ] ) - def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): + def pad_measurement(x_dyn, x_peq): # only pad first axis pad_width_dyn = tuple( [(0, n_ts_dyn - len(x_dyn))] + [(0, 0)] * (x_dyn.ndim - 1) @@ -365,68 +422,48 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): ) ) - my = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, x, _, _, _, _ in measurements.values() - ] - ) - iys = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, x, _, _, _ in measurements.values() - ] - ) - iy_trafos = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, _, x, _, _ in measurements.values() - ] - ) - ops = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, _, _, x, _ in measurements.values() - ] - ) - nps = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, _, _, _, x in measurements.values() - ] - ) + def pad_and_stack(output_index: int): + return np.stack( + [ + pad_measurement( + mv[output_index][: len(mv[0])], + mv[output_index][len(mv[0]) :], + ) + for mv in measurements.values() + ] + ) + + my = pad_and_stack(2) + iys = pad_and_stack(3) + iy_trafos = pad_and_stack(4) + op_numeric = pad_and_stack(5) + op_mask = pad_and_stack(6) + op_indices = pad_and_stack(7) + np_numeric = pad_and_stack(8) + np_mask = pad_and_stack(9) + np_indices = pad_and_stack(10) ts_masks = np.stack( [ np.concatenate( ( - np.pad(np.ones_like(tdyn), (0, n_ts_dyn - len(tdyn))), np.pad( - np.ones_like(tpeq), (0, n_ts_posteq - len(tpeq)) + np.ones_like(mv[0]), (0, n_ts_dyn - len(mv[0])) + ), + np.pad( + np.ones_like(mv[1]), (0, n_ts_posteq - len(mv[1])) ), ) ) - for tdyn, tpeq, _, _, _, _, _ in measurements.values() + for mv in measurements.values() ] ).astype(bool) petab_indices = np.stack( [ pad_measurement( - np.array(idx[: len(tdyn)]), - np.array(idx[len(tdyn) :]), - n_ts_dyn, - n_ts_posteq, + np.array(idx[: len(mv[0])]), + np.array(idx[len(mv[0]) :]), ) - for (tdyn, tpeq, _, _, _, _, _), idx in zip( + for mv, idx in zip( measurements.values(), petab_indices.values() ) ] @@ -440,8 +477,12 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): iy_trafos, ts_masks, petab_indices, - ops, - nps, + op_numeric, + op_mask, + op_indices, + np_numeric, + np_mask, + np_indices, ) def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: @@ -657,14 +698,18 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": def _prepare_conditions( self, conditions: list[str], - op_array: np.ndarray | None, - np_array: np.ndarray | None, + op_numeric: np.ndarray | None = None, + op_mask: np.ndarray | None = None, + op_indices: np.ndarray | None = None, + np_numeric: np.ndarray | None = None, + np_mask: np.ndarray | None = None, + np_indices: np.ndarray | None = None, ) -> tuple[ - jt.Float[jt.Array, "np"], # noqa: F821 + jt.Float[jt.Array, "nc np"], # noqa: F821, F722 jt.Bool[jt.Array, "nx"], # noqa: F821 jt.Float[jt.Array, "nx"], # noqa: F821 - jt.Float[jt.Array, "nt *nop"], # noqa: F821 - jt.Float[jt.Array, "nt *nnp"], # noqa: F821 + jt.Float[jt.Array, "nc nt nop"], # noqa: F821, F722 + jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722 ]: """ Prepare conditions for simulation. @@ -678,44 +723,39 @@ def _prepare_conditions( noise parameters. """ p_array = jnp.stack([self.load_parameters(sc) for sc in conditions]) - - def map_parameter(x, p): - if x in self.model.parameter_ids: - return p[self.model.parameter_ids.index(x)] - if x in self.parameter_ids: - return jax_unscale( - self.get_petab_parameter_by_id(x), + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], self._petab_problem.parameter_df.loc[ - x, petab.PARAMETER_SCALE + p_id, petab.PARAMETER_SCALE ], ) - if x in self._petab_problem.parameter_df.index: - return self._petab_problem.parameter_df.loc[ - x, petab.NOMINAL_VALUE - ] - if isinstance(x, str): - return float(x) - return x + for ip, p_id in enumerate(self.parameter_ids) + ] + ) - if op_array is not None and op_array.size: - op_array = jnp.stack( - [ - jnp.array( - [map_parameter(x, p) for x in op_array[ic, :].ravel()] - ).reshape(op_array.shape[1:]) - for ic, p in enumerate(p_array) - ] + if op_numeric is not None and op_numeric.size: + op_array = jnp.where( + op_mask, + jax.vmap( + jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) + )(op_indices), + op_numeric, ) - - if np_array is not None and np_array.size: - np_array = jnp.stack( - [ - jnp.array( - [map_parameter(x, p) for x in np_array[ic, :].ravel()] - ).reshape(np_array.shape[1:]) - for ic, p in enumerate(p_array) - ] + else: + op_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) + + if np_numeric is not None and np_numeric.size: + np_array = jnp.where( + np_mask, + jax.vmap( + jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) + )(np_indices), + np_numeric, ) + else: + np_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) mask_reinit_array = jnp.stack( [ @@ -861,7 +901,13 @@ def run_simulations( """ p_array, mask_reinit_array, x_reinit_array, op_array, np_array = ( self._prepare_conditions( - simulation_conditions, self._ops, self._nps + simulation_conditions, + self._op_numeric, + self._op_mask, + self._op_indices, + self._np_numeric, + self._np_mask, + self._np_indices, ) ) return self.run_simulation(