Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed GID assignment to transmitters involved in multiple connectivity sets #15

Merged
merged 10 commits into from
Sep 16, 2024
103 changes: 64 additions & 39 deletions bsb_neuron/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def create_neurons(self, simulation):
if (len(ps)) != 0:
self._create_population(simdata, cell_model, ps, offset)
offset += len(ps)
else:
simdata.populations[cell_model] = NeuronPopulation(cell_model, [])

def create_connections(self, simulation):
simdata = self.simdata[simulation]
Expand Down Expand Up @@ -170,54 +172,77 @@ def _allocate_transmitters(self, simulation):
simdata.transmap = self._map_transceivers(simulation, simdata)

def _map_transceivers(self, simulation, simdata):
blocks = []
offset = 0
transmap = {}

for cm, cs in simulation.get_connectivity_sets().items():
# For each connectivity set, determine how many unique transmitters they will place.
pre, _ = cs.load_connections().as_globals().all()
all_cm_transmitters = np.unique(pre[:, :2], axis=0)
# Now look up which transmitters are on our chunks
pre_t, _ = cs.load_connections().from_(simdata.chunks).as_globals().all()
our_cm_transmitters = np.unique(pre_t[:, :2], axis=0)
# Look up the local ids of those transmitters
pre_lc, _ = cs.load_connections().from_(simdata.chunks).all()
local_cm_transmitters = np.unique(pre_lc[:, :2], axis=0)

# Find the common indexes between all the transmitters, and the
# transmitters on our chunk.
dtype = ", ".join([str(all_cm_transmitters.dtype)] * 2)
_, _, idx_tm = np.intersect1d(
our_cm_transmitters.view(dtype),
all_cm_transmitters.view(dtype),
assume_unique=True,
return_indices=True,
)
pre_types = set(cs.pre_type for cs in simulation.get_connectivity_sets().values())
for pre_type in sorted(pre_types, key=lambda pre_type: pre_type.name):
data = []
for cm, cs in simulation.get_connectivity_sets().items():
if cs.pre_type != pre_type:
continue
pre, _ = cs.load_connections().as_globals().all()
data.append(pre[:, :2])

data = self.better_concat(data)
# Save all transmitters of the same pre_type across connectivity sets
all_cm_transmitters = np.unique(data, axis=0)
for cm, cs in simulation.get_connectivity_sets().items():
if cs.pre_type != pre_type:
continue

# Now look up which transmitters are on our chunks
pre_t, _ = cs.load_connections().from_(simdata.chunks).as_globals().all()
our_cm_transmitters = np.unique(pre_t[:, :2], axis=0)
# Look up the local ids of those transmitters
pre_lc, _ = cs.load_connections().from_(simdata.chunks).all()
local_cm_transmitters = np.unique(pre_lc[:, :2], axis=0)

# Find the common indexes between all the transmitters, and the
# transmitters on our chunk.
dtype = ", ".join([str(all_cm_transmitters.dtype)] * 2)
_, _, idx_tm = np.intersect1d(
our_cm_transmitters.view(dtype),
all_cm_transmitters.view(dtype),
assume_unique=True,
return_indices=True,
)

# Look up which transmitters have receivers on our chunks
pre_gc, _ = cs.load_connections().incoming().to(simdata.chunks).all()
local_cm_receivers = np.unique(pre_gc[:, :2], axis=0)
_, _, idx_rcv = np.intersect1d(
local_cm_receivers.view(dtype),
all_cm_transmitters.view(dtype),
assume_unique=True,
return_indices=True,
)
# Look up which transmitters have receivers on our chunks
pre_gc, _ = cs.load_connections().incoming().to(simdata.chunks).all()
local_cm_receivers = np.unique(pre_gc[:, :2], axis=0)
_, _, idx_rcv = np.intersect1d(
local_cm_receivers.view(dtype),
all_cm_transmitters.view(dtype),
assume_unique=True,
return_indices=True,
)

# Store a map of the local chunk transmitters to their GIDs
transmap[cm] = {
"transmitters": dict(
zip(map(tuple, local_cm_transmitters), map(int, idx_tm + offset))
),
"receivers": dict(
zip(map(tuple, local_cm_receivers), map(int, idx_rcv + offset))
),
}

# Store a map of the local chunk transmitters to their GIDs
transmap[cm] = {
"transmitters": dict(
zip(map(tuple, local_cm_transmitters), map(int, idx_tm + offset))
),
"receivers": dict(
zip(map(tuple, local_cm_receivers), map(int, idx_rcv + offset))
),
}
# Offset by the total amount of transmitter GIDs used by this ConnSet.
offset += len(all_cm_transmitters)
return transmap

def better_concat(self, items):
if not items:
raise RuntimeError("Can not concat 0 items")
l = sum(len(x) for x in items)
r = np.empty((l, items[0].shape[1]), dtype=items[0].dtype)
ptr = 0
for x in items:
r[ptr : ptr + len(x)] = x
ptr += len(x)
return r

def _create_population(self, simdata, cell_model, ps, offset):
data = []
for var in (
Expand Down
Loading