diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index ca6c0db924..a0e4652d57 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -12,3 +12,11 @@ "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, } + +try: + # Kilosort licence (GPL 3) is forcing us to make and use an external package + from spikeinterface_kilosort_components import KiloSortMatching + + matching_methods["kilosort-matching"] = KiloSortMatching +except ImportError: + pass diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 5c02646497..3186d5ba07 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -632,6 +632,5 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): return time_bin_centers_s, time_bin_edges_s - def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 807b8e6c9e..616c4fcbf2 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -12,6 +12,7 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset from spikeinterface.core import generate_ground_truth_recording + def make_fake_motion(rec): # make a fake motion object @@ -25,7 +26,7 @@ def make_fake_motion(rec): seg_time_bins = np.arange(0.5, duration - 0.49, 0.5) seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size)) seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None] - + temporal_bins.append(seg_time_bins) displacement.append(seg_disp) @@ -204,7 +205,6 @@ def test_InterpolateMotionRecording(): seed=2205, ) - motion = make_fake_motion(rec) rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index 7cd899a3bb..72aabc07de 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -45,9 +45,39 @@ def test_find_spikes_from_templates(method, sorting_analyzer): "templates": templates, } method_kwargs = {} - if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"): + if method in ( + "naive", + "tdc-peeler", + "circus", + ): method_kwargs["noise_levels"] = noise_levels + if method == "kilosort-matching": + from spikeinterface.sortingcomponents.peak_selection import select_peaks + from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + + peaks = detect_peaks(sorting_analyzer.recording, method="locally_exclusive", skip_after_n_peaks=5000) + few_wfs = extract_waveform_at_max_channel(sorting_analyzer.recording, peaks, ms_before=1, ms_after=2) + + wfs = few_wfs[:, :, 0] + import numpy as np + + n_components = 5 + from sklearn.cluster import KMeans + + wfs /= np.linalg.norm(wfs, axis=1)[:, None] + model = KMeans(n_clusters=n_components, n_init=10).fit(wfs) + temporal_components = model.cluster_centers_ + temporal_components = temporal_components / np.linalg.norm(temporal_components[:, None]) + temporal_components = temporal_components.astype(np.float32) + from sklearn.decomposition import TruncatedSVD + + model = TruncatedSVD(n_components=n_components).fit(wfs) + spatial_components = model.components_.astype(np.float32) + method_kwargs["spatial_components"] = spatial_components + method_kwargs["temporal_components"] = temporal_components + # method_kwargs["wobble"] = { # "templates": waveform_extractor.get_all_templates(), # "nbefore": waveform_extractor.nbefore, @@ -91,5 +121,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # method = "tdc-peeler" # method = "circus" # method = "circus-omp-svd" - method = "wobble" + # method = "wobble" + method = "kilosort-matching" + test_find_spikes_from_templates(method, sorting_analyzer)