Skip to content

Commit

Permalink
Change model registration for NEST target (#1002)
Browse files Browse the repository at this point in the history
* Change the way models are registered in NEST
* adjust compartmental_model_test.py to take into account the new initialization paradigm introduced in NEST 3.6

Co-authored-by: WillemWybo <willem.a.m.wybo@gmail.com>
  • Loading branch information
pnbabu and WillemWybo authored Apr 4, 2024
1 parent 4e02bf3 commit ca6e490
Show file tree
Hide file tree
Showing 52 changed files with 4,766 additions and 1,171 deletions.
237 changes: 201 additions & 36 deletions doc/tutorials/active_dendrite/nestml_active_dendrite_tutorial.ipynb

Large diffs are not rendered by default.

145 changes: 116 additions & 29 deletions doc/tutorials/izhikevich/nestml_izhikevich_tutorial.ipynb

Large diffs are not rendered by default.

332 changes: 243 additions & 89 deletions doc/tutorials/ornstein_uhlenbeck_noise/nestml_ou_noise_tutorial.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

1,105 changes: 963 additions & 142 deletions doc/tutorials/stdp_dopa_synapse/stdp_dopa_synapse.ipynb

Large diffs are not rendered by default.

1,150 changes: 1,061 additions & 89 deletions doc/tutorials/stdp_windows/stdp_windows.ipynb

Large diffs are not rendered by default.

820 changes: 733 additions & 87 deletions doc/tutorials/triplet_stdp_synapse/triplet_stdp_synapse.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions pynestml/codegeneration/nest_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __add_library_to_sli(lib_path):
lib_path = os.path.abspath(lib_path)

system = platform.system()
lib_key = ""

if system == "Linux":
lib_key = "LD_LIBRARY_PATH"
Expand All @@ -51,8 +50,7 @@ def __add_library_to_sli(lib_path):
if lib_key in os.environ:
current = os.environ[lib_key].split(os.pathsep)
if lib_path not in current:
current.append(lib_path)
os.environ[lib_key] += os.pathsep.join(current)
os.environ[lib_key] = os.pathsep.join([os.environ[lib_key], lib_path])
else:
os.environ[lib_key] = lib_path

Expand Down
3 changes: 3 additions & 0 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def _get_module_namespace(self, neurons: List[ASTNeuron], synapses: List[ASTSyna
"synapses": synapses,
"moduleName": FrontendConfiguration.get_module_name(),
"now": datetime.datetime.utcnow()}
# NEST version
if self.option_exists("nest_version"):
namespace["nest_version"] = self.get_option("nest_version")
return namespace

def analyse_transform_neurons(self, neurons: List[ASTNeuron]) -> None:
Expand Down
66 changes: 29 additions & 37 deletions pynestml/codegeneration/nest_code_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List, Optional

import re
Expand Down Expand Up @@ -63,12 +62,12 @@ def print_symbol_origin(cls, variable_symbol: VariableSymbol, variable: ASTVaria
def generate_code_for(cls,
nestml_neuron_model: str,
nestml_synapse_model: Optional[str] = None,
module_name: Optional[str] = None,
target_path: str = "target",
post_ports: Optional[List[str]] = None,
mod_ports: Optional[List[str]] = None,
uniq_id: Optional[str] = None,
logging_level: str = "WARNING"):
"""Generate code for a given neuron and synapse model, passed as a string.
NEST cannot yet unload or reload modules. This function implements a workaround using UUIDs to generate unique names.
The neuron and synapse models can be passed directly as strings in NESTML syntax, or as filenames, in which case the NESTML model is loaded from the given filename.
Returns
Expand All @@ -81,59 +80,52 @@ def generate_code_for(cls,
# generate temporary install directory
install_path = tempfile.mkdtemp(prefix="nestml_target_")

# generate unique ID
if uniq_id is None:
uniq_id = str(uuid.uuid4().hex)

# read neuron model from file?
if not "\n" in nestml_neuron_model and ".nestml" in nestml_neuron_model:
if "\n" not in nestml_neuron_model and ".nestml" in nestml_neuron_model:
with open(nestml_neuron_model, "r") as nestml_model_file:
nestml_neuron_model = nestml_model_file.read()

# update neuron model name inside the file
neuron_model_name_orig = re.findall(r"neuron\ [^:\s]*:", nestml_neuron_model)[0][7:-1]
neuron_model_name_uniq = neuron_model_name_orig + uniq_id
nestml_model = re.sub(r"neuron\ [^:\s]*:",
"neuron " + neuron_model_name_uniq + ":", nestml_neuron_model)
neuron_uniq_fn = neuron_model_name_uniq + ".nestml"
with open(neuron_uniq_fn, "w") as f:
print(nestml_model, file=f)
neuron_model_name = re.findall(r"neuron [^:\s]*:", nestml_neuron_model)[0][7:-1]
neuron_fn = neuron_model_name + ".nestml"
with open(neuron_fn, "w") as f:
print(nestml_neuron_model, file=f)

input_fns = [neuron_fn]
codegen_opts = {"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h"}
mangled_neuron_name = neuron_model_name + "_nestml"

if nestml_synapse_model:
# read synapse model from file?
if not "\n" in nestml_synapse_model and ".nestml" in nestml_synapse_model:
if "\n" not in nestml_synapse_model and ".nestml" in nestml_synapse_model:
with open(nestml_synapse_model, "r") as nestml_model_file:
nestml_synapse_model = nestml_model_file.read()

# update synapse model name inside the file
synapse_model_name_orig = re.findall(r"synapse\ [^:\s]*:", nestml_synapse_model)[0][8:-1]
synapse_model_name_uniq = synapse_model_name_orig + uniq_id
nestml_model = re.sub(r"synapse\ [^:\s]*:",
"synapse " + synapse_model_name_uniq + ":", nestml_synapse_model)
synapse_uniq_fn = synapse_model_name_uniq + ".nestml"
with open(synapse_uniq_fn, "w") as f:
print(nestml_model, file=f)

# generate the code for neuron and optionally synapse
module_name = "nestml_" + uniq_id + "_module"
input_fns = [neuron_uniq_fn]
codegen_opts = {"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h"}

mangled_neuron_name = neuron_model_name_uniq + "_nestml"
if nestml_synapse_model:
input_fns += [synapse_uniq_fn]
codegen_opts["neuron_synapse_pairs"] = [{"neuron": neuron_model_name_uniq,
"synapse": synapse_model_name_uniq,
synapse_model_name = re.findall(r"synapse [^:\s]*:", nestml_synapse_model)[0][8:-1]
synapse_fn = synapse_model_name + ".nestml"
with open(synapse_fn, "w") as f:
print(nestml_synapse_model, file=f)

input_fns += [synapse_fn]
codegen_opts["neuron_synapse_pairs"] = [{"neuron": neuron_model_name,
"synapse": synapse_model_name,
"post_ports": post_ports,
"vt_ports": mod_ports}]
mangled_neuron_name = neuron_model_name_uniq + "_nestml__with_" + synapse_model_name_uniq + "_nestml"
mangled_synapse_name = synapse_model_name_uniq + "_nestml__with_" + neuron_model_name_uniq + "_nestml"
mangled_neuron_name = neuron_model_name + "_nestml__with_" + synapse_model_name + "_nestml"
mangled_synapse_name = synapse_model_name + "_nestml__with_" + neuron_model_name + "_nestml"

if not module_name:
# generate unique ID
uniq_id = str(uuid.uuid4().hex)
module_name = "nestml_" + uniq_id + "_module"

generate_nest_target(input_path=input_fns,
install_path=install_path,
logging_level=logging_level,
module_name=module_name,
target_path=target_path,
suffix="_nestml",
codegen_opts=codegen_opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _get_module_namespace(self, neurons: List[ASTNeuron]) -> Dict:
:return: a context dictionary for rendering templates
"""
namespace = {"neurons": neurons,
"nest_version": self.get_option("nest_version"),
"moduleName": FrontendConfiguration.get_module_name(),
"now": datetime.datetime.utcnow()}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.
// Includes from nestkernel:
#include "exceptions.h"
#include "kernel_manager.h"
#include "nest_impl.h"
#include "universal_data_logger_impl.h"

// Includes from sli:
Expand All @@ -68,6 +69,15 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.

#include "{{neuronName}}.h"

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
void
register_{{ neuronName }}( const std::string& name )
{
nest::register_node_model< {{ neuronName }} >( name );
}
{%- endif %}

// ---------------------------------------------------------------------------
// Recordables map
// ---------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ size_t access_counter )

Receives: {% if has_spike_input %}Spike, {% endif %}{% if has_continuous_input %}Current,{% endif %} DataLoggingRequest
*/

// Register the neuron model
{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
void register_{{ neuronName }}( const std::string& name );
{%- endif %}

class {{neuronName}} : public nest::{{neuron_parent_class}}
{
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.

namespace nest
{
{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
// Register the synapse model
void register_{{ synapseName }}( const std::string& name );
{%- endif %}

namespace {{names_namespace}}
{
Expand Down Expand Up @@ -218,7 +223,6 @@ public:
{%- endif %}
};


template < typename targetidentifierT >
class {{synapseName}} : public Connection< targetidentifierT >
{
Expand Down Expand Up @@ -654,8 +658,13 @@ public:
}

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %}
{%- if not (nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
bool
send( Event& e, const size_t tid, const {{synapseName}}CommonSynapseProperties& cp )
{%- else %}
void
send( Event& e, const size_t tid, const {{synapseName}}CommonSynapseProperties& cp )
{%- endif %}
{%- else %}
void
send( Event& e, const thread tid, const {{synapseName}}CommonSynapseProperties& cp )
Expand Down Expand Up @@ -845,6 +854,11 @@ public:
**/

t_lastspike_ = __t_spike;

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
return true;
{%- endif %}
}

void get_status( DictionaryDatum& d ) const;
Expand All @@ -860,6 +874,15 @@ public:
{%- endif %}
};

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
void
register_{{ synapseName }}( const std::string& name )
{
nest::register_connection_model< {{ synapseName }} >( name );
}
{%- endif %}

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %}
template < typename targetidentifierT >
constexpr ConnectionModelProperties {{synapseName}}< targetidentifierT >::properties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,120 +19,11 @@
* along with NEST. If not, see <http://www.gnu.org/licenses/>.
*
*/#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
/*
* {{moduleName}}.cpp
*
* This file is part of NEST.
*
* Copyright (C) 2004 The NEST Initiative
*
* NEST is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* NEST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with NEST. If not, see <http://www.gnu.org/licenses/>.
*
* {{now}}
*/

// Includes from nestkernel:
#include "connection_manager_impl.h"
#include "connector_model_impl.h"
#include "dynamicloader.h"
#include "exceptions.h"
#include "genericmodel_impl.h"
#include "kernel_manager.h"
#include "model.h"
#include "model_manager_impl.h"
#include "nestmodule.h"
#include "target_identifier.h"

// Includes from sli:
#include "booldatum.h"
#include "integerdatum.h"
#include "sliexceptions.h"
#include "tokenarray.h"

// include headers with your own stuff
#include "{{moduleName}}.h"

{% for neuron in neurons %}
#include "{{neuron.get_name()}}.h"
{% endfor %}
{% for synapse in synapses %}
#include "{{synapse.get_name()}}.h"
{% endfor %}
// -- Interface to dynamic module loader ---------------------------------------

/*
* There are three scenarios, in which MyModule can be loaded by NEST:
*
* 1) When loading your module with `Install`, the dynamic module loader must
* be able to find your module. You make the module known to the loader by
* defining an instance of your module class in global scope. (LTX_MODULE is
* defined) This instance must have the name
*
* <modulename>_LTX_mod
*
* The dynamicloader can then load modulename and search for symbol "mod" in it.
*
* 2) When you link the library dynamically with NEST during compilation, a new
* object has to be created. In the constructor the DynamicLoaderModule will
* register your module. (LINKED_MODULE is defined)
*
* 3) When you link the library statically with NEST during compilation, the
* registration will take place in the file `static_modules.h`, which is
* generated by cmake.
*/
#if defined(LTX_MODULE) | defined(LINKED_MODULE)
{{moduleName}} {{moduleName}}_LTX_mod;
#endif

// -- DynModule functions ------------------------------------------------------

{{moduleName}}::{{moduleName}}()
{
#ifdef LINKED_MODULE
// register this module at the dynamic loader
// this is needed to allow for linking in this module at compile time
// all registered modules will be initialized by the main app's dynamic loader
nest::DynamicLoaderModule::registerLinkedModule( this );
#endif
}

{{moduleName}}::~{{moduleName}}()
{
}

const std::string
{{moduleName}}::name() const
{
return std::string("{{moduleName}}"); // Return name of the module
}

//-------------------------------------------------------------------------------------
void
{{moduleName}}::init( SLIInterpreter* i )
{
{%- if neurons %}
// register neurons
{%- for neuron in neurons %}
nest::kernel().model_manager.register_node_model<{{neuron.get_name()}}>("{{neuron.get_name()}}");
{%- endfor %}
{%- if nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6") %}
{%- include "common/ModuleClass.jinja2" %}
{%- else %}
{%- include "common/ModuleClassMaster.jinja2" %}
{%- endif %}
{%- if synapses %}

// register synapses
{%- for synapse in synapses %}
nest::register_connection_model< nest::{{synapse.get_name()}} >( "{{synapse.get_name()}}" );
{%- endfor %}
{%- endif %}
} // {{moduleName}}::init()
Loading

0 comments on commit ca6e490

Please sign in to comment.