Skip to content

Commit

Permalink
add attributes to spiking input ports
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jan 16, 2025
1 parent 022ac91 commit 4320b0c
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 17 deletions.
4 changes: 2 additions & 2 deletions models/neurons/iaf_psc_delta_neuron.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ model iaf_psc_delta_neuron:

equations:
kernel K_delta = delta(t)
V_m' = -(V_m - E_L) / tau_m + convolve(K_delta, spikes.weight) / s + (I_e + I_stim) / C_m # XXX: TODO: instead of the convolution, this should just read ``... + spikes.weight + ...``. This is a known issue (see https://github.com/nest/nestml/pull/1050).
V_m' = -(V_m - E_L) / tau_m + convolve(K_delta, spike_in_port.weight) / s + (I_e + I_stim) / C_m # XXX: TODO: instead of the convolution, this should just read ``... + spike_in_port.weight + ...``. This is a known issue (see https://github.com/nest/nestml/pull/1050).
refr_t' = -1e3 * ms/s # refractoriness is implemented as an ODE, representing a timer counting back down to zero. XXX: TODO: This should simply read ``refr_t' = -1 / s`` (see https://github.com/nest/nestml/issues/984)

parameters:
Expand All @@ -64,7 +64,7 @@ model iaf_psc_delta_neuron:
I_e pA = 0 pA

input:
spikes <- spike(weight mV)
spike_in_port <- spike(weight mV)
I_stim pA <- continuous

output:
Expand Down
6 changes: 5 additions & 1 deletion pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class NESTCodeGenerator(CodeGenerator):
- **continuous_state_buffering_method**: Which method to use for buffering state variables between neuron and synapse pairs. When a synapse has a "continuous" input port, connected to a postsynaptic neuron, either the value is obtained taking the synaptic (dendritic, that is, synapse-soma) delay into account, requiring a buffer to store the value at each timepoint (``continuous_state_buffering_method = "continuous_time_buffer"); or the value is obtained at the times of the somatic spikes of the postsynaptic neuron, ignoring the synaptic delay (``continuous_state_buffering_method == "post_spike_based"``). The former is more physically accurate but requires a large buffer and can require a long time to simulate. The latter ignores the dendritic delay but is much more computationally efficient.
- **delay_variable**: A mapping identifying, for each synapse (the name of which is given as a key), the variable or parameter in the model that corresponds with the NEST ``Connection`` class delay property.
- **weight_variable**: Like ``delay_variable``, but for synaptic weight.
- **linear_time_invariant_spiking_input_ports**: A list of spiking input ports which can be treated as linear and time-invariant; this implies that, for the given port(s), the weight of all spikes received within a timestep can be added together, improving memory consumption and runtime performance. Use with caution, for example, this is not compatible with using an input port as one processing inhibitory vs. excitatory spikes depending on the sign of the weight of the spike event.
- **redirect_build_output**: An optional boolean key for redirecting the build output. Setting the key to ``True``, two files will be created for redirecting the ``stdout`` and the ``stderr`. The ``target_path`` will be used as the default location for creating the two files.
- **build_output_dir**: An optional string key representing the new path where the files corresponding to the output of the build phase will be created. This key requires that the ``redirect_build_output`` is set to ``True``.
Expand Down Expand Up @@ -150,7 +151,8 @@ class NESTCodeGenerator(CodeGenerator):
"numeric_solver": "rk45",
"continuous_state_buffering_method": "continuous_time_buffer",
"delay_variable": {},
"weight_variable": {}
"weight_variable": {},
"linear_time_invariant_spiking_input_ports": []
}

def __init__(self, options: Optional[Mapping[str, Any]] = None):
Expand Down Expand Up @@ -531,6 +533,8 @@ def _get_model_namespace(self, astnode: ASTModel) -> Dict:
if "continuous_post_ports" in dir(astnode):
namespace["continuous_post_ports"] = astnode.continuous_post_ports

namespace["linear_time_invariant_spiking_input_ports"] = self.get_option("linear_time_invariant_spiking_input_ports")

return namespace

def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,37 @@ namespace nest
{%- for i in range(size) %}
{%- if inputPort.get_parameters() %}
{%- for parameter in inputPort.get_parameters() %}
{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_( nest::ListRingBuffer() )
// , spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- else %}
{%- for parameter in inputPort.get_parameters() %}
{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_( nest::ListRingBuffer() )
//, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}_( nest::ListRingBuffer() )
{%- endif %}
, spike_input_{{ inputPort.name }}_spike_input_received_( nest::RingBuffer() )
//, spike_input_{{ inputPort.name }}_spike_input_received_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- endif %}
Expand All @@ -236,21 +254,43 @@ namespace nest
{%- for i in range(size) %}
{%- if inputPort.get_parameters() %}
{%- for parameter in inputPort.get_parameters() %}

{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_( nest::ListRingBuffer() )
//, spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- else %}
?????????????
{%- endif %}
{%- endfor %}
{%- else %}
{%- for parameter in inputPort.get_parameters() %}



{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_( nest::ListRingBuffer() )
//, spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
, spike_input_{{ inputPort.name }}_( nest::RingBuffer() )
, spike_input_{{ inputPort.name }}_grid_sum_( 0. )
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
, spike_input_{{ inputPort.name }}_( nest::ListRingBuffer() )
{%- endif %}
, spike_input_{{ inputPort.name }}_spike_input_received_( nest::RingBuffer() )
//, spike_input_{{ inputPort.name }}_spike_input_received_grid_sum_( 0. )
{%- endif %}
{%- endfor %}
{%- endif %}
Expand Down Expand Up @@ -773,19 +813,62 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const

{%- for inputPortSymbol in neuron.get_spike_input_ports() %}
{%- set inputPort = utils.get_input_port_by_name(astnode.get_input_blocks(), inputPortSymbol.name.split(".")[0]) %}



{%- if inputPortSymbol.has_vector_parameter() %}
{%- set size = utils.get_numeric_vector_size(inputPortSymbol) %}
{%- for i in range(size) %}
{%- if inputPort.get_parameters() %}
{%- for parameter in inputPort.get_parameters() %}
const double __spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }} = std::accumulate(B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_.get_list(lag).begin(), spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_.get_list(lag).end(), 0.0);


{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_ = B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_.get_value(lag);
const double __spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }} = B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_;
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
const double __spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }} = std::accumulate(B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_.get_list(lag).begin(), B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_.get_list(lag).end(), 0.0);
{%- endif %}




{%- endfor %}
{%- endif %}
{%- endfor %}
{%- else %}
{%- for parameter in inputPort.get_parameters() %}


{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_ = B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_.get_value(lag);
const double __spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }} = B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_;
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
const double __spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }} = std::accumulate(B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_.get_list(lag).begin(), B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_.get_list(lag).end(), 0.0);
{%- endif %}





{%- endfor %}

{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
B_.spike_input_{{ inputPort.name }}_grid_sum_ = B_.spike_input_{{ inputPort.name }}_.get_value(lag);
const double __spike_input_{{ inputPort.name }} = B_.spike_input_{{ inputPort.name }}_grid_sum_;
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
const double __spike_input_{{ inputPort.name }} = std::accumulate(B_.spike_input_{{ inputPort.name }}_.get_list(lag).begin(), B_.spike_input_{{ inputPort.name }}_.get_list(lag).end(), 0.0);
{%- endif %}




{%- endif %}
{%- endfor %}

Expand Down Expand Up @@ -948,6 +1031,10 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
{%- for inputPortSymbol in neuron.get_spike_input_ports() %}
{%- set inputPort = utils.get_input_port_by_name(astnode.get_input_blocks(), inputPortSymbol.name.split(".")[0]) %}


{%- if inputPortSymbol.name not in linear_time_invariant_spiking_input_ports %}


{%- if inputPortSymbol.has_vector_parameter() %}
{%- set size = utils.get_numeric_vector_size(inputPortSymbol) %}
{%- for i in range(size) %}
Expand All @@ -963,10 +1050,13 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
std::list< double >& __spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_list = B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_.get_list(lag);
__spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_list.clear();
{%- endfor %}
{%- endif %}

std::list< double >& __spike_input_{{ inputPort.name }}_list = B_.spike_input_{{ inputPort.name }}_.get_list(lag);
__spike_input_{{ inputPort.name }}_list.clear();

{%- endif %}
{%- endif %}

{%- endfor %}
{%- endif %}

Expand Down Expand Up @@ -1235,19 +1325,40 @@ void {{ neuronName }}::handle(nest::SpikeEvent &e)
{
{%- if spike_in_port.get_parameters() %}
{%- for attribute in spike_in_port.get_parameters() %}


{%- if spike_in_port_name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
B_.spike_input_{{ spike_in_port_name }}__DOT__{{ attribute.name }}_.add_value(
e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin() ),
e.get_weight() * e.get_multiplicity() );
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
B_.spike_input_{{ spike_in_port_name }}__DOT__{{ attribute.name }}_.append_value(
e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin() ),
e.get_weight() * e.get_multiplicity() );
{%- endif %}


{%- endfor %}
{%- endif %}

// add an unweighted spike to the general "train of delta pulses" input buffer
//std::cout << "\tappending spike at offset = " << e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin()) << "; buffer size = " << B_.spike_input_{{ spike_in_port_name }}_.size() << "; nest::kernel().connection_manager.get_min_delay() = " << nest::kernel().connection_manager.get_min_delay() << "\n";
//std::cout << "\tappending spike at offset = " << e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin()) << " to B_.spike_input_{{ spike_in_port_name }}_, before length = " << B_.spike_input_{{ spike_in_port_name }}_.get_list(e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin())).size() << "\n";
// B_.spike_input_{{ spike_in_port_name }}_.resize();

{%- if spike_in_port_name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
B_.spike_input_{{ spike_in_port_name }}_.add_value(
e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin() ),
e.get_multiplicity() );
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
B_.spike_input_{{ spike_in_port_name }}_.append_value(
e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin() ),
e.get_multiplicity() );
{%- endif %}

//std::cout << "\tappending spike to B_.spike_input_{{ spike_in_port_name }}_, after length = " << B_.spike_input_{{ spike_in_port_name }}_.get_list(e.get_rel_delivery_steps( nest::kernel().simulation_manager.get_slice_origin())).size() << "\n";

Expand Down Expand Up @@ -1299,6 +1410,39 @@ void
const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function
auto get_t = [origin, lag](){ return nest::Time( nest::Time::step( origin.get_steps() + lag + 1) ).get_ms(); };

{%- if inputPortSymbol.name in linear_time_invariant_spiking_input_ports %}
{#- linear, time-invariant input port: all spike events for a specific buffer slot can be added together into a single number #}
/**
* Grab the actual spike event data from the buffers (for the current timepoint ``origin + lag``)
**/
{%- if inputPortSymbol.has_vector_parameter() %}
{%- set size = utils.get_numeric_vector_size(inputPortSymbol) %}
{%- for i in range(size) %}
{%- if inputPort.get_parameters() %}
{%- for parameter in inputPort.get_parameters() %}
const double __spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }} = B_.spike_input_{{ inputPort.name }}_VEC_IDX_{{ i }}__DOT__{{ parameter.get_name() }}_grid_sum_;
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- else %}
{%- for parameter in inputPort.get_parameters() %}
const double __spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }} = B_.spike_input_{{ inputPort.name }}__DOT__{{ parameter.get_name() }}_grid_sum_;
{%- endfor %}
{%- endif %}
const double __spike_input_{{ inputPort.name }} = B_.spike_input_{{ inputPort.name }}_grid_sum_;


/**
* Begin NESTML generated code for the onReceive() block statements
**/

{{ printer._expression_printer._simple_expression_printer._variable_printer.set_cpp_variable_suffix(" ") }} {# prevent printing origin #}
{% filter indent(4, True) -%}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfilter %}
{{ printer._expression_printer._simple_expression_printer._variable_printer.set_cpp_variable_suffix("") }}
{%- else %}
{#- generic input port: use lists of spike events for each buffer slot #}
// grab the lists of spike events from the buffers for the current timepoint
{%- if inputPortSymbol.has_vector_parameter() %}
{%- set size = utils.get_numeric_vector_size(inputPortSymbol) %}
Expand Down Expand Up @@ -1398,6 +1542,7 @@ std::cout << "\tclearing spike buffers....\n";
std::cout << "\tafter clearing " << __spike_input_{{ inputPort.name }}_list.size() << " spikes\n";
std::cout << "\tafter clearing (orig list) " << B_.spike_input_{{ inputPort.name }}_.get_list(lag).size() << " spikes\n";
*/
{%- endif %}
}

{% endfor %}
Expand Down
Loading

0 comments on commit 4320b0c

Please sign in to comment.