From 7302481c32efdbd16e90b35089943ec106067661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 30 Jan 2025 23:57:57 +0000 Subject: [PATCH] first working implementation --- python/sdist/amici/constants.py | 2 + python/sdist/amici/de_model.py | 16 ++ python/sdist/amici/de_model_components.py | 42 ++++ python/sdist/amici/jax/jax.template.py | 12 +- python/sdist/amici/jax/model.py | 113 +++++++--- python/sdist/amici/jax/ode_export.py | 13 +- python/sdist/amici/jax/petab.py | 205 +++++++++++++++--- python/sdist/amici/petab/sbml_import.py | 16 +- python/sdist/amici/sbml_import.py | 31 +++ .../benchmark-models/test_petab_benchmark.py | 24 +- 10 files changed, 394 insertions(+), 80 deletions(-) diff --git a/python/sdist/amici/constants.py b/python/sdist/amici/constants.py index 74b365889c..346dc1c9ab 100644 --- a/python/sdist/amici/constants.py +++ b/python/sdist/amici/constants.py @@ -34,3 +34,5 @@ class SymbolId(str, enum.Enum): SIGMAZ = "sigmaz" LLHZ = "llhz" LLHRZ = "llhrz" + NOISE_PARAMETER = "noise_parameter" + OBSERVABLE_PARAMETER = "observable_parameter" diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 8ad2e7a998..44475a5556 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -35,6 +35,8 @@ LogLikelihoodY, LogLikelihoodZ, LogLikelihoodRZ, + NoiseParameter, + ObservableParameter, Expression, ConservationLaw, Event, @@ -226,6 +228,8 @@ def __init__( self._log_likelihood_ys: list[LogLikelihoodY] = [] self._log_likelihood_zs: list[LogLikelihoodZ] = [] self._log_likelihood_rzs: list[LogLikelihoodRZ] = [] + self._noise_parameters: list[NoiseParameter] = [] + self._observable_parameters: list[ObservableParameter] = [] self._expressions: list[Expression] = [] self._conservation_laws: list[ConservationLaw] = [] self._events: list[Event] = [] @@ -273,6 +277,8 @@ def __init__( "sigmay": self.sigma_ys, "sigmaz": self.sigma_zs, "h": self.events, + "np": self.noise_parameters, + "op": self.observable_parameters, } self._value_prototype: dict[str, Callable] = { "p": self.parameters, @@ -385,6 +391,14 @@ def log_likelihood_rzs(self) -> list[LogLikelihoodRZ]: """Get all event observable regularization log likelihoods.""" return self._log_likelihood_rzs + def noise_parameters(self) -> list[NoiseParameter]: + """Get all noise parameters.""" + return self._noise_parameters + + def observable_parameters(self) -> list[ObservableParameter]: + """Get all observable parameters.""" + return self._observable_parameters + def is_ode(self) -> bool: """Check if model is ODE model.""" return len(self._algebraic_equations) == 0 @@ -565,6 +579,8 @@ def add_component( ConservationLaw, Event, EventObservable, + NoiseParameter, + ObservableParameter, }: raise ValueError(f"Invalid component type {type(component)}") diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index bc93f44b87..30624dbc9e 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -607,6 +607,46 @@ def __init__( super().__init__(identifier, name, value) +class NoiseParameter(ModelQuantity): + """ + A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point + specific manner, abbreviated by ``np``. Only used for jax models. + """ + + def __init__(self, identifier: sp.Symbol, name: str): + """ + Create a new Expression instance. + + :param identifier: + unique identifier of the NoiseParameter + + :param name: + individual name of the NoiseParameter (does not need to be + unique) + """ + super().__init__(identifier, name, 0.0) + + +class ObservableParameter(ModelQuantity): + """ + A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific + manner, abbreviated by ``op``. Only used for jax models. + """ + + def __init__(self, identifier: sp.Symbol, name: str): + """ + Create a new Expression instance. + + :param identifier: + unique identifier of the ObservableParameter + + :param name: + individual name of the ObservableParameter (does not need to be + unique) + """ + super().__init__(identifier, name, 0.0) + + class LogLikelihood(ModelQuantity): """ A LogLikelihood defines the distance between measurements and @@ -751,4 +791,6 @@ def get_trigger_time(self) -> sp.Float: SymbolId.LLHRZ: LogLikelihoodRZ, SymbolId.EXPRESSION: Expression, SymbolId.EVENT: Event, + SymbolId.NOISE_PARAMETER: NoiseParameter, + SymbolId.OBSERVABLE_PARAMETER: ObservableParameter, } diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 5d5521d222..f78561fd55 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -64,28 +64,30 @@ def _tcl(self, x, p): return TPL_TOTAL_CL_RET - def _y(self, t, x, p, tcl): + def _y(self, t, x, p, tcl, op): TPL_X_SYMS = x TPL_P_SYMS = p TPL_W_SYMS = self._w(t, x, p, tcl) + TPL_OP_SYMS = op TPL_Y_EQ return TPL_Y_RET - def _sigmay(self, y, p): + def _sigmay(self, y, p, np): TPL_P_SYMS = p TPL_Y_SYMS = y + TPL_NP_SYMS = np TPL_SIGMAY_EQ return TPL_SIGMAY_RET - def _nllh(self, t, x, p, tcl, my, iy): - y = self._y(t, x, p, tcl) + def _nllh(self, t, x, p, tcl, my, iy, op, np): + y = self._y(t, x, p, tcl, op) TPL_Y_SYMS = y - TPL_SIGMAY_SYMS = self._sigmay(y, p) + TPL_SIGMAY_SYMS = self._sigmay(y, p, np) TPL_JY_EQ diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 616431dd94..da5b2f9e56 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -43,7 +43,7 @@ class JAXModel(eqx.Module): Path to the JAX model file. """ - MODEL_API_VERSION = "0.0.2" + MODEL_API_VERSION = "0.0.3" api_version: str jax_py_file: Path @@ -77,7 +77,7 @@ def _w( self, t: jt.Float[jt.Array, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], ) -> jt.Float[jt.Array, "nw"]: """ @@ -85,7 +85,7 @@ def _w( :param t: time point :param x: state vector - :param pk: parameters + :param p: parameters :param tcl: total values for conservation laws :return: Expression values. @@ -93,11 +93,11 @@ def _w( ... @abstractmethod - def _x0(self, pk: jt.Float[jt.Array, "np"]) -> jt.Float[jt.Array, "nx"]: + def _x0(self, p: jt.Float[jt.Array, "np"]) -> jt.Float[jt.Array, "nx"]: """ Compute the initial state vector. - :param pk: parameters + :param p: parameters """ ... @@ -133,14 +133,14 @@ def _x_rdata( @abstractmethod def _tcl( - self, x: jt.Float[jt.Array, "nx"], pk: jt.Float[jt.Array, "np"] + self, x: jt.Float[jt.Array, "nx"], p: jt.Float[jt.Array, "np"] ) -> jt.Float[jt.Array, "ncl"]: """ Compute the total values for conservation laws. :param x: state vector - :param pk: + :param p: parameters :return: total values for conservation laws @@ -152,8 +152,9 @@ def _y( self, t: jt.Float[jt.Scalar, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], + op: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Array, "ny"]: """ Compute the observables. @@ -162,10 +163,12 @@ def _y( time point :param x: state vector - :param pk: + :param p: parameters :param tcl: total values for conservation laws + :param op: + observables parameters :return: observables """ @@ -173,15 +176,20 @@ def _y( @abstractmethod def _sigmay( - self, y: jt.Float[jt.Array, "ny"], pk: jt.Float[jt.Array, "np"] + self, + y: jt.Float[jt.Array, "ny"], + p: jt.Float[jt.Array, "np"], + np: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Array, "ny"]: """ Compute the standard deviations of the observables. :param y: observables - :param pk: + :param p: parameters + :param np: + noise parameters :return: standard deviations of the observables """ @@ -192,10 +200,12 @@ def _nllh( self, t: jt.Float[jt.Scalar, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], my: jt.Float[jt.Array, ""], iy: jt.Int[jt.Array, ""], + op: jt.Float[jt.Array, "ny"], + np: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Scalar, ""]: """ Compute the negative log-likelihood of the observable for the specified observable index. @@ -204,7 +214,7 @@ def _nllh( time point :param x: state vector - :param pk: + :param p: parameters :param tcl: total values for conservation laws @@ -212,6 +222,10 @@ def _nllh( observed data :param iy: observable index + :param op: + observables parameters + :param np: + noise parameters :return: log-likelihood of the observable """ @@ -377,6 +391,8 @@ def _nllhs( tcl: jt.Float[jt.Array, "ncl"], mys: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], ) -> jt.Float[jt.Array, "nt"]: """ Compute the negative log-likelihood for each observable. @@ -393,11 +409,15 @@ def _nllhs( observed data :param iys: observable indices + :param ops: + observables parameters + :param nps: + noise parameters :return: negative log-likelihoods of the observables """ - return jax.vmap(self._nllh, in_axes=(0, 0, None, None, 0, 0))( - ts, xs, p, tcl, mys, iys + return jax.vmap(self._nllh, in_axes=(0, 0, None, None, 0, 0, 0, 0))( + ts, xs, p, tcl, mys, iys, ops, nps ) def _ys( @@ -407,6 +427,7 @@ def _ys( p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], iys: jt.Float[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], ) -> jt.Int[jt.Array, "nt"]: """ Compute the observables. @@ -421,13 +442,17 @@ def _ys( total values for conservation laws :param iys: observable indices + :param ops: + observables parameters :return: observables """ return jax.vmap( - lambda t, x, p, tcl, iy: self._y(t, x, p, tcl).at[iy].get(), - in_axes=(0, 0, None, None, 0), - )(ts, xs, p, tcl, iys) + lambda t, x, p, tcl, iy, op: self._y(t, x, p, tcl, op) + .at[iy] + .get(), + in_axes=(0, 0, None, None, 0, 0), + )(ts, xs, p, tcl, iys, ops) def _sigmays( self, @@ -436,6 +461,8 @@ def _sigmays( p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], iys: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], ): """ Compute the standard deviations of the observables. @@ -450,15 +477,21 @@ def _sigmays( total values for conservation laws :param iys: observable indices + :param ops: + observables parameters + :param nps: + noise parameters :return: standard deviations of the observables """ return jax.vmap( - lambda t, x, p, tcl, iy: self._sigmay(self._y(t, x, p, tcl), p) + lambda t, x, p, tcl, iy, op, np: self._sigmay( + self._y(t, x, p, tcl, op), p, np + ) .at[iy] .get(), - in_axes=(0, 0, None, None, 0), - )(ts, xs, p, tcl, iys) + in_axes=(0, 0, None, None, 0, 0, 0), + )(ts, xs, p, tcl, iys, ops, nps) @eqx.filter_jit def simulate_condition( @@ -469,6 +502,8 @@ def simulate_condition( my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], iy_trafos: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -497,13 +532,12 @@ def simulate_condition( observed data :param iys: indices of the observables according to ordering in :ivar observable_ids: - :param x_preeq: - initial state vector for pre-equilibration. If not provided, the initial state vector is computed using - :meth:`_x0`. - :param mask_reinit: - mask for re-initialization. If `True`, the corresponding state variable is re-initialized. - :param x_reinit: - re-initialized state vector. If not provided, the state vector is not re-initialized. + :param iy_trafos: + indices of transformations for observables + :param ops: + observables parameters + :param nps: + noise parameters :param solver: ODE solver :param controller: @@ -515,13 +549,20 @@ def simulate_condition( event function for steady state. See :func:`diffrax.steady_state_event` for details. :param max_steps: maximum number of solver steps - :param ret: - which output to return. See :class:`ReturnValue` for available options. + :param x_preeq: + initial state vector for pre-equilibration. If not provided, the initial state vector is computed using + :meth:`_x0`. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. + :param ret: + which output to return. See :class:`ReturnValue` for available options. :return: - output according to `ret` and statistics + output according to `ret` and general results/statistics """ if x_preeq.shape[0]: x = x_preeq @@ -578,7 +619,7 @@ def simulate_condition( x = jnp.concatenate((x_dyn, x_posteq), axis=0) - nllhs = self._nllhs(ts, x, p, tcl, my, iys) + nllhs = self._nllhs(ts, x, p, tcl, my, iys, ops, nps) nllhs = jnp.where(ts_mask, nllhs, 0.0) llh = -jnp.sum(nllhs) @@ -598,9 +639,9 @@ def simulate_condition( elif ret == ReturnValue.x_solver: output = x elif ret == ReturnValue.y: - output = self._ys(ts, x, p, tcl, iys) + output = self._ys(ts, x, p, tcl, iys, ops) elif ret == ReturnValue.sigmay: - output = self._sigmays(ts, x, p, tcl, iys) + output = self._sigmays(ts, x, p, tcl, iys, ops, nps) elif ret == ReturnValue.x0: output = self._x_rdata(x[0, :], tcl) elif ret == ReturnValue.x0_solver: @@ -616,10 +657,10 @@ def simulate_condition( .at[iy_trafo] .get(), ) - ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) + ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys, ops), iy_trafos) m_obj = obs_trafo(my, iy_trafos) if ret == ReturnValue.chi2: - sigma_obj = self._sigmays(ts, x, p, tcl, iys) + sigma_obj = self._sigmays(ts, x, p, tcl, iys, ops, nps) chi2 = jnp.square((ys_obj - m_obj) / sigma_obj) chi2 = jnp.where(ts_mask, chi2, 0.0) output = jnp.sum(chi2) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 4329195441..0ad7e48ed9 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -194,7 +194,18 @@ def _generate_jax_code(self) -> None: "x_rdata", "total_cl", ) - sym_names = ("p", "x", "tcl", "w", "my", "y", "sigmay", "x_rdata") + sym_names = ( + "p", + "np", + "op", + "x", + "tcl", + "w", + "my", + "y", + "sigmay", + "x_rdata", + ) indent = 8 diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index c47a00e1e3..287a7f0d8c 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -2,7 +2,7 @@ import shutil from numbers import Number -from collections.abc import Iterable +from collections.abc import Sized, Iterable from pathlib import Path from collections.abc import Callable @@ -88,6 +88,8 @@ class JAXProblem(eqx.Module): _iys: np.ndarray _iy_trafos: np.ndarray _ts_masks: np.ndarray + _ops: np.ndarray + _nps: np.ndarray _petab_measurement_indices: np.ndarray _petab_problem: petab.Problem @@ -113,6 +115,8 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): self._iy_trafos, self._ts_masks, self._petab_measurement_indices, + self._ops, + self._nps, ) = self._get_measurements(scs) self.parameters = self._get_nominal_parameter_values() @@ -192,6 +196,8 @@ def _get_measurements( np.ndarray, np.ndarray, np.ndarray, + np.ndarray, + np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. @@ -211,6 +217,27 @@ def _get_measurements( """ measurements = dict() petab_indices = dict() + + n_pars = dict() + for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + n_pars[col] = 0 + if col in self._petab_problem.measurement_df: + if self._petab_problem.measurement_df[col].dtype == np.float64: + n_pars[col] = 1 - int( + self._petab_problem.measurement_df[col].isna().all() + ) + else: + n_pars[col] = ( + self._petab_problem.measurement_df[col] + .str.split(";") + .apply( + lambda x: len(x) + if isinstance(x, Sized) + else 1 - int(pd.isna(x)) + ) + .max() + ) + for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -249,42 +276,85 @@ def _get_measurements( else: iy_trafos = np.zeros_like(iys) + parameter_overrides = dict() + for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + if col not in m or m[col].isna().all(): + mat = jnp.ones((len(m), n_pars[col])) + + elif m[col].dtype == np.float64: + mat = np.pad( + jnp.array(m[col].values), + ((0, 0), (0, n_pars[col])), + mode="edge", + ) + + else: + split_vals = m[col].str.split(";") + list_vals = split_vals.apply( + lambda x: x + if isinstance(x, list) + else [] + if pd.isna(x) + else [float(x)] + ) + vals = list_vals.apply( + lambda x: np.pad( + x, + (0, n_pars[col] - len(x)), + mode="constant", + constant_values=1.0, + ) + ) + mat = np.stack(vals) + + parameter_overrides[col] = mat + measurements[tuple(simulation_condition)] = ( ts_dyn, ts_posteq, my, iys, iy_trafos, + parameter_overrides[petab.OBSERVABLE_PARAMETERS], + parameter_overrides[petab.NOISE_PARAMETERS], ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) # compute maximum lengths n_ts_dyn = max( - len(ts_dyn) for ts_dyn, _, _, _, _ in measurements.values() + len(ts_dyn) for ts_dyn, _, _, _, _, _, _ in measurements.values() ) n_ts_posteq = max( - len(ts_posteq) for _, ts_posteq, _, _, _ in measurements.values() + len(ts_posteq) + for _, ts_posteq, _, _, _, _, _ in measurements.values() ) # pad with last value and stack ts_dyn = np.stack( [ np.pad(x, (0, n_ts_dyn - len(x)), mode="edge") - for x, _, _, _, _ in measurements.values() + for x, _, _, _, _, _, _ in measurements.values() ] ) ts_posteq = np.stack( [ np.pad(x, (0, n_ts_posteq - len(x)), mode="edge") - for _, x, _, _, _ in measurements.values() + for _, x, _, _, _, _, _ in measurements.values() ] ) def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): + # only pad first axis + pad_width_dyn = tuple( + [(0, n_ts_dyn - len(x_dyn))] + [(0, 0)] * (x_dyn.ndim - 1) + ) + pad_width_peq = tuple( + [(0, n_ts_posteq - len(x_peq))] + [(0, 0)] * (x_peq.ndim - 1) + ) return np.concatenate( ( - np.pad(x_dyn, (0, n_ts_dyn - len(x_dyn)), mode="edge"), - np.pad(x_peq, (0, n_ts_posteq - len(x_peq)), mode="edge"), + np.pad(x_dyn, pad_width_dyn, mode="edge"), + np.pad(x_peq, pad_width_peq, mode="edge"), ) ) @@ -293,7 +363,7 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): pad_measurement( x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq ) - for tdyn, tpeq, x, _, _ in measurements.values() + for tdyn, tpeq, x, _, _, _, _ in measurements.values() ] ) iys = np.stack( @@ -301,7 +371,7 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): pad_measurement( x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq ) - for tdyn, tpeq, _, x, _ in measurements.values() + for tdyn, tpeq, _, x, _, _, _ in measurements.values() ] ) iy_trafos = np.stack( @@ -309,7 +379,23 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): pad_measurement( x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq ) - for tdyn, tpeq, _, _, x in measurements.values() + for tdyn, tpeq, _, _, x, _, _ in measurements.values() + ] + ) + ops = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, _, _, x, _ in measurements.values() + ] + ) + nps = np.stack( + [ + pad_measurement( + x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq + ) + for tdyn, tpeq, _, _, _, _, x in measurements.values() ] ) ts_masks = np.stack( @@ -322,21 +408,34 @@ def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): ), ) ) - for tdyn, tpeq, _, _, _ in measurements.values() + for tdyn, tpeq, _, _, _, _, _ in measurements.values() ] ).astype(bool) petab_indices = np.stack( [ pad_measurement( - idx[: len(tdyn)], idx[len(tdyn) :], n_ts_dyn, n_ts_posteq + np.array(idx[: len(tdyn)]), + np.array(idx[len(tdyn) :]), + n_ts_dyn, + n_ts_posteq, ) - for (tdyn, tpeq, _, _, _), idx in zip( + for (tdyn, tpeq, _, _, _, _, _), idx in zip( measurements.values(), petab_indices.values() ) ] ) - return ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices + return ( + ts_dyn, + ts_posteq, + my, + iys, + iy_trafos, + ts_masks, + petab_indices, + ops, + nps, + ) def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -549,35 +648,77 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": return eqx.tree_at(lambda p: p.parameters, self, p) def _prepare_conditions( - self, conditions: Iterable[str] + self, + conditions: list[tuple[str, ...]], + op_array: np.ndarray, + np_array: np.ndarray, ) -> tuple[ jt.Float[jt.Array, "np"], # noqa: F821 jt.Bool[jt.Array, "nx"], # noqa: F821 jt.Float[jt.Array, "nx"], # noqa: F821 + jt.Float[jt.Array, "nt *nop"], # noqa: F821 + jt.Float[jt.Array, "nt *nnp"], # noqa: F821 ]: """ Prepare conditions for simulation. :param conditions: Simulation conditions to prepare. + :param ts_mask: + Time point mask to use for padding. :return: - Tuple of parameter arrays, reinitialisation masks and reinitialisation values. + Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and + noise parameters. """ - p_array = jnp.stack([self.load_parameters(sc) for sc in conditions]) + p_array = jnp.stack([self.load_parameters(sc[0]) for sc in conditions]) + + def map_parameter(x, p): + if isinstance(x, Number): + return x + if x in self.model.parameter_ids: + return p[self.model.parameter_ids.index(x)] + if x in self.parameter_ids: + return jax_unscale( + self.get_petab_parameter_by_id(x), + self._petab_problem.parameter_df.loc[ + x, petab.PARAMETER_SCALE + ], + ) + return float(x) + + if op_array.size: + op_array = jnp.stack( + [ + jnp.array( + [map_parameter(x, p) for x in op_array[ic, :].ravel()] + ).reshape(op_array.shape[1:]) + for ic, p in enumerate(p_array) + ] + ) + + if np_array.size: + np_array = jnp.stack( + [ + jnp.array( + [map_parameter(x, p) for x in np_array[ic, :].ravel()] + ).reshape(np_array.shape[1:]) + for ic, p in enumerate(p_array) + ] + ) mask_reinit_array = jnp.stack( [ - self.load_reinitialisation(sc, p)[0] + self.load_reinitialisation(sc[0], p)[0] for sc, p in zip(conditions, p_array) ] ) x_reinit_array = jnp.stack( [ - self.load_reinitialisation(sc, p)[1] + self.load_reinitialisation(sc[0], p)[1] for sc, p in zip(conditions, p_array) ] ) - return p_array, mask_reinit_array, x_reinit_array + return p_array, mask_reinit_array, x_reinit_array, op_array, np_array @eqx.filter_vmap( in_axes={ @@ -593,6 +734,8 @@ def run_simulation( my: np.ndarray, iys: np.ndarray, iy_trafos: np.ndarray, + ops: jt.Float[jt.Array, "nt *nop"], # noqa: F821, F722 + nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, @@ -620,6 +763,10 @@ def run_simulation( (Padded) observable indices :param iy_trafos: (Padded) observable transformations indices + :param ops: + (Padded) observable parameters + :param nps: + (Padded) noise parameters :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: @@ -650,6 +797,8 @@ def run_simulation( my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), + nps=nps, + ops=ops, x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, @@ -666,7 +815,7 @@ def run_simulation( def run_simulations( self, - simulation_conditions: list[str], + simulation_conditions: list[tuple[str, ...]], preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, @@ -699,8 +848,10 @@ def run_simulations( Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ - p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( - simulation_conditions + p_array, mask_reinit_array, x_reinit_array, op_array, np_array = ( + self._prepare_conditions( + simulation_conditions, self._ops, self._nps + ) ) return self.run_simulation( p_array, @@ -709,6 +860,8 @@ def run_simulations( self._my, self._iys, self._iy_trafos, + op_array, + np_array, mask_reinit_array, x_reinit_array, solver, @@ -779,8 +932,8 @@ def run_preequilibrations( ], max_steps: jnp.int_, ): - p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( - simulation_conditions + p_array, mask_reinit_array, x_reinit_array, _, _ = ( + self._prepare_conditions(simulation_conditions) ) return self.run_preequilibration( p_array, @@ -869,7 +1022,7 @@ def run_simulations( ] ) output, results = problem.run_simulations( - dynamic_conditions, + simulation_conditions, preeq_array, solver, controller, diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e605a9cc80..dabc8eda20 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -1,4 +1,6 @@ import logging +import re + import math import os import tempfile @@ -345,16 +347,24 @@ def import_model_sbml( f"({len(sigmas)}) do not match." ) - _workaround_observable_parameters( - observables, sigmas, sbml_model, output_parameter_defaults - ) if not jax: + _workaround_observable_parameters( + observables, sigmas, sbml_model, output_parameter_defaults + ) fixed_parameters = _workaround_initial_states( petab_problem=petab_problem, sbml_model=sbml_model, **kwargs, ) else: + sigmas = { + obs: re.sub(f"(noiseParameter[0-9]+)_{obs}", r"\1", sigma) + for obs, sigma in sigmas.items() + } + for obs, obs_def in observables.items(): + obs_def["formula"] = re.sub( + f"(observableParameter[0-9]+)_{obs}", r"\1", obs_def["formula"] + ) fixed_parameters = [] fixed_parameters.extend( diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 557ad02d0f..eed91f200b 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -1985,6 +1985,37 @@ def _process_observables( self.symbols[SymbolId.OBSERVABLE], "eventObservable" ) + noise_pars = list( + { + name + for sigma in sigmas.values() + for symbol in sp.sympify(sigma).free_symbols + if (name := str(symbol)).startswith("noiseParameter") + } + ) + self.symbols[SymbolId.NOISE_PARAMETER] = { + symbol_with_assumptions(np): {"name": np} + for np in sorted( + noise_pars, key=lambda x: int(x.replace("noiseParameter", "")) + ) + } + + observable_pars = list( + { + name + for obs in observables.values() + for symbol in sp.sympify(obs["formula"]).free_symbols + if (name := str(symbol)).startswith("observableParameter") + } + ) + self.symbols[SymbolId.OBSERVABLE_PARAMETER] = { + symbol_with_assumptions(op): {"name": op} + for op in sorted( + observable_pars, + key=lambda x: int(x.replace("observableParameter", "")), + ) + } + self._process_log_likelihood(sigmas, noise_distributions) @log_execution_time("processing SBML event observables", logger) diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 2f3fbb433a..a34f14dd29 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -5,6 +5,7 @@ for a subset of the benchmark problems. """ +import copy from functools import partial from pathlib import Path @@ -245,17 +246,18 @@ def benchmark_problem(request): the benchmark problem collection.""" problem_id = request.param petab_problem = benchmark_models_petab.get_problem(problem_id) + flat_petab_problem = copy.deepcopy(petab_problem) if measurement_table_has_timepoint_specific_mappings( petab_problem.measurement_df, ): - petab.flatten_timepoint_specific_output_overrides(petab_problem) + petab.flatten_timepoint_specific_output_overrides(flat_petab_problem) # Setup AMICI objects. amici_model = import_petab_problem( - petab_problem, + flat_petab_problem, model_output_dir=benchmark_outdir / problem_id, ) - return problem_id, petab_problem, amici_model + return problem_id, flat_petab_problem, petab_problem, amici_model @pytest.mark.filterwarnings( @@ -272,7 +274,9 @@ def test_jax_llh(benchmark_problem): jax.config.update("jax_enable_x64", True) from beartype import beartype - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, flat_petab_problem, petab_problem, amici_model = ( + benchmark_problem + ) amici_solver = amici_model.getSolver() cur_settings = settings[problem_id] @@ -282,7 +286,7 @@ def test_jax_llh(benchmark_problem): simulate_amici = partial( simulate_petab, - petab_problem=petab_problem, + petab_problem=flat_petab_problem, amici_model=amici_model, solver=amici_solver, scaled_parameters=True, @@ -294,7 +298,7 @@ def test_jax_llh(benchmark_problem): problem_parameters = None if problem_id in problems_for_gradient_check: - point = petab_problem.x_nominal_free_scaled + point = flat_petab_problem.x_nominal_free_scaled for _ in range(20): amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint) amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) @@ -306,7 +310,9 @@ def test_jax_llh(benchmark_problem): ) point += point_noise # avoid small gradients at nominal value - problem_parameters = dict(zip(petab_problem.x_free_ids, point)) + problem_parameters = dict( + zip(flat_petab_problem.x_free_ids, point) + ) r_amici = simulate_amici( problem_parameters=problem_parameters, @@ -372,7 +378,7 @@ def test_nominal_parameters_llh(benchmark_problem): Also check that the simulation time is within the reference range. """ - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, petab_problem, _, amici_model = benchmark_problem if problem_id not in problems_for_llh_check: pytest.skip("Excluded from log-likelihood check.") @@ -526,7 +532,7 @@ def test_nominal_parameters_llh(benchmark_problem): def test_benchmark_gradient( benchmark_problem, scale, sensitivity_method, request ): - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, petab_problem, _, amici_model = benchmark_problem if problem_id not in problems_for_gradient_check: pytest.skip("Excluded from gradient check.")