From 806ca130bc4f9837403d93ac1989f2b1bfad2906 Mon Sep 17 00:00:00 2001 From: Pooja Babu <75320801+pnbabu@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:03:49 +0100 Subject: [PATCH] Integrate the entire dynamics for higher-order ODEs (#1139) --- .../nestml_language_concepts.rst | 16 ++++++ .../co_co_integrate_odes_params_correct.py | 3 ++ .../AnalyticIntegrationStep_begin.jinja2 | 2 +- pynestml/utils/ast_utils.py | 14 +++++ pynestml/utils/messages.py | 6 +++ ...alpha_function_2nd_order_ode_neuron.nestml | 44 +++++++++++++++ tests/nest_tests/test_integrate_odes.py | 54 +++++++++++++++++-- 7 files changed, 133 insertions(+), 6 deletions(-) create mode 100644 tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml diff --git a/doc/nestml_language/nestml_language_concepts.rst b/doc/nestml_language/nestml_language_concepts.rst index 5c8af3738..f2c658d05 100644 --- a/doc/nestml_language/nestml_language_concepts.rst +++ b/doc/nestml_language/nestml_language_concepts.rst @@ -1121,6 +1121,22 @@ Integrating the ODEs needs to be triggered explicitly inside the ``update`` bloc The ``integrate_odes()`` function numerically integrates the differential equations defined in the ``equations`` block. Integrating the ODEs from one timestep to the next has to be explicitly carried out in the model by calling the ``integrate_odes()`` function. If no parameters are given, all ODEs in the model are integrated. Integration can be limited to a given set of ODEs by giving their left-hand side state variables as parameters to the function, for example ``integrate_odes(V_m, I_ahp)`` if ODEs exist for the variables ``V_m`` and ``I_ahp``. In this example, these variables are integrated simultaneously (as one single system of equations). This is different from calling ``integrate_odes(V_m)`` and then ``integrate_odes(I_ahp)`` in that the second call would use the already-updated values from the first call. Variables not included in the call to ``integrate_odes()`` are assumed to remain constant (both inside the numeric solver stepping function as well as from before to after the call). +In case of higher-order ODEs of the form ``F(x'', x', x) = 0``, the solution ``x(t)`` is obtained by just providing the variable ``x`` to the ``integrate_odes`` function. For example, + +.. code-block:: nestml + + state: + x real = 0 + x' ms**-1 = 0 * ms**-1 + + equations: + x'' = - 2 * x' / ms - x / ms**2 + + update: + integrate_odes(x) + +Here, ``integrate_odes(x)`` integrates the entire dynamics of ``x(t)``, in this case, ``x`` and ``x'``. + Note that the dynamical equations that correspond to convolutions are always updated, regardless of whether ``integrate_odes()`` is called. The state variables affected by incoming events are updated at the end of each timestep, that is, within one timestep, the state as observed by statements in the ``update`` block will be those at :math:`t^-`, i.e. "just before" it has been updated due to the events. See also :ref:`Integrating spiking input` and :ref:`Integration order`. ODEs that can be solved analytically are integrated to machine precision from one timestep to the next using the propagators obtained from `ODE-toolbox `_. In case a numerical solver is used (such as Runge-Kutta or forward Euler), the same ODEs are also evaluated numerically by the numerical solver to allow more precise values for analytically solvable ODEs *within* a timestep. In this way, the long-term dynamics obeys the analytic (more exact) equations, while the short-term (within one timestep) dynamics is evaluated to the precision of the numerical integrator. diff --git a/pynestml/cocos/co_co_integrate_odes_params_correct.py b/pynestml/cocos/co_co_integrate_odes_params_correct.py index a587a830e..5365957cc 100644 --- a/pynestml/cocos/co_co_integrate_odes_params_correct.py +++ b/pynestml/cocos/co_co_integrate_odes_params_correct.py @@ -52,3 +52,6 @@ def visit_function_call(self, node): if symbol_var is None or not symbol_var.is_state(): code, message = Messages.get_integrate_odes_wrong_arg(str(arg)) Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR) + elif symbol_var.is_state() and arg.get_variable().get_differential_order() > 0: + code, message = Messages.get_integrate_odes_arg_higher_order(str(arg)) + Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR) diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2 index fa3fa85e0..63ace4f1a 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/AnalyticIntegrationStep_begin.jinja2 @@ -3,7 +3,7 @@ #} {%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} {%- if uses_analytic_solver %} -{%- for variable_name in analytic_state_variables_: %} +{%- for variable_name in analytic_state_variables_ %} {%- set update_expr = update_expressions[variable_name] %} {%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%} {%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%} diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index d66318130..e6772bf56 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -138,7 +138,21 @@ def filter_variables_list(cls, variables_list, variables_to_filter_by): for var in variables_list: if var in variables_to_filter_by: ret.append(var) + # Add higher order variables of var if not already in the filter list + ret.extend(cls.get_higher_order_variables(var, variables_list, variables_to_filter_by)) + return ret + @classmethod + def get_higher_order_variables(cls, var, variables_list, variables_to_filter_by) -> List[str]: + """ + Returns a list of higher order state variables of ``var`` from the ``variables_list`` that are not already present in ``variables_to_filter_by``. + """ + ret = [] + for v in variables_list: + order = v.count('__d') + if order > 0: + if v.split("__d")[0] == var and v not in variables_to_filter_by: + ret.append(v) return ret @classmethod diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index ce2913c7b..9c6157fb4 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -140,6 +140,7 @@ class MessageCode(Enum): EXPONENT_MUST_BE_INTEGER = 114 EMIT_SPIKE_OUTPUT_PORT_TYPE_DIFFERS = 115 CONTINUOUS_OUTPUT_PORT_MAY_NOT_HAVE_ATTRIBUTES = 116 + INTEGRATE_ODES_ARG_HIGHER_ORDER = 117 class Messages: @@ -1300,6 +1301,11 @@ def get_integrate_odes_wrong_arg(cls, arg: str) -> Tuple[MessageCode, str]: message = "Parameter provided to integrate_odes() function is not a state variable: '" + arg + "'" return MessageCode.INTEGRATE_ODES_WRONG_ARG, message + @classmethod + def get_integrate_odes_arg_higher_order(cls, arg: str) -> Tuple[MessageCode, str]: + message = "Parameter provided to integrate_odes() function is a state variable of higher order: '" + arg + "'" + return MessageCode.INTEGRATE_ODES_ARG_HIGHER_ORDER, message + @classmethod def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info, con_in_info) -> Tuple[MessageCode, str]: message = "" diff --git a/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml b/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml new file mode 100644 index 000000000..fa0d1ef88 --- /dev/null +++ b/tests/nest_tests/resources/alpha_function_2nd_order_ode_neuron.nestml @@ -0,0 +1,44 @@ +""" +alpha_function_2nd_order_ode_neuron.nestml +########################################## + +Tests that for a system of higher-oder ODEs of the form F(x'',x',x)=0, integrate_odes(x) includes the integration of all the higher-order variables involved of the system. + +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" +model alpha_function_2nd_order_ode_neuron: + state: + x real = 0 + x' ms**-1 = 0 * ms**-1 + y real = 0 + + input: + fX <- spike + + equations: + x'' = - 2 * x' / ms - x / ms**2 + y' = (-y + 42) / s + + update: + integrate_odes(x, y) + + onReceive(fX): + x' += e*fX * s / ms diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py index 6ddb699b4..ae4002062 100644 --- a/tests/nest_tests/test_integrate_odes.py +++ b/tests/nest_tests/test_integrate_odes.py @@ -54,7 +54,9 @@ def setUp(self): os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test.nestml"))), os.path.realpath(os.path.join(os.path.dirname(__file__), - os.path.join("resources", "integrate_odes_nonlinear_test.nestml")))], + os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))), + os.path.realpath(os.path.join(os.path.dirname(__file__), + os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),], logging_level="INFO", suffix="_nestml") @@ -131,11 +133,9 @@ def test_integrate_odes(self): pass # create the network - spikedet = nest.Create("spike_recorder") neuron = nest.Create("integrate_odes_test_nestml") mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]}) nest.Connect(mm, neuron) - nest.Connect(neuron, spikedet) # simulate nest.Simulate(sim_time) @@ -182,11 +182,9 @@ def test_integrate_odes_nonlinear(self): pass # create the network - spikedet = nest.Create("spike_recorder") neuron = nest.Create("integrate_odes_nonlinear_test_nestml") mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]}) nest.Connect(mm, neuron) - nest.Connect(neuron, spikedet) # simulate nest.Simulate(sim_time) @@ -232,3 +230,49 @@ def test_integrate_odes_params2(self): generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 + + def test_integrate_odes_higher_order(self): + r""" + Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x. + """ + resolution = 0.1 + simtime = 15. + nest.set_verbosity("M_ALL") + nest.ResetKernel() + nest.SetKernelStatus({"resolution": resolution}) + try: + nest.Install("nestmlmodule") + except Exception: + # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions + pass + + n = nest.Create("alpha_function_2nd_order_ode_neuron_nestml") + sgX = nest.Create("spike_generator", params={"spike_times": [10.]}) + nest.Connect(sgX, n, syn_spec={"weight": 1., "delay": resolution}) + + mm = nest.Create("multimeter", params={"interval": resolution, "record_from": ["x", "y"]}) + nest.Connect(mm, n) + + nest.Simulate(simtime) + times = mm.get()["events"]["times"] + x_actual = mm.get()["events"]["x"] + y_actual = mm.get()["events"]["y"] + + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=2) + ax1, ax2 = ax + + ax2.plot(times, x_actual, label="x") + ax1.plot(times, y_actual, label="y") + + for _ax in ax: + _ax.grid(which="major", axis="both") + _ax.grid(which="minor", axis="x", linestyle=":", alpha=.4) + _ax.set_xlim(0., simtime) + _ax.legend() + + fig.savefig("/tmp/test_integrate_odes_higher_order.png", dpi=300) + + # verify + np.testing.assert_allclose(x_actual[-1], 0.10737970490959549) + np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)