Skip to content

Commit

Permalink
first working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Jan 30, 2025
1 parent 64de63a commit 7302481
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 80 deletions.
2 changes: 2 additions & 0 deletions python/sdist/amici/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ class SymbolId(str, enum.Enum):
SIGMAZ = "sigmaz"
LLHZ = "llhz"
LLHRZ = "llhrz"
NOISE_PARAMETER = "noise_parameter"
OBSERVABLE_PARAMETER = "observable_parameter"
16 changes: 16 additions & 0 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
LogLikelihoodY,
LogLikelihoodZ,
LogLikelihoodRZ,
NoiseParameter,
ObservableParameter,
Expression,
ConservationLaw,
Event,
Expand Down Expand Up @@ -226,6 +228,8 @@ def __init__(
self._log_likelihood_ys: list[LogLikelihoodY] = []
self._log_likelihood_zs: list[LogLikelihoodZ] = []
self._log_likelihood_rzs: list[LogLikelihoodRZ] = []
self._noise_parameters: list[NoiseParameter] = []
self._observable_parameters: list[ObservableParameter] = []
self._expressions: list[Expression] = []
self._conservation_laws: list[ConservationLaw] = []
self._events: list[Event] = []
Expand Down Expand Up @@ -273,6 +277,8 @@ def __init__(
"sigmay": self.sigma_ys,
"sigmaz": self.sigma_zs,
"h": self.events,
"np": self.noise_parameters,
"op": self.observable_parameters,
}
self._value_prototype: dict[str, Callable] = {
"p": self.parameters,
Expand Down Expand Up @@ -385,6 +391,14 @@ def log_likelihood_rzs(self) -> list[LogLikelihoodRZ]:
"""Get all event observable regularization log likelihoods."""
return self._log_likelihood_rzs

def noise_parameters(self) -> list[NoiseParameter]:
"""Get all noise parameters."""
return self._noise_parameters

def observable_parameters(self) -> list[ObservableParameter]:
"""Get all observable parameters."""
return self._observable_parameters

def is_ode(self) -> bool:
"""Check if model is ODE model."""
return len(self._algebraic_equations) == 0
Expand Down Expand Up @@ -565,6 +579,8 @@ def add_component(
ConservationLaw,
Event,
EventObservable,
NoiseParameter,
ObservableParameter,
}:
raise ValueError(f"Invalid component type {type(component)}")

Expand Down
42 changes: 42 additions & 0 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,46 @@ def __init__(
super().__init__(identifier, name, value)


class NoiseParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point
specific manner, abbreviated by ``np``. Only used for jax models.
"""

def __init__(self, identifier: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param identifier:
unique identifier of the NoiseParameter
:param name:
individual name of the NoiseParameter (does not need to be
unique)
"""
super().__init__(identifier, name, 0.0)


class ObservableParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific
manner, abbreviated by ``op``. Only used for jax models.
"""

def __init__(self, identifier: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param identifier:
unique identifier of the ObservableParameter
:param name:
individual name of the ObservableParameter (does not need to be
unique)
"""
super().__init__(identifier, name, 0.0)


class LogLikelihood(ModelQuantity):
"""
A LogLikelihood defines the distance between measurements and
Expand Down Expand Up @@ -751,4 +791,6 @@ def get_trigger_time(self) -> sp.Float:
SymbolId.LLHRZ: LogLikelihoodRZ,
SymbolId.EXPRESSION: Expression,
SymbolId.EVENT: Event,
SymbolId.NOISE_PARAMETER: NoiseParameter,
SymbolId.OBSERVABLE_PARAMETER: ObservableParameter,
}
12 changes: 7 additions & 5 deletions python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,30 @@ def _tcl(self, x, p):

return TPL_TOTAL_CL_RET

def _y(self, t, x, p, tcl):
def _y(self, t, x, p, tcl, op):
TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_W_SYMS = self._w(t, x, p, tcl)
TPL_OP_SYMS = op

TPL_Y_EQ

return TPL_Y_RET

def _sigmay(self, y, p):
def _sigmay(self, y, p, np):
TPL_P_SYMS = p

TPL_Y_SYMS = y
TPL_NP_SYMS = np

TPL_SIGMAY_EQ

return TPL_SIGMAY_RET

def _nllh(self, t, x, p, tcl, my, iy):
y = self._y(t, x, p, tcl)
def _nllh(self, t, x, p, tcl, my, iy, op, np):
y = self._y(t, x, p, tcl, op)
TPL_Y_SYMS = y
TPL_SIGMAY_SYMS = self._sigmay(y, p)
TPL_SIGMAY_SYMS = self._sigmay(y, p, np)

TPL_JY_EQ

Expand Down
Loading

0 comments on commit 7302481

Please sign in to comment.