Skip to content

Commit

Permalink
Merge pull request #3488 from samuelgarcia/kilosort_matching_gpl
Browse files Browse the repository at this point in the history
kilosort-matching in si
  • Loading branch information
alejoe91 authored Feb 3, 2025
2 parents 12a1276 + 8e6efd7 commit d9ad31f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/spikeinterface/sortingcomponents/matching/method_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@
"circus-omp-svd": CircusOMPSVDPeeler,
"wobble": WobbleMatch,
}

try:
# Kilosort licence (GPL 3) is forcing us to make and use an external package
from spikeinterface_kilosort_components import KiloSortMatching

matching_methods["kilosort-matching"] = KiloSortMatching
except ImportError:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,5 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
return time_bin_centers_s, time_bin_edges_s



def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from spikeinterface.sortingcomponents.tests.common import make_dataset
from spikeinterface.core import generate_ground_truth_recording


def make_fake_motion(rec):
# make a fake motion object

Expand All @@ -25,7 +26,7 @@ def make_fake_motion(rec):
seg_time_bins = np.arange(0.5, duration - 0.49, 0.5)
seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size))
seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None]

temporal_bins.append(seg_time_bins)
displacement.append(seg_disp)

Expand Down Expand Up @@ -204,7 +205,6 @@ def test_InterpolateMotionRecording():
seed=2205,
)


motion = make_fake_motion(rec)

rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,39 @@ def test_find_spikes_from_templates(method, sorting_analyzer):
"templates": templates,
}
method_kwargs = {}
if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"):
if method in (
"naive",
"tdc-peeler",
"circus",
):
method_kwargs["noise_levels"] = noise_levels

if method == "kilosort-matching":
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

peaks = detect_peaks(sorting_analyzer.recording, method="locally_exclusive", skip_after_n_peaks=5000)
few_wfs = extract_waveform_at_max_channel(sorting_analyzer.recording, peaks, ms_before=1, ms_after=2)

wfs = few_wfs[:, :, 0]
import numpy as np

n_components = 5
from sklearn.cluster import KMeans

wfs /= np.linalg.norm(wfs, axis=1)[:, None]
model = KMeans(n_clusters=n_components, n_init=10).fit(wfs)
temporal_components = model.cluster_centers_
temporal_components = temporal_components / np.linalg.norm(temporal_components[:, None])
temporal_components = temporal_components.astype(np.float32)
from sklearn.decomposition import TruncatedSVD

model = TruncatedSVD(n_components=n_components).fit(wfs)
spatial_components = model.components_.astype(np.float32)
method_kwargs["spatial_components"] = spatial_components
method_kwargs["temporal_components"] = temporal_components

# method_kwargs["wobble"] = {
# "templates": waveform_extractor.get_all_templates(),
# "nbefore": waveform_extractor.nbefore,
Expand Down Expand Up @@ -91,5 +121,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer):
# method = "tdc-peeler"
# method = "circus"
# method = "circus-omp-svd"
method = "wobble"
# method = "wobble"
method = "kilosort-matching"

test_find_spikes_from_templates(method, sorting_analyzer)

0 comments on commit d9ad31f

Please sign in to comment.