Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Jan 31, 2025
1 parent 7302481 commit 4673e29
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 37 deletions.
57 changes: 34 additions & 23 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""PEtab wrappers for JAX models.""" ""

import copy
import shutil
from numbers import Number
from collections.abc import Sized, Iterable
Expand Down Expand Up @@ -173,13 +174,22 @@ def _get_parameter_mappings(
Dictionary mapping simulation conditions to parameter mappings.
"""
scs = list(set(simulation_conditions.values.flatten()))
petab_problem = copy.deepcopy(self._petab_problem)

Check warning on line 177 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L177

Added line #L177 was not covered by tests
# remove observable and noise parameters from measurement dataframe as we are mapping them elsewhere
petab_problem.measurement_df.drop(

Check warning on line 179 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L179

Added line #L179 was not covered by tests
columns=[petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS],
inplace=True,
errors="ignore",
)
mappings = create_parameter_mapping(
petab_problem=self._petab_problem,
petab_problem=petab_problem,
simulation_conditions=[
{petab.SIMULATION_CONDITION_ID: sc} for sc in scs
],
scaled_parameters=False,
allow_timepoint_specific_numeric_noise_parameters=True,
)
# fill in dummy variables
for mapping in mappings:
for sim_var, value in mapping.map_sim_var.items():
if isinstance(value, Number) and not np.isfinite(value):
Expand Down Expand Up @@ -222,7 +232,9 @@ def _get_measurements(
for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]:
n_pars[col] = 0
if col in self._petab_problem.measurement_df:
if self._petab_problem.measurement_df[col].dtype == np.float64:
if np.issubdtype(

Check warning on line 235 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L231-L235

Added lines #L231 - L235 were not covered by tests
self._petab_problem.measurement_df[col].dtype, np.number
):
n_pars[col] = 1 - int(

Check warning on line 238 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L238

Added line #L238 was not covered by tests
self._petab_problem.measurement_df[col].isna().all()
)
Expand Down Expand Up @@ -281,13 +293,8 @@ def _get_measurements(
if col not in m or m[col].isna().all():
mat = jnp.ones((len(m), n_pars[col]))

Check warning on line 294 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L291-L294

Added lines #L291 - L294 were not covered by tests

elif m[col].dtype == np.float64:
mat = np.pad(
jnp.array(m[col].values),
((0, 0), (0, n_pars[col])),
mode="edge",
)

elif np.issubdtype(m[col].dtype, np.number):
mat = np.expand_dims(m[col].values, axis=1)

Check warning on line 297 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L296-L297

Added lines #L296 - L297 were not covered by tests
else:
split_vals = m[col].str.split(";")
list_vals = split_vals.apply(

Check warning on line 300 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L299-L300

Added lines #L299 - L300 were not covered by tests
Expand Down Expand Up @@ -649,9 +656,9 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":

def _prepare_conditions(
self,
conditions: list[tuple[str, ...]],
op_array: np.ndarray,
np_array: np.ndarray,
conditions: list[str],
op_array: np.ndarray | None,
np_array: np.ndarray | None,
) -> tuple[
jt.Float[jt.Array, "np"], # noqa: F821
jt.Bool[jt.Array, "nx"], # noqa: F821
Expand All @@ -670,11 +677,9 @@ def _prepare_conditions(
Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and
noise parameters.
"""
p_array = jnp.stack([self.load_parameters(sc[0]) for sc in conditions])
p_array = jnp.stack([self.load_parameters(sc) for sc in conditions])

def map_parameter(x, p):
if isinstance(x, Number):
return x
if x in self.model.parameter_ids:
return p[self.model.parameter_ids.index(x)]
if x in self.parameter_ids:
Expand All @@ -684,9 +689,15 @@ def map_parameter(x, p):
x, petab.PARAMETER_SCALE
],
)
return float(x)
if x in self._petab_problem.parameter_df.index:
return self._petab_problem.parameter_df.loc[

Check warning on line 693 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L692-L693

Added lines #L692 - L693 were not covered by tests
x, petab.NOMINAL_VALUE
]
if isinstance(x, str):
return float(x)
return x

Check warning on line 698 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L696-L698

Added lines #L696 - L698 were not covered by tests

if op_array.size:
if op_array is not None and op_array.size:
op_array = jnp.stack(

Check warning on line 701 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L700-L701

Added lines #L700 - L701 were not covered by tests
[
jnp.array(
Expand All @@ -696,7 +707,7 @@ def map_parameter(x, p):
]
)

if np_array.size:
if np_array is not None and np_array.size:
np_array = jnp.stack(

Check warning on line 711 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L710-L711

Added lines #L710 - L711 were not covered by tests
[
jnp.array(
Expand All @@ -708,13 +719,13 @@ def map_parameter(x, p):

mask_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc[0], p)[0]
self.load_reinitialisation(sc, p)[0]
for sc, p in zip(conditions, p_array)
]
)
x_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc[0], p)[1]
self.load_reinitialisation(sc, p)[1]
for sc, p in zip(conditions, p_array)
]
)
Expand Down Expand Up @@ -815,7 +826,7 @@ def run_simulation(

def run_simulations(
self,
simulation_conditions: list[tuple[str, ...]],
simulation_conditions: list[str],
preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
Expand Down Expand Up @@ -933,7 +944,7 @@ def run_preequilibrations(
max_steps: jnp.int_,
):
p_array, mask_reinit_array, x_reinit_array, _, _ = (

Check warning on line 946 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L946

Added line #L946 was not covered by tests
self._prepare_conditions(simulation_conditions)
self._prepare_conditions(simulation_conditions, None, None)
)
return self.run_preequilibration(
p_array,
Expand Down Expand Up @@ -1022,7 +1033,7 @@ def run_simulations(
]
)
output, results = problem.run_simulations(
simulation_conditions,
dynamic_conditions,
preeq_array,
solver,
controller,
Expand Down
40 changes: 26 additions & 14 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _workaround_initial_states(


def _workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
observables, sigmas, sbml_model, output_parameter_defaults, jax=False
):
# TODO: adding extra output parameters is currently not supported,
# so we add any output parameters to the SBML model.
Expand All @@ -167,7 +167,25 @@ def _workaround_observable_parameters(
)
for free_sym in free_syms:
sym = str(free_sym)
if (
if jax and (m := re.match(r"(noiseParameter\d+)_(\w+)", sym)):
# group1 is the noise parameter, group2 is the observable, don't add to sbml but replace with generic
# noise parameter
sigmas[m.group(2)] = str(

Check warning on line 173 in python/sdist/amici/petab/sbml_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/sbml_import.py#L173

Added line #L173 was not covered by tests
sp.sympify(sigmas[m.group(2)], locals=_clash).subs(
free_sym, sp.Symbol(m.group(1))
)
)
elif jax and (
m := re.match(r"(observableParameter\d+)_(\w+)", sym)
):
# group1 is the noise parameter, group2 is the observable, don't add to sbml but replace with generic
# observable parameter
observables[m.group(2)]["formula"] = str(

Check warning on line 183 in python/sdist/amici/petab/sbml_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/sbml_import.py#L183

Added line #L183 was not covered by tests
sp.sympify(
observables[m.group(2)]["formula"], locals=_clash
).subs(free_sym, sp.Symbol(m.group(1)))
)
elif (
sbml_model.getElementBySId(sym) is None
and sym != "time"
and sym not in observables
Expand Down Expand Up @@ -319,7 +337,8 @@ def import_model_sbml(
)
)
if (
petab_problem.measurement_df is not None
not jax
and petab_problem.measurement_df is not None
and petab.lint.measurement_table_has_timepoint_specific_mappings(
petab_problem.measurement_df,
allow_scalar_numeric_noise_parameters=allow_n_noise_pars,
Expand Down Expand Up @@ -347,24 +366,17 @@ def import_model_sbml(
f"({len(sigmas)}) do not match."
)

_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults, jax=jax
)

if not jax:
_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
)
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
else:
sigmas = {
obs: re.sub(f"(noiseParameter[0-9]+)_{obs}", r"\1", sigma)
for obs, sigma in sigmas.items()
}
for obs, obs_def in observables.items():
obs_def["formula"] = re.sub(
f"(observableParameter[0-9]+)_{obs}", r"\1", obs_def["formula"]
)
fixed_parameters = []

fixed_parameters.extend(
Expand Down

0 comments on commit 4673e29

Please sign in to comment.