Skip to content

Commit

Permalink
Merge pull request #3616 from samuelgarcia/for_sigui
Browse files Browse the repository at this point in the history
Changes for spikeinterface-gui
  • Loading branch information
alejoe91 authored Jan 17, 2025
2 parents 692166b + 7e49ab9 commit bd22413
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 101 deletions.
70 changes: 57 additions & 13 deletions src/spikeinterface/widgets/sorting_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

import warnings

from .base import BaseWidget, to_attr

from .amplitudes import AmplitudesWidget
Expand All @@ -14,6 +16,9 @@
from ..core import SortingAnalyzer


_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude", "snr", "rp_violation"]


class SortingSummaryWidget(BaseWidget):
"""
Plots spike sorting summary.
Expand Down Expand Up @@ -42,14 +47,24 @@ class SortingSummaryWidget(BaseWidget):
label_choices : list or None, default: None
List of labels to be added to the curation table
(sortingview backend)
unit_table_properties : list or None, default: None
displayed_unit_properties : list or None, default: None
List of properties to be added to the unit table.
These may be drawn from the sorting extractor, and, if available,
the quality_metrics and template_metrics extensions of the SortingAnalyzer.
the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer.
See all properties available with sorting.get_property_keys(), and, if available,
analyzer.get_extension("quality_metrics").get_data().columns and
analyzer.get_extension("template_metrics").get_data().columns.
(sortingview backend)
extra_unit_properties : dict or None, default: None
A dict with extra units properties to display.
curation_dict : dict or None, default: None
When curation is True, optionaly the viewer can get a previous 'curation_dict'
to continue/check previous curations on this analyzer.
In this case label_definitions must be None beacuse it is already included in the curation_dict.
(spikeinterface_gui backend)
label_definitions : dict or None, default: None
When curation is True, optionaly the user can provide a label_definitions dict.
This replaces the label_choices in the curation_format.
(spikeinterface_gui backend)
"""

def __init__(
Expand All @@ -60,11 +75,24 @@ def __init__(
max_amplitudes_per_unit=None,
min_similarity_for_correlograms=0.2,
curation=False,
unit_table_properties=None,
displayed_unit_properties=None,
extra_unit_properties=None,
label_choices=None,
curation_dict=None,
label_definitions=None,
backend=None,
unit_table_properties=None,
**backend_kwargs,
):

if unit_table_properties is not None:
warnings.warn(
"plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead",
category=DeprecationWarning,
stacklevel=2,
)
displayed_unit_properties = unit_table_properties

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
self.check_extensions(
sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"]
Expand All @@ -74,18 +102,29 @@ def __init__(
if unit_ids is None:
unit_ids = sorting.get_unit_ids()

plot_data = dict(
if curation_dict is not None and label_definitions is not None:
raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both")

if displayed_unit_properties is None:
displayed_unit_properties = list(_default_displayed_unit_properties)
if extra_unit_properties is not None:
displayed_unit_properties += list(extra_unit_properties.keys())

data_plot = dict(
sorting_analyzer=sorting_analyzer,
unit_ids=unit_ids,
sparsity=sparsity,
min_similarity_for_correlograms=min_similarity_for_correlograms,
unit_table_properties=unit_table_properties,
displayed_unit_properties=displayed_unit_properties,
extra_unit_properties=extra_unit_properties,
curation=curation,
label_choices=label_choices,
max_amplitudes_per_unit=max_amplitudes_per_unit,
curation_dict=curation_dict,
label_definitions=label_definitions,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
Expand Down Expand Up @@ -156,7 +195,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

# unit ids
v_units_table = generate_unit_table_view(
dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores
dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores
)

if dp.curation:
Expand Down Expand Up @@ -190,9 +229,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
def plot_spikeinterface_gui(self, data_plot, **backend_kwargs):
sorting_analyzer = data_plot["sorting_analyzer"]

import spikeinterface_gui
from spikeinterface_gui import run_mainwindow

app = spikeinterface_gui.mkQApp()
win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"])
win.show()
app.exec_()
run_mainwindow(
sorting_analyzer,
with_traces=True,
curation=data_plot["curation"],
curation_dict=data_plot["curation_dict"],
label_definitions=data_plot["label_definitions"],
extra_unit_properties=data_plot["extra_unit_properties"],
displayed_unit_properties=data_plot["displayed_unit_properties"],
)
6 changes: 3 additions & 3 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,9 @@ def test_plot_motion_info(self):
# mytest.test_plot_unit_presence()
# mytest.test_plot_peak_activity()
# mytest.test_plot_multicomparison()
# mytest.test_plot_sorting_summary()
mytest.test_plot_sorting_summary()
# mytest.test_plot_motion()
mytest.test_plot_motion_info()
plt.show()
# mytest.test_plot_motion_info()
# plt.show()

# TestWidgets.tearDownClass()
90 changes: 90 additions & 0 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,93 @@ def array_to_image(
output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape)

return output_image


def make_units_table_from_sorting(sorting, units_table=None):
"""
Make a DataFrame from sorting properties.
Only for properties with ndim=1
Parameters
----------
sorting : Sorting
The Sorting object
units_table : None | pd.DataFrame
Optionally a existing dataframe.
Returns
-------
units_table : pd.DataFrame
Table containing all columns.
"""

if units_table is None:
import pandas as pd

units_table = pd.DataFrame(index=sorting.unit_ids)

for col in sorting.get_property_keys():
values = sorting.get_property(col)
if values.dtype.kind in "iuUSfb" and values.ndim == 1:
units_table.loc[:, col] = values

return units_table


def make_units_table_from_analyzer(
analyzer,
extra_properties=None,
):
"""
Make a DataFrame by aggregating :
* quality metrics
* template metrics
* unit_position
* sorting properties
* extra columns
This used in sortingview and spikeinterface-gui to display the units table in a flexible way.
Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
extra_properties : None | dict
Extra columns given as dict.
Returns
-------
units_table : pd.DataFrame
Table containing all columns.
"""
import pandas as pd

all_df = []

if analyzer.get_extension("unit_locations") is not None:
locs = analyzer.get_extension("unit_locations").get_data()
df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids)
all_df.append(df)

if analyzer.get_extension("quality_metrics") is not None:
df = analyzer.get_extension("quality_metrics").get_data()
all_df.append(df)

if analyzer.get_extension("template_metrics") is not None:
all_df = analyzer.get_extension("template_metrics").get_data()
all_df.append(df)

if len(all_df) > 0:
units_table = pd.concat(all_df, axis=1)
else:
units_table = pd.DataFrame(index=analyzer.unit_ids)

make_units_table_from_sorting(analyzer.sorting, units_table=units_table)

if extra_properties is not None:
for col, values in extra_properties.items():
# the ndim = 1 is important because we need column only for the display in gui.
if values.dtype.kind in "iuUSfb" and values.ndim == 1:
units_table.loc[:, col] = values

return units_table
124 changes: 39 additions & 85 deletions src/spikeinterface/widgets/utils_sortingview.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from warnings import warn

import numpy as np

from ..core import SortingAnalyzer, BaseSorting
from ..core.core_tools import check_json
from warnings import warn
from .utils import make_units_table_from_sorting, make_units_table_from_analyzer


def make_serializable(*args):
Expand Down Expand Up @@ -50,105 +52,57 @@ def handle_display_and_url(widget, view, **backend_kwargs):
def generate_unit_table_view(
sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting,
unit_properties: list[str] | None = None,
similarity_scores: npndarray | None = None,
similarity_scores: np.ndarray | None = None,
):
import sortingview.views as vv

if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer):
analyzer = sorting_or_sorting_analyzer
units_tables = make_units_table_from_analyzer(analyzer)
sorting = analyzer.sorting
else:
sorting = sorting_or_sorting_analyzer
analyzer = None

# Find available unit properties from all sources
sorting_props = list(sorting.get_property_keys())
if analyzer is not None:
if analyzer.get_extension("quality_metrics") is not None:
qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns)
qm_data = analyzer.get_extension("quality_metrics").get_data()
else:
qm_props = []
if analyzer.get_extension("template_metrics") is not None:
tm_props = list(analyzer.get_extension("template_metrics").get_data().columns)
tm_data = analyzer.get_extension("template_metrics").get_data()
else:
tm_props = []
# Check for any overlaps and warn user if any
all_props = sorting_props + qm_props + tm_props
else:
all_props = sorting_props
qm_props = []
tm_props = []
qm_data = None
tm_data = None

overlap_props = [prop for prop in all_props if all_props.count(prop) > 1]
if len(overlap_props) > 0:
warn(
f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}"
)

# Get unit properties
units_tables = make_units_table_from_sorting(sorting)
# analyzer = None

if unit_properties is None:
ut_columns = []
ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids]
else:
# keep only selected columns
unit_properties = np.array(unit_properties)
keep = np.isin(unit_properties, units_tables.columns)
unit_properties = unit_properties[keep]
units_tables = units_tables.loc[:, unit_properties]

dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"}

ut_columns = []
ut_rows = []
values = {}
valid_unit_properties = []

# Create columns for each property
for prop_name in unit_properties:

# Get property values from correct location
if prop_name in sorting_props:
property_values = sorting.get_property(prop_name)
elif prop_name in qm_props:
property_values = qm_data[prop_name].to_numpy()
elif prop_name in tm_props:
property_values = tm_data[prop_name].to_numpy()
else:
warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics")
for col in unit_properties:
if col not in units_tables.columns:
continue
values = units_tables[col].to_numpy()
if values.dtype.kind in dtype_convertor:
txt_dtype = dtype_convertor[values.dtype.kind]
ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=txt_dtype))

# make dtype available
val0 = np.array(property_values[0])
if val0.dtype.kind in ("i", "u"):
dtype = "int"
elif val0.dtype.kind in ("U", "S"):
dtype = "str"
elif val0.dtype.kind == "f":
dtype = "float"
elif val0.dtype.kind == "b":
dtype = "bool"
else:
warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping")
continue
ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype))
valid_unit_properties.append(prop_name)

# Create rows for each unit
for ui, unit in enumerate(sorting.unit_ids):
for prop_name in valid_unit_properties:

# Get property values from correct location
if prop_name in sorting_props:
property_values = sorting.get_property(prop_name)
elif prop_name in qm_props:
property_values = qm_data[prop_name].to_numpy()
elif prop_name in tm_props:
property_values = tm_data[prop_name].to_numpy()

# Check for NaN values and round floats
val0 = np.array(property_values[0])
if val0.dtype.kind == "f":
if np.isnan(property_values[ui]):
continue
property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False)
values[prop_name] = property_values[ui]
ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values)))
ut_rows = []
for unit_index, unit_id in enumerate(sorting.unit_ids):
row_values = {}
for col in unit_properties:
if col not in units_tables.columns:
continue
values = units_tables[col].to_numpy()
if values.dtype.kind in dtype_convertor:
value = values[unit_index]
if values.dtype.kind == "f":
# Check for NaN values and round floats
if np.isnan(values[unit_index]):
continue
value = np.format_float_positional(value, precision=4, fractional=False)
row_values[col] = value
ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values)))

v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores)

return v_units_table

0 comments on commit bd22413

Please sign in to comment.