Skip to content

Commit

Permalink
Add explicit output parameters to spiking output ports (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
clinssen authored Nov 6, 2024
1 parent 5feb3df commit ac9b6e5
Show file tree
Hide file tree
Showing 43 changed files with 846 additions and 440 deletions.
24 changes: 24 additions & 0 deletions doc/nestml_language/nestml_language_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,30 @@ Each model can only send a single type of event. The type of the event has to be
Calling the ``emit_spike()`` function in the ``update`` block results in firing a spike to all target neurons and devices time stamped with the simulation time at the end of the time interval ``t + timestep()``.

Event attributes
~~~~~~~~~~~~~~~~

Each spiking output event can be parameterised by one or more attributes. For example, a synapse could assign a weight (as a real number) and delay (in milliseconds) to its spike events by including these values in the call to ``emit_spike()``:

.. code-block:: nestml
parameters:
weight real = 10.
update:
emit_spike(weight, 1 ms)
If spike event attributes are used, their names and types must be given as part of the output port specification, for example:

.. code-block:: nestml
output:
spike(weight real, delay ms)
The names are only used externally, so that other models can refer to the correct attribute (such as a downstream neuron that is receiving the spike through its input port). It is thus allowed to have a state variable called ``weight`` and an output port attribute by the same name; the output port attribute name does not refer to names declared inside the model.

Specific code generators may support a specific set of attributes; please check the documentation of each individual code generator for more details.


Equations
---------
Expand Down
8 changes: 8 additions & 0 deletions doc/running/running_nest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ For a full example, please see `iaf_psc_exp_multisynapse_vectors.nestml <https:/
Generating code
---------------

Output event attributes
~~~~~~~~~~~~~~~~~~~~~~~

In neuron models, no spike event attributes are supported.

In synapse models, precisely two spike event attributes are supported: a synaptic weight (as a real number) and a synaptic (dendritic) delay (in milliseconds).


Generating code for plastic synapses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/stdp_dopa_synapse/stdp_dopa_synapse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
" mod_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(mod_spikes):\n",
" n += A_vt / tau_n\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@
" I_post_dend pA <- continuous\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes):\n",
" # potentiate synapse\n",
Expand Down
6 changes: 3 additions & 3 deletions doc/tutorials/stdp_windows/stdp_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
" post_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes):\n",
" # potentiate synapse\n",
Expand Down Expand Up @@ -3214,7 +3214,7 @@
" post_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes):\n",
" post_nn_trace = 1\n",
Expand Down Expand Up @@ -4805,7 +4805,7 @@
" post_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes, priority=2):\n",
" w += lambda * (pre_trace + post_trace)\n",
Expand Down
4 changes: 2 additions & 2 deletions doc/tutorials/triplet_stdp_synapse/triplet_stdp_synapse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
" post_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes):\n",
" # increment post trace values\n",
Expand Down Expand Up @@ -1120,7 +1120,7 @@
" post_spikes <- spike\n",
"\n",
" output:\n",
" spike\n",
" spike(weight real, delay ms)\n",
"\n",
" onReceive(post_spikes):\n",
" # increment post trace values\n",
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/neuromodulated_stdp_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model neuromodulated_stdp_synapse:
mod_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(mod_spikes):
n += 1. / tau_n
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/noisy_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model noisy_synapse:
pre_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(pre_spikes):
# temporary variable for the "weight" that will be transmitted
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/static_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ model static_synapse:
pre_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(pre_spikes):
emit_spike(w, d)
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_pre_centered_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ model stdp_nn_pre_centered_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_restr_symm_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ model stdp_nn_restr_symm_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_symm_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ model stdp_nn_symm_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ model stdp_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace += 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_triplet_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model stdp_triplet_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
# potentiate synapse
Expand Down
52 changes: 43 additions & 9 deletions pynestml/cocos/co_co_output_port_defined_if_emit_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pynestml.cocos.co_co import CoCo
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_model import ASTModel
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.utils.ast_utils import ASTUtils
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.visitors.ast_visitor import ASTVisitor
Expand Down Expand Up @@ -60,22 +62,54 @@ def visit_function_call(self, node: ASTFunctionCall):
"""
assert self.neuron is not None
func_name = node.get_name()
if func_name == 'emit_spike':
if func_name == PredefinedFunctions.EMIT_SPIKE:
output_blocks = self.neuron.get_output_blocks()
if not output_blocks:

# exactly one output block should be defined
if len(output_blocks) == 0:
code, message = Messages.get_block_not_defined_correctly('output', missing=True)
Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
code=code, message=message)
return

spike_output_exists = False
for output_block in output_blocks:
if output_block.is_spike():
spike_output_exists = True
break
if len(output_blocks) > 1:
code, message = Messages.get_block_not_defined_correctly('output', missing=False)
Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
code=code, message=message)
return

assert len(output_blocks) == 1

if not spike_output_exists:
if not output_blocks[0].is_spike():
code, message = Messages.get_emit_spike_function_but_no_output_port()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())
error_position=output_blocks[0].get_source_position())
return

# check types
if len(node.get_args()) != len(output_blocks[0].get_attributes()):
code, message = Messages.get_output_port_type_differs()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=output_blocks[0].get_source_position())
return

for emit_spike_arg, output_block_attr in zip(node.get_args(), output_blocks[0].get_attributes()):
emit_spike_arg_type_sym = emit_spike_arg.type
output_block_attr_type_sym = output_block_attr.get_data_type().get_type_symbol()

if emit_spike_arg_type_sym.equals(output_block_attr_type_sym):
continue

if emit_spike_arg_type_sym.is_castable_to(output_block_attr_type_sym):
# types are not equal, but castable
code, message = Messages.get_implicit_cast_rhs_to_lhs(output_block_attr_type_sym.print_symbol(),
emit_spike_arg_type_sym.print_symbol())
Logger.log_message(error_position=output_blocks[0].get_source_position(),
code=code, message=message, log_level=LoggingLevel.WARNING)
continue
else:
# types are not equal and not castable
code, message = Messages.get_output_port_type_differs()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=output_blocks[0].get_source_position())
return
8 changes: 8 additions & 0 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ def print_output_block(self, node: ASTOutputBlock) -> str:
ret += print_n_spaces(self.indent) + "output:\n"
ret += print_n_spaces(self.indent + 4)
ret += "spike" if node.is_spike() else "continuous"
if node.get_attributes():
ret += "("
for i, attr in enumerate(node.get_attributes()):
ret += self.print(attr)
if i < len(node.get_attributes()) - 1:
ret += ", "

ret += ")"
ret += print_sl_comment(node.in_comment)
ret += "\n"
return ret
Expand Down
Loading

0 comments on commit ac9b6e5

Please sign in to comment.