From 806ca130bc4f9837403d93ac1989f2b1bfad2906 Mon Sep 17 00:00:00 2001
From: Pooja Babu <75320801+pnbabu@users.noreply.github.com>
Date: Wed, 27 Nov 2024 10:03:49 +0100
Subject: [PATCH] Integrate the entire dynamics for higher-order ODEs (#1139)
---
.../nestml_language_concepts.rst | 16 ++++++
.../co_co_integrate_odes_params_correct.py | 3 ++
.../AnalyticIntegrationStep_begin.jinja2 | 2 +-
pynestml/utils/ast_utils.py | 14 +++++
pynestml/utils/messages.py | 6 +++
...alpha_function_2nd_order_ode_neuron.nestml | 44 +++++++++++++++
tests/nest_tests/test_integrate_odes.py | 54 +++++++++++++++++--
7 files changed, 133 insertions(+), 6 deletions(-)
create mode 100644 tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml
diff --git a/doc/nestml_language/nestml_language_concepts.rst b/doc/nestml_language/nestml_language_concepts.rst
index 5c8af3738..f2c658d05 100644
--- a/doc/nestml_language/nestml_language_concepts.rst
+++ b/doc/nestml_language/nestml_language_concepts.rst
@@ -1121,6 +1121,22 @@ Integrating the ODEs needs to be triggered explicitly inside the ``update`` bloc
The ``integrate_odes()`` function numerically integrates the differential equations defined in the ``equations`` block. Integrating the ODEs from one timestep to the next has to be explicitly carried out in the model by calling the ``integrate_odes()`` function. If no parameters are given, all ODEs in the model are integrated. Integration can be limited to a given set of ODEs by giving their left-hand side state variables as parameters to the function, for example ``integrate_odes(V_m, I_ahp)`` if ODEs exist for the variables ``V_m`` and ``I_ahp``. In this example, these variables are integrated simultaneously (as one single system of equations). This is different from calling ``integrate_odes(V_m)`` and then ``integrate_odes(I_ahp)`` in that the second call would use the already-updated values from the first call. Variables not included in the call to ``integrate_odes()`` are assumed to remain constant (both inside the numeric solver stepping function as well as from before to after the call).
+In case of higher-order ODEs of the form ``F(x'', x', x) = 0``, the solution ``x(t)`` is obtained by just providing the variable ``x`` to the ``integrate_odes`` function. For example,
+
+.. code-block:: nestml
+
+ state:
+ x real = 0
+ x' ms**-1 = 0 * ms**-1
+
+ equations:
+ x'' = - 2 * x' / ms - x / ms**2
+
+ update:
+ integrate_odes(x)
+
+Here, ``integrate_odes(x)`` integrates the entire dynamics of ``x(t)``, in this case, ``x`` and ``x'``.
+
Note that the dynamical equations that correspond to convolutions are always updated, regardless of whether ``integrate_odes()`` is called. The state variables affected by incoming events are updated at the end of each timestep, that is, within one timestep, the state as observed by statements in the ``update`` block will be those at :math:`t^-`, i.e. "just before" it has been updated due to the events. See also :ref:`Integrating spiking input` and :ref:`Integration order`.
ODEs that can be solved analytically are integrated to machine precision from one timestep to the next using the propagators obtained from `ODE-toolbox `_. In case a numerical solver is used (such as Runge-Kutta or forward Euler), the same ODEs are also evaluated numerically by the numerical solver to allow more precise values for analytically solvable ODEs *within* a timestep. In this way, the long-term dynamics obeys the analytic (more exact) equations, while the short-term (within one timestep) dynamics is evaluated to the precision of the numerical integrator.
diff --git a/pynestml/cocos/co_co_integrate_odes_params_correct.py b/pynestml/cocos/co_co_integrate_odes_params_correct.py
index a587a830e..5365957cc 100644
--- a/pynestml/cocos/co_co_integrate_odes_params_correct.py
+++ b/pynestml/cocos/co_co_integrate_odes_params_correct.py
@@ -52,3 +52,6 @@ def visit_function_call(self, node):
if symbol_var is None or not symbol_var.is_state():
code, message = Messages.get_integrate_odes_wrong_arg(str(arg))
Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
+ elif symbol_var.is_state() and arg.get_variable().get_differential_order() > 0:
+ code, message = Messages.get_integrate_odes_arg_higher_order(str(arg))
+ Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2
index fa3fa85e0..63ace4f1a 100644
--- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2
+++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2
@@ -3,7 +3,7 @@
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if uses_analytic_solver %}
-{%- for variable_name in analytic_state_variables_: %}
+{%- for variable_name in analytic_state_variables_ %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index d66318130..e6772bf56 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -138,7 +138,21 @@ def filter_variables_list(cls, variables_list, variables_to_filter_by):
for var in variables_list:
if var in variables_to_filter_by:
ret.append(var)
+ # Add higher order variables of var if not already in the filter list
+ ret.extend(cls.get_higher_order_variables(var, variables_list, variables_to_filter_by))
+ return ret
+ @classmethod
+ def get_higher_order_variables(cls, var, variables_list, variables_to_filter_by) -> List[str]:
+ """
+ Returns a list of higher order state variables of ``var`` from the ``variables_list`` that are not already present in ``variables_to_filter_by``.
+ """
+ ret = []
+ for v in variables_list:
+ order = v.count('__d')
+ if order > 0:
+ if v.split("__d")[0] == var and v not in variables_to_filter_by:
+ ret.append(v)
return ret
@classmethod
diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py
index ce2913c7b..9c6157fb4 100644
--- a/pynestml/utils/messages.py
+++ b/pynestml/utils/messages.py
@@ -140,6 +140,7 @@ class MessageCode(Enum):
EXPONENT_MUST_BE_INTEGER = 114
EMIT_SPIKE_OUTPUT_PORT_TYPE_DIFFERS = 115
CONTINUOUS_OUTPUT_PORT_MAY_NOT_HAVE_ATTRIBUTES = 116
+ INTEGRATE_ODES_ARG_HIGHER_ORDER = 117
class Messages:
@@ -1300,6 +1301,11 @@ def get_integrate_odes_wrong_arg(cls, arg: str) -> Tuple[MessageCode, str]:
message = "Parameter provided to integrate_odes() function is not a state variable: '" + arg + "'"
return MessageCode.INTEGRATE_ODES_WRONG_ARG, message
+ @classmethod
+ def get_integrate_odes_arg_higher_order(cls, arg: str) -> Tuple[MessageCode, str]:
+ message = "Parameter provided to integrate_odes() function is a state variable of higher order: '" + arg + "'"
+ return MessageCode.INTEGRATE_ODES_ARG_HIGHER_ORDER, message
+
@classmethod
def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info, con_in_info) -> Tuple[MessageCode, str]:
message = ""
diff --git a/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml b/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml
new file mode 100644
index 000000000..fa0d1ef88
--- /dev/null
+++ b/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml
@@ -0,0 +1,44 @@
+"""
+alpha_function_2nd_order_ode_neuron.nestml
+##########################################
+
+Tests that for a system of higher-oder ODEs of the form F(x'',x',x)=0, integrate_odes(x) includes the integration of all the higher-order variables involved of the system.
+
+Copyright statement
++++++++++++++++++++
+
+This file is part of NEST.
+
+Copyright (C) 2004 The NEST Initiative
+
+NEST is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 2 of the License, or
+(at your option) any later version.
+
+NEST is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with NEST. If not, see .
+"""
+model alpha_function_2nd_order_ode_neuron:
+ state:
+ x real = 0
+ x' ms**-1 = 0 * ms**-1
+ y real = 0
+
+ input:
+ fX <- spike
+
+ equations:
+ x'' = - 2 * x' / ms - x / ms**2
+ y' = (-y + 42) / s
+
+ update:
+ integrate_odes(x, y)
+
+ onReceive(fX):
+ x' += e*fX * s / ms
diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py
index 6ddb699b4..ae4002062 100644
--- a/tests/nest_tests/test_integrate_odes.py
+++ b/tests/nest_tests/test_integrate_odes.py
@@ -54,7 +54,9 @@ def setUp(self):
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
- os.path.join("resources", "integrate_odes_nonlinear_test.nestml")))],
+ os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
+ os.path.realpath(os.path.join(os.path.dirname(__file__),
+ os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
logging_level="INFO",
suffix="_nestml")
@@ -131,11 +133,9 @@ def test_integrate_odes(self):
pass
# create the network
- spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
- nest.Connect(neuron, spikedet)
# simulate
nest.Simulate(sim_time)
@@ -182,11 +182,9 @@ def test_integrate_odes_nonlinear(self):
pass
# create the network
- spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_nonlinear_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
- nest.Connect(neuron, spikedet)
# simulate
nest.Simulate(sim_time)
@@ -232,3 +230,49 @@ def test_integrate_odes_params2(self):
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
+
+ def test_integrate_odes_higher_order(self):
+ r"""
+ Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x.
+ """
+ resolution = 0.1
+ simtime = 15.
+ nest.set_verbosity("M_ALL")
+ nest.ResetKernel()
+ nest.SetKernelStatus({"resolution": resolution})
+ try:
+ nest.Install("nestmlmodule")
+ except Exception:
+ # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
+ pass
+
+ n = nest.Create("alpha_function_2nd_order_ode_neuron_nestml")
+ sgX = nest.Create("spike_generator", params={"spike_times": [10.]})
+ nest.Connect(sgX, n, syn_spec={"weight": 1., "delay": resolution})
+
+ mm = nest.Create("multimeter", params={"interval": resolution, "record_from": ["x", "y"]})
+ nest.Connect(mm, n)
+
+ nest.Simulate(simtime)
+ times = mm.get()["events"]["times"]
+ x_actual = mm.get()["events"]["x"]
+ y_actual = mm.get()["events"]["y"]
+
+ if TEST_PLOTS:
+ fig, ax = plt.subplots(nrows=2)
+ ax1, ax2 = ax
+
+ ax2.plot(times, x_actual, label="x")
+ ax1.plot(times, y_actual, label="y")
+
+ for _ax in ax:
+ _ax.grid(which="major", axis="both")
+ _ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
+ _ax.set_xlim(0., simtime)
+ _ax.legend()
+
+ fig.savefig("/tmp/test_integrate_odes_higher_order.png", dpi=300)
+
+ # verify
+ np.testing.assert_allclose(x_actual[-1], 0.10737970490959549)
+ np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)