Skip to content

Commit

Permalink
Merge pull request #3642 from alejoe91/fix-node-pipeline
Browse files Browse the repository at this point in the history
Fix node pipeline when multiple retrievers
  • Loading branch information
alejoe91 authored Jan 31, 2025
2 parents 67b2135 + 988b2a0 commit f5a50c8
Showing 1 changed file with 59 additions and 22 deletions.
81 changes: 59 additions & 22 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
self._dtype = spike_peak_dtype

self.include_spikes_in_margin = include_spikes_in_margin
if include_spikes_in_margin is not None:
if include_spikes_in_margin:
self._dtype = spike_peak_dtype + [("in_margin", "bool")]

self.peaks = sorting_to_peaks(sorting, extremum_channel_inds, self._dtype)
Expand Down Expand Up @@ -228,12 +228,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# if self.include_spikes_in_margin:
# i0, i1 = np.searchsorted(
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]
# )
# else:
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice

local_peaks = peaks_in_segment[i0:i1]
Expand Down Expand Up @@ -435,21 +429,59 @@ def compute(self, traces, peaks):
return sparse_wfs


def find_parent_of_type(list_of_parents, parent_type, unique=True):
def find_parent_of_type(list_of_parents, parent_type):
"""
Find a single parent of a given type(s) in a list of parents.
If multiple parents of the given type are found, the first parent is returned.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type of parent to search for.
Returns
-------
parent : PipelineNode or None
The parent of the given type. Returns None if no parent of the given type is found.
"""
if list_of_parents is None:
return None

parents = find_parents_of_type(list_of_parents, parent_type)

if len(parents) > 0:
return parents[0]
else:
return None


def find_parents_of_type(list_of_parents, parent_type):
"""
Find all parents of a given type(s) in a list of parents.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type(s) of parents to search for.
Returns
-------
parents : list of PipelineNode
List of parents of the given type(s). Returns an empty list if no parents of the given type(s) are found.
"""
if list_of_parents is None:
return []

parents = []
for parent in list_of_parents:
if isinstance(parent, parent_type):
parents.append(parent)

if unique and len(parents) == 1:
return parents[0]
elif not unique and len(parents) > 1:
return parents[0]
else:
return None
return parents


def check_graph(nodes):
Expand All @@ -471,7 +503,7 @@ def check_graph(nodes):
assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes"
assert (
nodes.index(parent) < i
), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition."
), f"Nodes are ordered incorrectly: {node} before {parent} in the pipeline definition."

return nodes

Expand Down Expand Up @@ -607,12 +639,16 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
node0 = nodes[0]

if isinstance(node0, (SpikeRetriever, PeakRetriever)):
# in this case PeakSource could have no peaks and so no need to load traces just skip
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
load_trace_and_compute = i0 < i1
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
# get peak slices once for all retrievers
peak_slice_by_retriever = {}
for retriever in retrievers:
peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
peak_slice_by_retriever[retriever] = peak_slice

if len(peak_slice_by_retriever) > 0:
# in this case the retrievers could have no peaks, so we test if any spikes are in the chunk
load_trace_and_compute = any(i0 < i1 for i0, i1 in peak_slice_by_retriever.values())
else:
# PeakDetector always need traces
load_trace_and_compute = True
Expand Down Expand Up @@ -646,7 +682,8 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
elif isinstance(node, (PeakRetriever, SpikeRetriever)):
peak_slice = peak_slice_by_retriever[node]
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down

0 comments on commit f5a50c8

Please sign in to comment.