Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with integrate_odes() for numeric solver #1147

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,17 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %
}
{%- endif %}


{% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %}
{%- if ast.get_args() | length > 0 %}
{%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %}
{%- endif %}
{%- for variable_name in numeric_state_variables + numeric_state_variables_moved %}
{%- set update_expr = numeric_update_expressions[variable_name] %}
{%- set variable_symbol = variable_symbols[variable_name] %}
{%- if use_gap_junctions %}
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
{%- else %}
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
{%- endif %}
{%- endfor %}

Expand Down
120 changes: 120 additions & 0 deletions tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
aeif_cond_alpha - Conductance based exponential integrate-and-fire neuron model
###############################################################################

Description
+++++++++++

aeif_psc_alpha is the adaptive exponential integrate and fire neuron according to Brette and Gerstner (2005), with post-synaptic conductances in the form of a bi-exponential ("alpha") function.

The membrane potential is given by the following differential equation:

.. math::

C_m \frac{dV_m}{dt} =
-g_L(V_m-E_L)+g_L\Delta_T\exp\left(\frac{V_m-V_{th}}{\Delta_T}\right) -
g_e(t)(V_m-E_e) \\
-g_i(t)(V_m-E_i)-w + I_e

and

.. math::

\tau_w \frac{dw}{dt} = a(V_m-E_L) - w

Note that the membrane potential can diverge to positive infinity due to the exponential term. To avoid numerical instabilities, instead of :math:`V_m`, the value :math:`\min(V_m,V_{peak})` is used in the dynamical equations.


References
++++++++++

.. [1] Brette R and Gerstner W (2005). Adaptive exponential
integrate-and-fire model as an effective description of neuronal
activity. Journal of Neurophysiology. 943637-3642
DOI: https://doi.org/10.1152/jn.00686.2005


See also
++++++++

iaf_psc_alpha, aeif_psc_exp
"""
model aeif_cond_alpha_alt_neuron:

state:
V_m mV = E_L # Membrane potential
w pA = 0 pA # Spike-adaptation current
refr_t ms = 0 ms # Refractory period timer
g_exc nS = 0 nS # AHP conductance
g_exc' nS/ms = 0 nS/ms # AHP conductance
g_inh nS = 0 nS # AHP conductance
g_inh' nS/ms = 0 nS/ms # AHP conductance

equations:
inline V_bounded mV = min(V_m, V_peak) # prevent exponential divergence

g_exc'' = -2 * g_exc' / tau_syn_exc - g_exc / tau_syn_exc**2
g_inh'' = -2 * g_inh' / tau_syn_inh - g_inh / tau_syn_inh**2

# Add inlines to simplify the equation definition of V_m
inline exp_arg real = (V_bounded - V_th) / Delta_T
inline I_spike pA = g_L * Delta_T * exp(exp_arg)

V_m' = (-g_L * (V_bounded - E_L) + I_spike - g_exc * (V_bounded - E_exc) - g_inh * (V_bounded - E_inh) - w + I_e + I_stim) / C_m
w' = (a * (V_bounded - E_L) - w) / tau_w

refr_t' = -1e3 * ms/s # refractoriness is implemented as an ODE, representing a timer counting back down to zero. XXX: TODO: This should simply read ``refr_t' = -1 / s`` (see https://github.com/nest/nestml/issues/984)

parameters:
# membrane parameters
C_m pF = 281.0 pF # Membrane Capacitance
refr_T ms = 2 ms # Duration of refractory period
V_reset mV = -60.0 mV # Reset Potential
g_L nS = 30.0 nS # Leak Conductance
E_L mV = -70.6 mV # Leak reversal Potential (aka resting potential)

# spike adaptation parameters
a nS = 4 nS # Subthreshold adaptation
b pA = 80.5 pA # Spike-triggered adaptation
Delta_T mV = 2.0 mV # Slope factor
tau_w ms = 144.0 ms # Adaptation time constant
V_th mV = -50.4 mV # Threshold Potential
V_peak mV = 0 mV # Spike detection threshold

# synaptic parameters
tau_syn_exc ms = 0.2 ms # Synaptic Time Constant Excitatory Synapse
tau_syn_inh ms = 2.0 ms # Synaptic Time Constant for Inhibitory Synapse
E_exc mV = 0 mV # Excitatory reversal Potential
E_inh mV = -85.0 mV # Inhibitory reversal Potential

# constant external input current
I_e pA = 0 pA

input:
exc_spikes <- excitatory spike
inh_spikes <- inhibitory spike
I_stim pA <- continuous

output:
spike

update:
if refr_t > 0 ms:
# neuron is absolute refractory, do not evolve V_m
integrate_odes(g_exc, g_inh, w, refr_t)
else:
# neuron not refractory
integrate_odes(g_exc, g_inh, V_m, w)

onReceive(exc_spikes):
g_exc' += exc_spikes * (e / tau_syn_exc) * nS * s

onReceive(inh_spikes):
g_inh' += inh_spikes * (e / tau_syn_inh) * nS * s

onCondition(refr_t <= 0 ms and V_m >= V_th):
# threshold crossing
refr_t = refr_T # start of the refractory period
V_m = V_reset
w += b
emit_spike()
85 changes: 73 additions & 12 deletions tests/nest_tests/test_integrate_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@

try:
import matplotlib

matplotlib.use("Agg")
import matplotlib.ticker
import matplotlib.pyplot as plt

TEST_PLOTS = True
except Exception:
TEST_PLOTS = False
Expand All @@ -50,22 +52,24 @@ def setUp(self):
r"""Generate the model code"""

generate_nest_target(input_path=[os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "aeif_cond_alpha_alt_neuron.nestml")))],
logging_level="INFO",
suffix="_nestml")

def test_convolutions_always_integrated(self):
r"""Test that synaptic integration continues for iaf_psc_exp, even when neuron is refractory."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
spike_interval = 5. # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
spike_interval = 5. # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -120,8 +124,8 @@ def test_convolutions_always_integrated(self):
def test_integrate_odes(self):
r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -169,8 +173,8 @@ def test_integrate_odes(self):
def test_integrate_odes_nonlinear(self):
r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated, for nonlinear ODEs."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -205,7 +209,6 @@ def test_integrate_odes_nonlinear(self):
for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
# _ax.minorticks_on()
_ax.set_xlim(0., sim_time)
_ax.legend()

Expand All @@ -218,15 +221,17 @@ def test_integrate_odes_nonlinear(self):
def test_integrate_odes_params(self):
r"""Test the integrate_odes() function, in particular with respect to the parameter types."""

fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
fname = os.path.realpath(
os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2

def test_integrate_odes_params2(self):
r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables."""

fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
fname = os.path.realpath(
os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
Expand Down Expand Up @@ -276,3 +281,59 @@ def test_integrate_odes_higher_order(self):
# verify
np.testing.assert_allclose(x_actual[-1], 0.10737970490959549)
np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)

def test_integrate_odes_numeric_higher_order(self):
r"""
Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x with a numeric solver.
"""
resolution = 0.1
simtime = 800.
params_nestml = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6,
"g_L": 300.0, 'E_exc': 20.0, 'E_inh': -85.0,
'tau_syn_exc': 40.0, 'tau_syn_inh': 20.0}

params_nest = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6,
"g_L": 300.0, 'E_ex': 20.0, 'E_in': -85.0,
'tau_syn_ex': 40.0, 'tau_syn_in': 20.0}

for model in ["aeif_cond_alpha_alt_neuron_nestml", "aeif_cond_alpha"]:
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass
n = nest.Create(model)
if "_nestml" in model:
nest.SetStatus(n, params_nestml)
else:
nest.SetStatus(n, params_nest)

spike = nest.Create("spike_generator")
spike_times = [10.0, 400.0]
nest.SetStatus(spike, {"spike_times": spike_times})
nest.Connect(spike, n, syn_spec={"weight": 0.1, "delay": 1.0})
nest.Connect(spike, n, syn_spec={"weight": -0.2, "delay": 100.})

mm = nest.Create("multimeter", params={"record_from": ["V_m"]})
nest.Connect(mm, n)

nest.Simulate(simtime)
times = mm.get()["events"]["times"]
if "_nestml" in model:
v_m_nestml = mm.get()["events"]["V_m"]
else:
v_m_nest = mm.get()["events"]["V_m"]

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=1)

ax.plot(times, v_m_nestml, label="NESTML")
ax.plot(times, v_m_nest, label="NEST")
ax.legend()

fig.savefig("/tmp/test_integrate_odes_numeric_higher_order.png")

np.testing.assert_allclose(v_m_nestml, v_m_nest)
Loading