Skip to content

Commit

Permalink
Integrate the entire dynamics for higher-order ODEs (#1139)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu authored Nov 27, 2024
1 parent 31576af commit 806ca13
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 6 deletions.
16 changes: 16 additions & 0 deletions doc/nestml_language/nestml_language_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,22 @@ Integrating the ODEs needs to be triggered explicitly inside the ``update`` bloc

The ``integrate_odes()`` function numerically integrates the differential equations defined in the ``equations`` block. Integrating the ODEs from one timestep to the next has to be explicitly carried out in the model by calling the ``integrate_odes()`` function. If no parameters are given, all ODEs in the model are integrated. Integration can be limited to a given set of ODEs by giving their left-hand side state variables as parameters to the function, for example ``integrate_odes(V_m, I_ahp)`` if ODEs exist for the variables ``V_m`` and ``I_ahp``. In this example, these variables are integrated simultaneously (as one single system of equations). This is different from calling ``integrate_odes(V_m)`` and then ``integrate_odes(I_ahp)`` in that the second call would use the already-updated values from the first call. Variables not included in the call to ``integrate_odes()`` are assumed to remain constant (both inside the numeric solver stepping function as well as from before to after the call).

In case of higher-order ODEs of the form ``F(x'', x', x) = 0``, the solution ``x(t)`` is obtained by just providing the variable ``x`` to the ``integrate_odes`` function. For example,

.. code-block:: nestml
state:
x real = 0
x' ms**-1 = 0 * ms**-1
equations:
x'' = - 2 * x' / ms - x / ms**2
update:
integrate_odes(x)
Here, ``integrate_odes(x)`` integrates the entire dynamics of ``x(t)``, in this case, ``x`` and ``x'``.

Note that the dynamical equations that correspond to convolutions are always updated, regardless of whether ``integrate_odes()`` is called. The state variables affected by incoming events are updated at the end of each timestep, that is, within one timestep, the state as observed by statements in the ``update`` block will be those at :math:`t^-`, i.e. "just before" it has been updated due to the events. See also :ref:`Integrating spiking input` and :ref:`Integration order`.

ODEs that can be solved analytically are integrated to machine precision from one timestep to the next using the propagators obtained from `ODE-toolbox <https://ode-toolbox.readthedocs.io/>`_. In case a numerical solver is used (such as Runge-Kutta or forward Euler), the same ODEs are also evaluated numerically by the numerical solver to allow more precise values for analytically solvable ODEs *within* a timestep. In this way, the long-term dynamics obeys the analytic (more exact) equations, while the short-term (within one timestep) dynamics is evaluated to the precision of the numerical integrator.
Expand Down
3 changes: 3 additions & 0 deletions pynestml/cocos/co_co_integrate_odes_params_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def visit_function_call(self, node):
if symbol_var is None or not symbol_var.is_state():
code, message = Messages.get_integrate_odes_wrong_arg(str(arg))
Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
elif symbol_var.is_state() and arg.get_variable().get_differential_order() > 0:
code, message = Messages.get_integrate_odes_arg_higher_order(str(arg))
Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables_: %}
{%- for variable_name in analytic_state_variables_ %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
Expand Down
14 changes: 14 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,21 @@ def filter_variables_list(cls, variables_list, variables_to_filter_by):
for var in variables_list:
if var in variables_to_filter_by:
ret.append(var)
# Add higher order variables of var if not already in the filter list
ret.extend(cls.get_higher_order_variables(var, variables_list, variables_to_filter_by))
return ret

@classmethod
def get_higher_order_variables(cls, var, variables_list, variables_to_filter_by) -> List[str]:
"""
Returns a list of higher order state variables of ``var`` from the ``variables_list`` that are not already present in ``variables_to_filter_by``.
"""
ret = []
for v in variables_list:
order = v.count('__d')
if order > 0:
if v.split("__d")[0] == var and v not in variables_to_filter_by:
ret.append(v)
return ret

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions pynestml/utils/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class MessageCode(Enum):
EXPONENT_MUST_BE_INTEGER = 114
EMIT_SPIKE_OUTPUT_PORT_TYPE_DIFFERS = 115
CONTINUOUS_OUTPUT_PORT_MAY_NOT_HAVE_ATTRIBUTES = 116
INTEGRATE_ODES_ARG_HIGHER_ORDER = 117


class Messages:
Expand Down Expand Up @@ -1300,6 +1301,11 @@ def get_integrate_odes_wrong_arg(cls, arg: str) -> Tuple[MessageCode, str]:
message = "Parameter provided to integrate_odes() function is not a state variable: '" + arg + "'"
return MessageCode.INTEGRATE_ODES_WRONG_ARG, message

@classmethod
def get_integrate_odes_arg_higher_order(cls, arg: str) -> Tuple[MessageCode, str]:
message = "Parameter provided to integrate_odes() function is a state variable of higher order: '" + arg + "'"
return MessageCode.INTEGRATE_ODES_ARG_HIGHER_ORDER, message

@classmethod
def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info, con_in_info) -> Tuple[MessageCode, str]:
message = ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
alpha_function_2nd_order_ode_neuron.nestml
##########################################

Tests that for a system of higher-oder ODEs of the form F(x'',x',x)=0, integrate_odes(x) includes the integration of all the higher-order variables involved of the system.

Copyright statement
+++++++++++++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
model alpha_function_2nd_order_ode_neuron:
state:
x real = 0
x' ms**-1 = 0 * ms**-1
y real = 0

input:
fX <- spike

equations:
x'' = - 2 * x' / ms - x / ms**2
y' = (-y + 42) / s

update:
integrate_odes(x, y)

onReceive(fX):
x' += e*fX * s / ms
54 changes: 49 additions & 5 deletions tests/nest_tests/test_integrate_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def setUp(self):
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_nonlinear_test.nestml")))],
os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
logging_level="INFO",
suffix="_nestml")

Expand Down Expand Up @@ -131,11 +133,9 @@ def test_integrate_odes(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -182,11 +182,9 @@ def test_integrate_odes_nonlinear(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_nonlinear_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -232,3 +230,49 @@ def test_integrate_odes_params2(self):
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2

def test_integrate_odes_higher_order(self):
r"""
Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x.
"""
resolution = 0.1
simtime = 15.
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

n = nest.Create("alpha_function_2nd_order_ode_neuron_nestml")
sgX = nest.Create("spike_generator", params={"spike_times": [10.]})
nest.Connect(sgX, n, syn_spec={"weight": 1., "delay": resolution})

mm = nest.Create("multimeter", params={"interval": resolution, "record_from": ["x", "y"]})
nest.Connect(mm, n)

nest.Simulate(simtime)
times = mm.get()["events"]["times"]
x_actual = mm.get()["events"]["x"]
y_actual = mm.get()["events"]["y"]

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=2)
ax1, ax2 = ax

ax2.plot(times, x_actual, label="x")
ax1.plot(times, y_actual, label="y")

for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
_ax.set_xlim(0., simtime)
_ax.legend()

fig.savefig("/tmp/test_integrate_odes_higher_order.png", dpi=300)

# verify
np.testing.assert_allclose(x_actual[-1], 0.10737970490959549)
np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)

0 comments on commit 806ca13

Please sign in to comment.