From 032181746c308ec1b0014f2d1986dab7eef9b444 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 23 Apr 2024 20:50:58 -0400 Subject: [PATCH 01/73] wip --- examples/stack_viewer.py | 2 +- .../_stack_viewer2/__init__.py | 0 .../_stack_viewer2/_stack_viewer.py | 139 ++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 src/pymmcore_widgets/_stack_viewer2/__init__.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py diff --git a/examples/stack_viewer.py b/examples/stack_viewer.py index ef0a47cfe..03dc1a78a 100644 --- a/examples/stack_viewer.py +++ b/examples/stack_viewer.py @@ -20,7 +20,7 @@ sequence = MDASequence( channels=( - # {"config": "DAPI", "exposure": 10}, + {"config": "DAPI", "exposure": 10}, # {"config": "FITC", "exposure": 1}, {"config": "Cy5", "exposure": 1}, ), diff --git a/src/pymmcore_widgets/_stack_viewer2/__init__.py b/src/pymmcore_widgets/_stack_viewer2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py new file mode 100644 index 000000000..96b309fa0 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import Hashable +from warnings import warn + +import numpy as np +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from skimage.data import camera +from superqt import QLabeledSlider +from superqt.iconify import QIconifyIcon +from vispy import scene + + +def noisy_camera() -> np.ndarray: + img = camera() + img = img + 0.3 * img.std() * np.random.standard_normal(img.shape) + return img + + +class PlayButton(QPushButton): + PLAY_ICON = "fa6-solid:play" + PAUSE_ICON = "fa6-solid:pause" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.PLAY_ICON) + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) + super().__init__(icn, text, parent) + self.setCheckable(True) + + +class DimsSlider(QWidget): + def __init__(self, dimension_name: str, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._name = dimension_name + self._play_btn = PlayButton(dimension_name) + self._slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._play_btn) + layout.addWidget(self._slider) + + def setMaximum(self, max_val: int) -> None: + self._slider.setMaximum(max_val) + + def setValue(self, val: int) -> None: + self._slider.setValue(val) + + +class DimsSliders(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + self._sliders: dict[str, DimsSlider] = {} + + def add_dimension(self, name: str) -> None: + self._sliders[name] = slider = DimsSlider(dimension_name=name, parent=self) + self.layout().addWidget(slider) + + def remove_dimension(self, name: str) -> None: + try: + slider = self._sliders.pop(name) + except KeyError: + warn(f"Dimension {name} not found in DimsSliders", stacklevel=2) + return + self.layout().removeWidget(slider) + slider.deleteLater() + + +class ViewerCanvas(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + + self._canvas = scene.SceneCanvas(parent=self, keys="interactive") + self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + central_wdg: scene.Widget = self._canvas.central_widget + self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) + + # Mapping of image key to Image visual objects + # tbd... determine what the key should be + # could have an image per channel, + # but may also have multiple images per channel... in the case of tiles, etc... + self._images: dict[Hashable, scene.visuals.Image] = {} + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._canvas.native) + + self.add_image("key", noisy_camera()) + self.reset_view() + + def add_image(self, key: Hashable, data: np.ndarray | None = None) -> None: + self._images[key] = img = scene.visuals.Image( + data, cmap="grays", parent=self._view.scene + ) + img.set_gl_state("additive", depth_test=False) + + def remove_image(self, key: Hashable) -> None: + try: + image = self._images.pop(key) + except KeyError: + warn(f"Image {key} not found in ViewerCanvas", stacklevel=2) + return + image.parent = None + + def reset_view(self) -> None: + self._camera.set_range() + + +class StackViewer(QWidget): + """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" + + def __init__(self, *, parent: QWidget | None = None): + super().__init__(parent=parent) + + self._canvas = ViewerCanvas() + self._info_bar = QLabel("Info") + self._dims_sliders = DimsSliders() + self._dims_sliders.add_dimension("z") + + layout = QVBoxLayout(self) + layout.addWidget(self._canvas, 1) + layout.addWidget(self._info_bar) + layout.addWidget(self._dims_sliders) + + +if __name__ == "__main__": + from qtpy.QtWidgets import QApplication + + app = QApplication([]) + + viewer = StackViewer() + viewer.show() + viewer.resize(600, 600) + + app.exec() From 452c66147526e222f108857e25bf60229eb7d994 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 24 Apr 2024 18:04:19 -0400 Subject: [PATCH 02/73] more progress --- examples/stack_viewer2.py | 34 ++ .../_stack_viewer2/_stack_viewer.py | 290 ++++++++++++++++-- 2 files changed, 295 insertions(+), 29 deletions(-) create mode 100644 examples/stack_viewer2.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py new file mode 100644 index 000000000..311939968 --- /dev/null +++ b/examples/stack_viewer2.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pymmcore_plus import CMMCorePlus, configure_logging +from qtpy import QtWidgets +from useq import MDASequence + +from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer + +configure_logging(stderr_level="WARNING") + +mmcore = CMMCorePlus.instance() +mmcore.loadSystemConfiguration() +mmcore.defineConfig("Channel", "DAPI", "Camera", "Mode", "Artificial Waves") +mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Noise") + +sequence = MDASequence( + channels=( + {"config": "DAPI", "exposure": 4}, + {"config": "FITC", "exposure": 10}, + # {"config": "Cy5", "exposure": 20}, + ), + stage_positions=[(0, 0), (1, 1)], + z_plan={"range": 2, "step": 0.4}, + time_plan={"interval": 2, "loops": 5}, + grid_plan={"rows": 2, "columns": 1}, +) + + +qapp = QtWidgets.QApplication([]) +v = StackViewer() +v.show() + +mmcore.run_mda(sequence, output=v.datastore) +qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 96b309fa0..e23459446 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,24 +1,55 @@ from __future__ import annotations -from typing import Hashable +from contextlib import suppress +from itertools import cycle +from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping from warnings import warn -import numpy as np +import cmap +import superqt +import useq +from psygnal import Signal +from pymmcore_plus import CMMCorePlus +from pymmcore_plus.mda.handlers import OMEZarrWriter from qtpy.QtCore import Qt from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from skimage.data import camera from superqt import QLabeledSlider from superqt.iconify import QIconifyIcon from vispy import scene +if TYPE_CHECKING: + import numpy as np + from PySide6.QtCore import QTimerEvent + from vispy.scene.events import SceneMouseEvent -def noisy_camera() -> np.ndarray: - img = camera() - img = img + 0.3 * img.std() * np.random.standard_normal(img.shape) - return img + +CHANNEL = "c" +COLORMAPS = cycle( + [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] +) + + +def try_cast_colormap(val: Any) -> cmap.Colormap | None: + """Try to cast `val` to a cmap.Colormap instance, return None if it fails.""" + if isinstance(val, cmap.Colormap): + return val + with suppress(Exception): + return cmap.Colormap(val) + return None + + +# FIXME: get rid of this thin subclass +class DataStore(OMEZarrWriter): + frame_ready = Signal(useq.MDAEvent) + + def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: + super().frameReady(frame, event, meta) + self.frame_ready.emit(event) class PlayButton(QPushButton): + """Just a styled QPushButton that toggles between play and pause icons.""" + PLAY_ICON = "fa6-solid:play" PAUSE_ICON = "fa6-solid:pause" @@ -29,35 +60,117 @@ def __init__(self, text: str = "", parent: QWidget | None = None) -> None: self.setCheckable(True) +class LockButton(QPushButton): + LOCK_ICON = "fa6-solid:lock-open" + UNLOCK_ICON = "fa6-solid:lock" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.LOCK_ICON) + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On) + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setMaximumWidth(20) + + class DimsSlider(QWidget): + """A single slider in the DimsSliders widget. + + Provides a play/pause button that toggles animation of the slider value. + Has a QLabeledSlider for the actual value. + Adds a label for the maximum value (e.g. "3 / 10") + """ + + valueChanged = Signal(str, int) + def __init__(self, dimension_name: str, parent: QWidget | None = None) -> None: super().__init__(parent) + self._interval = 1000 // 10 self._name = dimension_name + self._play_btn = PlayButton(dimension_name) + self._play_btn.toggled.connect(self._toggle_animation) + # note, this lock button only prevents the slider from updating programmatically + # using self.setValue, it doesn't prevent the user from changing the value. + self._lock_btn = LockButton() + + self._max_label = QLabel("/ 0") self._slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) + self._slider.setMaximum(0) + self._slider.rangeChanged.connect(self._on_range_changed) + self._slider.valueChanged.connect(self._on_value_changed) + self._slider.layout().addWidget(self._max_label) + layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self._play_btn) layout.addWidget(self._slider) + layout.addWidget(self._lock_btn) def setMaximum(self, max_val: int) -> None: self._slider.setMaximum(max_val) def setValue(self, val: int) -> None: + # variant of setValue that always updates the maximum + if val > self._slider.maximum(): + self._slider.setMaximum(val) + if self._lock_btn.isChecked(): + return self._slider.setValue(val) + def set_fps(self, fps: int) -> None: + self._interval = 1000 // fps + + def _toggle_animation(self, checked: bool) -> None: + if checked: + self._timer_id = self.startTimer(self._interval) + else: + self.killTimer(self._timer_id) + + def timerEvent(self, event: QTimerEvent) -> None: + val = self._slider.value() + next_val = (val + 1) % (self._slider.maximum() + 1) + self._slider.setValue(next_val) + + def _on_range_changed(self, min: int, max: int) -> None: + self._max_label.setText("/ " + str(max)) + + def _on_value_changed(self, value: int) -> None: + self.valueChanged.emit(self._name, value) + class DimsSliders(QWidget): + """A Collection of DimsSlider widgets for each dimension in the data. + + Maintains the global current index and emits a signal when it changes. + """ + + indexChanged = Signal(dict) + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._sliders: dict[str, DimsSlider] = {} + self._current_index: dict[str, int] = {} + self._invisible_dims: set[str] = set() + self._updating = False + layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - self._sliders: dict[str, DimsSlider] = {} + layout.setSpacing(0) def add_dimension(self, name: str) -> None: self._sliders[name] = slider = DimsSlider(dimension_name=name, parent=self) + self._current_index[name] = 0 + slider.valueChanged.connect(self._on_value_changed) self.layout().addWidget(slider) + slider.setVisible(name not in self._invisible_dims) + + def set_dimension_visible(self, name: str, visible: bool) -> None: + if visible: + self._invisible_dims.discard(name) + else: + self._invisible_dims.add(name) + if name in self._sliders: + self._sliders[name].setVisible(visible) def remove_dimension(self, name: str) -> None: try: @@ -68,12 +181,50 @@ def remove_dimension(self, name: str) -> None: self.layout().removeWidget(slider) slider.deleteLater() + def _on_value_changed(self, dim_name: str, value: int) -> None: + self._current_index[dim_name] = value + if not self._updating: + self.indexChanged.emit(self._current_index) -class ViewerCanvas(QWidget): - def __init__(self, parent: QWidget | None = None) -> None: + def add_or_update_dimension(self, name: str, value: int) -> None: + if name in self._sliders: + self._sliders[name].setValue(value) + else: + self.add_dimension(name) + + def update_dimensions(self, index: Mapping[str, int]) -> None: + prev = self._current_index.copy() + self._updating = True + try: + for dim, value in index.items(): + self.add_or_update_dimension(dim, value) + if self._current_index != prev: + self.indexChanged.emit(self._current_index) + finally: + self._updating = False + + +class VispyViewerCanvas(QWidget): + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + infoText = Signal(str) + + def __init__( + self, + datastore: OMEZarrWriter, + channel_mode: str = "composite", + parent: QWidget | None = None, + ) -> None: super().__init__(parent) + self._datastore = datastore - self._canvas = scene.SceneCanvas(parent=self, keys="interactive") + self._channel_mode = channel_mode + self._canvas = scene.SceneCanvas(parent=self) + self._canvas.events.mouse_move.connect(self._on_mouse_move) self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) central_wdg: scene.Widget = self._canvas.central_widget @@ -89,16 +240,44 @@ def __init__(self, parent: QWidget | None = None) -> None: layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self._canvas.native) - self.add_image("key", noisy_camera()) - self.reset_view() + self.set_range() + + def _on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + images = [] + # Get the images the mouse is over + while image := self._canvas.visual_at(event.pos): + if image in self._images.values(): + images.append(image) + image.interactive = False + for img in self._images.values(): + img.interactive = True + if not images: + return + + tform = images[0].get_transform("canvas", "visual") + px, py, *_ = (int(x) for x in tform.map(event.pos)) + text = f"[{py}, {px}]" + for c, img in enumerate(images): + value = f"{img._data[py, px]}" + text += f" c{c}: {value}" + self.infoText.emit(text) def add_image(self, key: Hashable, data: np.ndarray | None = None) -> None: + """Add a new Image node to the scene.""" + if self._channel_mode == "composite": + cmap = next(COLORMAPS).to_vispy() + else: + cmap = "grays" self._images[key] = img = scene.visuals.Image( - data, cmap="grays", parent=self._view.scene + data, cmap=cmap, parent=self._view.scene ) img.set_gl_state("additive", depth_test=False) + img.interactive = True + self.set_range() def remove_image(self, key: Hashable) -> None: + """Remove an Image node from the scene.""" try: image = self._images.pop(key) except KeyError: @@ -106,8 +285,56 @@ def remove_image(self, key: Hashable) -> None: return image.parent = None - def reset_view(self) -> None: - self._camera.set_range() + def set_image_data(self, key: Hashable, data: np.ndarray) -> None: + """Set the data for an existing Image node.""" + self._images[key].set_data(data) + self._canvas.update() + + def set_image_cmap(self, key: Hashable, cmap: str) -> None: + """Set the colormap for an existing Image node.""" + self._images[key].cmap = cmap + self._canvas.update() + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + margin: float | None = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + self._camera.set_range(x=x, y=y, margin=margin) + + def _image_key(self, index: dict) -> Hashable: + dims_needing_images = set() + if self._channel_mode == "composite": + dims_needing_images.add(CHANNEL) + + return tuple((dim, index.get(dim)) for dim in dims_needing_images) + + def set_current_index(self, index: Mapping[str, int]) -> None: + """Set the current image index.""" + cidx = ((CHANNEL, index.get("c")),) + if self._channel_mode == "composite" and cidx in self._images: + # if we're in composite mode, we need to update the image for each channel + for key, _ in self._images.items(): + # FIXME + try: + image_data = self._datastore.isel(index, c=key[0][1]) + except IndexError: + print("ERR", key, index) + continue + self.set_image_data(key, image_data) + + else: + # otherwise, we only have a single image to update + frame = self._datastore.isel(index) + if (key := self._image_key(index)) not in self._images: + self.add_image(key, frame) + else: + self.set_image_data(key, frame) class StackViewer(QWidget): @@ -116,24 +343,29 @@ class StackViewer(QWidget): def __init__(self, *, parent: QWidget | None = None): super().__init__(parent=parent) - self._canvas = ViewerCanvas() + channel_mode: Literal["composite", "grayscale"] = "composite" + + self._core = CMMCorePlus.instance() + self.datastore = DataStore() + self._canvas = VispyViewerCanvas(self.datastore, channel_mode=channel_mode) self._info_bar = QLabel("Info") self._dims_sliders = DimsSliders() - self._dims_sliders.add_dimension("z") + + if channel_mode == "composite": + self._dims_sliders.set_dimension_visible(CHANNEL, False) + + self._canvas.infoText.connect(lambda x: self._info_bar.setText(x)) + self.datastore.frame_ready.connect(self.on_frame_ready) + self._dims_sliders.indexChanged.connect(self._on_dims_sliders) layout = QVBoxLayout(self) layout.addWidget(self._canvas, 1) layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) + def _on_dims_sliders(self, index: dict) -> None: + self._canvas.set_current_index(index) -if __name__ == "__main__": - from qtpy.QtWidgets import QApplication - - app = QApplication([]) - - viewer = StackViewer() - viewer.show() - viewer.resize(600, 600) - - app.exec() + @superqt.ensure_main_thread + def on_frame_ready(self, event: useq.MDAEvent) -> None: + self._dims_sliders.update_dimensions(event.index) From 0d5b085ce7951435f291a71f127eb0d8735bb7a0 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 24 Apr 2024 19:45:11 -0400 Subject: [PATCH 03/73] wip --- examples/stack_viewer2.py | 6 +- .../_stack_viewer/_stack_viewer.py | 6 +- .../_stack_viewer2/_stack_viewer.py | 202 ++++++++++++------ 3 files changed, 144 insertions(+), 70 deletions(-) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 311939968..90d3cbc02 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -15,14 +15,14 @@ sequence = MDASequence( channels=( - {"config": "DAPI", "exposure": 4}, + {"config": "DAPI", "exposure": 16}, {"config": "FITC", "exposure": 10}, # {"config": "Cy5", "exposure": 20}, ), stage_positions=[(0, 0), (1, 1)], z_plan={"range": 2, "step": 0.4}, - time_plan={"interval": 2, "loops": 5}, - grid_plan={"rows": 2, "columns": 1}, + time_plan={"interval": 0.8, "loops": 2}, + # grid_plan={"rows": 2, "columns": 1}, ) diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py index 987ed495f..9b9211736 100644 --- a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py @@ -464,11 +464,9 @@ def _reload_position(self) -> None: self.cmap_names = self.qt_settings.value("cmaps", ["gray", "cyan", "magenta"]) def _collapse_view(self) -> None: + w, h = self.img_size view_rect = ( - ( - self.view_rect[0][0] - self.img_size[0] / 2, - self.view_rect[0][1] + self.img_size[1] / 2, - ), + (self.view_rect[0][0] - w / 2, self.view_rect[0][1] + h / 2), self.view_rect[1], ) self.view.camera.rect = view_rect diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index e23459446..17007a306 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,27 +1,39 @@ from __future__ import annotations from contextlib import suppress +import logging from itertools import cycle -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, cast from warnings import warn import cmap +import numpy as np import superqt import useq -from psygnal import Signal +from psygnal import Signal as psygnalSignal from pymmcore_plus import CMMCorePlus from pymmcore_plus.mda.handlers import OMEZarrWriter -from qtpy.QtCore import Qt -from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QLabeledSlider +from qtpy.QtCore import Qt, Signal +from qtpy.QtWidgets import ( + QCheckBox, + QHBoxLayout, + QLabel, + QPushButton, + QVBoxLayout, + QWidget, +) +from superqt import QLabeledRangeSlider, QLabeledSlider +from superqt.cmap import QColormapComboBox from superqt.iconify import QIconifyIcon from vispy import scene if TYPE_CHECKING: - import numpy as np + import numpy.typing as npt from PySide6.QtCore import QTimerEvent from vispy.scene.events import SceneMouseEvent + ImageKey = tuple[tuple[str, int], ...] + CHANNEL = "c" COLORMAPS = cycle( @@ -29,18 +41,9 @@ ) -def try_cast_colormap(val: Any) -> cmap.Colormap | None: - """Try to cast `val` to a cmap.Colormap instance, return None if it fails.""" - if isinstance(val, cmap.Colormap): - return val - with suppress(Exception): - return cmap.Colormap(val) - return None - - # FIXME: get rid of this thin subclass class DataStore(OMEZarrWriter): - frame_ready = Signal(useq.MDAEvent) + frame_ready = psygnalSignal(useq.MDAEvent) def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: super().frameReady(frame, event, meta) @@ -72,6 +75,43 @@ def __init__(self, text: str = "", parent: QWidget | None = None) -> None: self.setMaximumWidth(20) +class ChannelVisControl(QWidget): + visibilityChanged = Signal(bool) + climsChanged = Signal(tuple) + cmapChanged = Signal(cmap.Colormap) + + def __init__(self, idx: int, name: str = "", parent: QWidget | None = None) -> None: + super().__init__(parent) + self.idx = idx + self._name = name + + self._visible = QCheckBox(name) + self._visible.setChecked(True) + self._visible.toggled.connect(self.visibilityChanged) + + self._cmap = QColormapComboBox(allow_user_colormaps=True) + self._cmap.currentColormapChanged.connect(self.cmapChanged) + for color in ["green", "magenta", "cyan"]: + self._cmap.addColormap(color) + + self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self._clims.setRange(0, 2**12) + self._clims.valueChanged.connect(self.climsChanged) + + self._auto_clim = QCheckBox("Auto") + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._visible) + layout.addWidget(self._cmap) + layout.addWidget(self._clims) + layout.addWidget(self._auto_clim) + + def set_clim_for_dtype(self, dtype: npt.DTypeLike) -> None: + # get maximum possible value for the dtype + self._clims.setRange(0, np.iinfo(dtype).max) + + class DimsSlider(QWidget): """A single slider in the DimsSliders widget. @@ -234,24 +274,24 @@ def __init__( # tbd... determine what the key should be # could have an image per channel, # but may also have multiple images per channel... in the case of tiles, etc... - self._images: dict[Hashable, scene.visuals.Image] = {} + self._images: dict[ImageKey, scene.visuals.Image] = {} layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self._canvas.native) - self.set_range() - def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" images = [] # Get the images the mouse is over - while image := self._canvas.visual_at(event.pos): - if image in self._images.values(): - images.append(image) - image.interactive = False - for img in self._images.values(): - img.interactive = True + seen = set() + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + for visual in seen: + visual.interactive = True if not images: return @@ -259,40 +299,44 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: px, py, *_ = (int(x) for x in tform.map(event.pos)) text = f"[{py}, {px}]" for c, img in enumerate(images): - value = f"{img._data[py, px]}" - text += f" c{c}: {value}" + with suppress(IndexError): + text += f" c{c}: {img._data[py, px]}" self.infoText.emit(text) - def add_image(self, key: Hashable, data: np.ndarray | None = None) -> None: + def add_image(self, key: ImageKey, data: np.ndarray | None = None) -> None: """Add a new Image node to the scene.""" if self._channel_mode == "composite": cmap = next(COLORMAPS).to_vispy() else: cmap = "grays" + + print('add', data) self._images[key] = img = scene.visuals.Image( data, cmap=cmap, parent=self._view.scene ) img.set_gl_state("additive", depth_test=False) - img.interactive = True self.set_range() + img.interactive = True - def remove_image(self, key: Hashable) -> None: - """Remove an Image node from the scene.""" - try: - image = self._images.pop(key) - except KeyError: - warn(f"Image {key} not found in ViewerCanvas", stacklevel=2) - return - image.parent = None + def set_channel_visibility(self, ch_idx: int, visible: bool) -> None: + """Set the visibility of an existing Image node.""" + self._map_func(lambda i: setattr(i, "visible", visible), (CHANNEL, ch_idx)) - def set_image_data(self, key: Hashable, data: np.ndarray) -> None: - """Set the data for an existing Image node.""" - self._images[key].set_data(data) - self._canvas.update() + def set_channel_clims(self, ch_idx: int, clims: tuple) -> None: + """Set the contrast limits for an existing Image node.""" + self._map_func(lambda i: setattr(i, "clim", clims), (CHANNEL, ch_idx)) - def set_image_cmap(self, key: Hashable, cmap: str) -> None: + def set_channel_cmap(self, ch_idx: int, cmap: cmap.Colormap) -> None: """Set the colormap for an existing Image node.""" - self._images[key].cmap = cmap + self._map_func(lambda i: setattr(i, "cmap", cmap.to_vispy()), (CHANNEL, ch_idx)) + + def _map_func( + self, functor: Callable[[scene.visuals.Image], Any], axis_key: tuple + ) -> None: + """Apply a function to all images that match the given axis key.""" + for axis_keys, img in self._images.items(): + if axis_key in axis_keys: + functor(img) self._canvas.update() def set_range( @@ -307,34 +351,47 @@ def set_range( """ self._camera.set_range(x=x, y=y, margin=margin) - def _image_key(self, index: dict) -> Hashable: - dims_needing_images = set() + def _image_key(self, index: Mapping[str, int]) -> ImageKey: + # gather all axes that require a unique image + # and return as, e.g. [('c', 0), ('g', 1)] + keys: list[tuple[str, int]] = [] if self._channel_mode == "composite": - dims_needing_images.add(CHANNEL) - - return tuple((dim, index.get(dim)) for dim in dims_needing_images) + keys.append((CHANNEL, index.get(CHANNEL, 0))) + return tuple(keys) def set_current_index(self, index: Mapping[str, int]) -> None: """Set the current image index.""" - cidx = ((CHANNEL, index.get("c")),) - if self._channel_mode == "composite" and cidx in self._images: - # if we're in composite mode, we need to update the image for each channel - for key, _ in self._images.items(): - # FIXME - try: - image_data = self._datastore.isel(index, c=key[0][1]) - except IndexError: - print("ERR", key, index) - continue - self.set_image_data(key, image_data) - + indices: list[Mapping[str, int]] = [] + if self._channel_mode != "composite": + indices = [index] else: + # if we're in composite mode, we need to update the image for each channel + this_channel = index.get(CHANNEL) + this_channel_exists = False + for key in self._images: + for axis, axis_i in key: + if axis == CHANNEL: + indices.append({**index, axis: axis_i}) + if axis_i == this_channel: + this_channel_exists = True + if not this_channel_exists: + indices.append(index) + + for index in indices: # otherwise, we only have a single image to update - frame = self._datastore.isel(index) + try: + data = self._datastore.isel(index) + except Exception as e: + logging.error(f"Error getting frame for index {index}: {e}") + continue + if (key := self._image_key(index)) not in self._images: - self.add_image(key, frame) + print('add', key) + self.add_image(key, data) else: - self.set_image_data(key, frame) + print('update', key) + self._images[key].set_data(data) + self._canvas.update() class StackViewer(QWidget): @@ -363,6 +420,25 @@ def __init__(self, *, parent: QWidget | None = None): layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) + for i, ch in enumerate(["DAPI", "FITC"]): + c = ChannelVisControl(i, ch) + layout.addWidget(c) + c.climsChanged.connect(self._on_clims_changed) + c.cmapChanged.connect(self._on_cmap_changed) + c.visibilityChanged.connect(self._on_channel_vis_changed) + + def _on_channel_vis_changed(self, checked: bool) -> None: + sender = cast("ChannelVisControl", self.sender()) + self._canvas.set_channel_visibility(sender.idx, checked) + + def _on_clims_changed(self, clims: tuple) -> None: + sender = cast("ChannelVisControl", self.sender()) + self._canvas.set_channel_clims(sender.idx, clims) + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + sender = cast("ChannelVisControl", self.sender()) + self._canvas.set_channel_cmap(sender.idx, cmap) + def _on_dims_sliders(self, index: dict) -> None: self._canvas.set_current_index(index) From 1f4258859a561724a51f33b838a1e591776f0c4b Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 24 Apr 2024 21:32:00 -0400 Subject: [PATCH 04/73] some fixes --- .../_stack_viewer2/_stack_viewer.py | 60 +++++++++---------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 17007a306..1febb49a1 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import suppress +import itertools import logging from itertools import cycle from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, cast @@ -95,7 +96,7 @@ def __init__(self, idx: int, name: str = "", parent: QWidget | None = None) -> N self._cmap.addColormap(color) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - self._clims.setRange(0, 2**12) + self._clims.setRange(0, 2**14) self._clims.valueChanged.connect(self.climsChanged) self._auto_clim = QCheckBox("Auto") @@ -310,7 +311,6 @@ def add_image(self, key: ImageKey, data: np.ndarray | None = None) -> None: else: cmap = "grays" - print('add', data) self._images[key] = img = scene.visuals.Image( data, cmap=cmap, parent=self._view.scene ) @@ -318,24 +318,30 @@ def add_image(self, key: ImageKey, data: np.ndarray | None = None) -> None: self.set_range() img.interactive = True - def set_channel_visibility(self, ch_idx: int, visible: bool) -> None: + def set_channel_visibility(self, visible: bool, ch_idx: int | None = None) -> None: """Set the visibility of an existing Image node.""" - self._map_func(lambda i: setattr(i, "visible", visible), (CHANNEL, ch_idx)) + if ch_idx is None: + ch_idx = getattr(self.sender(), "idx", 0) + self._map_func(lambda i: setattr(i, "visible", visible), ch_idx) - def set_channel_clims(self, ch_idx: int, clims: tuple) -> None: + def set_channel_clims(self, clims: tuple, ch_idx: int | None = None) -> None: """Set the contrast limits for an existing Image node.""" - self._map_func(lambda i: setattr(i, "clim", clims), (CHANNEL, ch_idx)) + if ch_idx is None: + ch_idx = getattr(self.sender(), "idx", 0) + self._map_func(lambda i: setattr(i, "clim", clims), ch_idx) - def set_channel_cmap(self, ch_idx: int, cmap: cmap.Colormap) -> None: + def set_channel_cmap(self, cmap: cmap.Colormap, ch_idx: int | None = None) -> None: """Set the colormap for an existing Image node.""" - self._map_func(lambda i: setattr(i, "cmap", cmap.to_vispy()), (CHANNEL, ch_idx)) + if ch_idx is None: + ch_idx = getattr(self.sender(), "idx", 0) + self._map_func(lambda i: setattr(i, "cmap", cmap.to_vispy()), ch_idx) def _map_func( - self, functor: Callable[[scene.visuals.Image], Any], axis_key: tuple + self, functor: Callable[[scene.visuals.Image], Any], ch_idx: int ) -> None: """Apply a function to all images that match the given axis key.""" for axis_keys, img in self._images.items(): - if axis_key in axis_keys: + if (CHANNEL, ch_idx) in axis_keys: functor(img) self._canvas.update() @@ -375,7 +381,7 @@ def set_current_index(self, index: Mapping[str, int]) -> None: if axis_i == this_channel: this_channel_exists = True if not this_channel_exists: - indices.append(index) + indices.insert(0, index) for index in indices: # otherwise, we only have a single image to update @@ -386,14 +392,15 @@ def set_current_index(self, index: Mapping[str, int]) -> None: continue if (key := self._image_key(index)) not in self._images: - print('add', key) self.add_image(key, data) else: - print('update', key) - self._images[key].set_data(data) + # FIXME + # this is a hack to avoid data that hasn't arrived yet + if data.max() != 0: + self._images[key].set_data(data) self._canvas.update() - +c = itertools.count() class StackViewer(QWidget): """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" @@ -412,7 +419,8 @@ def __init__(self, *, parent: QWidget | None = None): self._dims_sliders.set_dimension_visible(CHANNEL, False) self._canvas.infoText.connect(lambda x: self._info_bar.setText(x)) - self.datastore.frame_ready.connect(self.on_frame_ready) + self._core.mda.events.frameReady.connect(self.on_frame_ready) + # self.datastore.frame_ready.connect(self.on_frame_ready) self._dims_sliders.indexChanged.connect(self._on_dims_sliders) layout = QVBoxLayout(self) @@ -423,25 +431,13 @@ def __init__(self, *, parent: QWidget | None = None): for i, ch in enumerate(["DAPI", "FITC"]): c = ChannelVisControl(i, ch) layout.addWidget(c) - c.climsChanged.connect(self._on_clims_changed) - c.cmapChanged.connect(self._on_cmap_changed) - c.visibilityChanged.connect(self._on_channel_vis_changed) - - def _on_channel_vis_changed(self, checked: bool) -> None: - sender = cast("ChannelVisControl", self.sender()) - self._canvas.set_channel_visibility(sender.idx, checked) - - def _on_clims_changed(self, clims: tuple) -> None: - sender = cast("ChannelVisControl", self.sender()) - self._canvas.set_channel_clims(sender.idx, clims) - - def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - sender = cast("ChannelVisControl", self.sender()) - self._canvas.set_channel_cmap(sender.idx, cmap) + c.climsChanged.connect(self._canvas.set_channel_clims) + c.cmapChanged.connect(self._canvas.set_channel_cmap) + c.visibilityChanged.connect(self._canvas.set_channel_visibility) def _on_dims_sliders(self, index: dict) -> None: self._canvas.set_current_index(index) @superqt.ensure_main_thread - def on_frame_ready(self, event: useq.MDAEvent) -> None: + def on_frame_ready(self, frame: np.narray, event: useq.MDAEvent) -> None: self._dims_sliders.update_dimensions(event.index) From 1731f7ce6f5a3e042ebeaee4c21d2cc788718729 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 1 May 2024 16:54:00 -0400 Subject: [PATCH 05/73] wip --- examples/stack_viewer2.py | 8 +- .../_stack_viewer2/_dims_slider.py | 254 +++++++++ .../_stack_viewer2/_stack_viewer.py | 517 ++++++------------ .../_stack_viewer2/_vispy_canvas.py | 126 +++++ 4 files changed, 544 insertions(+), 361 deletions(-) create mode 100644 src/pymmcore_widgets/_stack_viewer2/_dims_slider.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 90d3cbc02..9883e3a2a 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -4,7 +4,7 @@ from qtpy import QtWidgets from useq import MDASequence -from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer2._stack_viewer import MDAViewer configure_logging(stderr_level="WARNING") @@ -15,8 +15,8 @@ sequence = MDASequence( channels=( - {"config": "DAPI", "exposure": 16}, - {"config": "FITC", "exposure": 10}, + {"config": "DAPI", "exposure": 10}, + {"config": "FITC", "exposure": 80}, # {"config": "Cy5", "exposure": 20}, ), stage_positions=[(0, 0), (1, 1)], @@ -27,7 +27,7 @@ qapp = QtWidgets.QApplication([]) -v = StackViewer() +v = MDAViewer() v.show() mmcore.run_mda(sequence, output=v.datastore) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py new file mode 100644 index 000000000..f5f7ad87d --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from warnings import warn + +from qtpy.QtCore import Qt, Signal +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QLabeledRangeSlider, QLabeledSlider +from superqt.iconify import QIconifyIcon +from superqt.utils import signals_blocked + +if TYPE_CHECKING: + from typing import Hashable, Mapping, TypeAlias + + from qtpy.QtGui import QMouseEvent + + DimensionKey: TypeAlias = Hashable + Index: TypeAlias = int | slice + Indices: TypeAlias = Mapping[DimensionKey, Index] + + +class PlayButton(QPushButton): + """Just a styled QPushButton that toggles between play and pause icons.""" + + PLAY_ICON = "fa6-solid:play" + PAUSE_ICON = "fa6-solid:pause" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.PLAY_ICON) + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setMaximumWidth(22) + + +class LockButton(QPushButton): + LOCK_ICON = "fa6-solid:lock-open" + UNLOCK_ICON = "fa6-solid:lock" + + def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + icn = QIconifyIcon(self.LOCK_ICON) + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On) + super().__init__(icn, text, parent) + self.setCheckable(True) + self.setMaximumWidth(20) + + +class DimsSlider(QWidget): + """A single slider in the DimsSliders widget. + + Provides a play/pause button that toggles animation of the slider value. + Has a QLabeledSlider for the actual value. + Adds a label for the maximum value (e.g. "3 / 10") + """ + + valueChanged = Signal(str, object) # where object is int | slice + + def __init__( + self, dimension_key: DimensionKey, parent: QWidget | None = None + ) -> None: + super().__init__(parent) + self._slice_mode = False + self._animation_fps = 10 + self._dim_key = dimension_key + + self._play_btn = PlayButton() + self._play_btn.toggled.connect(self._toggle_animation) + + self._dim_label = QLabel(str(dimension_key)) + + # note, this lock button only prevents the slider from updating programmatically + # using self.setValue, it doesn't prevent the user from changing the value. + self._lock_btn = LockButton() + + self._max_label = QLabel("/ 0") + self._int_slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) + self._int_slider.rangeChanged.connect(self._on_range_changed) + self._int_slider.valueChanged.connect(self._on_int_value_changed) + self._int_slider.layout().addWidget(self._max_label) + + self._slice_slider = QLabeledRangeSlider(Qt.Orientation.Horizontal, parent=self) + self._slice_slider.setVisible(False) + # self._slice_slider.rangeChanged.connect(self._on_range_changed) + self._slice_slider.valueChanged.connect(self._on_slice_value_changed) + + self.installEventFilter(self) + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._play_btn) + layout.addWidget(self._dim_label) + layout.addWidget(self._int_slider) + layout.addWidget(self._slice_slider) + layout.addWidget(self._lock_btn) + + def mouseDoubleClickEvent(self, a0: QMouseEvent | None) -> None: + self._set_slice_mode(not self._slice_mode) + return super().mouseDoubleClickEvent(a0) + + def setMaximum(self, max_val: int) -> None: + if max_val > self._int_slider.maximum(): + self._int_slider.setMaximum(max_val) + if max_val > self._slice_slider.maximum(): + self._slice_slider.setMaximum(max_val) + + def setRange(self, min_val: int, max_val: int) -> None: + self._int_slider.setRange(min_val, max_val) + self._slice_slider.setRange(min_val, max_val) + + def value(self) -> Index: + return ( + self._int_slider.value() + if not self._slice_mode + else slice(*self._slice_slider.value()) + ) + + def setValue(self, val: Index) -> None: + # variant of setValue that always updates the maximum + self._set_slice_mode(isinstance(val, slice)) + if self._lock_btn.isChecked(): + return + if isinstance(val, slice): + self._slice_slider.setValue((val.start, val.stop)) + # self._int_slider.setValue(int((val.stop + val.start) / 2)) + else: + self._int_slider.setValue(val) + # self._slice_slider.setValue((val, val + 1)) + + def forceValue(self, val: Index) -> None: + """Set value and increase range if necessary.""" + self.setMaximum(val.stop if isinstance(val, slice) else val) + self.setValue(val) + + def _set_slice_mode(self, mode: bool = True) -> None: + self._slice_mode = mode + if mode: + self._slice_slider.setVisible(True) + self._int_slider.setVisible(False) + else: + self._int_slider.setVisible(True) + self._slice_slider.setVisible(False) + + def set_fps(self, fps: int) -> None: + self._animation_fps = fps + + def _toggle_animation(self, checked: bool) -> None: + if checked: + self._timer_id = self.startTimer(1000 // self._animation_fps) + else: + self.killTimer(self._timer_id) + + def timerEvent(self, event: Any) -> None: + if self._slice_mode: + val = self._slice_slider.value() + next_val = [v + 1 for v in val] + if next_val[1] > self._slice_slider.maximum(): + next_val = [v - val[0] for v in val] + self._slice_slider.setValue(next_val) + else: + val = self._int_slider.value() + val = (val + 1) % (self._int_slider.maximum() + 1) + self._int_slider.setValue(val) + + def _on_range_changed(self, min: int, max: int) -> None: + self._max_label.setText("/ " + str(max)) + + def _on_int_value_changed(self, value: int) -> None: + if not self._slice_mode: + self.valueChanged.emit(self._dim_key, value) + + def _on_slice_value_changed(self, value: tuple[int, int]) -> None: + if self._slice_mode: + self.valueChanged.emit(self._dim_key, slice(*value)) + + +class DimsSliders(QWidget): + """A Collection of DimsSlider widgets for each dimension in the data. + + Maintains the global current index and emits a signal when it changes. + """ + + valueChanged = Signal(dict) + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._sliders: dict[DimensionKey, DimsSlider] = {} + self._current_index: dict[DimensionKey, Index] = {} + self._invisible_dims: set[DimensionKey] = set() + self._updating = False + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + def value(self) -> Indices: + return self._current_index.copy() + + def setValue(self, values: Indices) -> None: + if self._current_index == values: + return + with signals_blocked(self): + for dim, index in values.items(): + self.add_or_update_dimension(dim, index) + self.valueChanged.emit(self.value()) + + def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: + self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) + slider.setRange(0, 1) + val = val if val is not None else 0 + self._current_index[name] = val + slider.forceValue(val) + slider.valueChanged.connect(self._on_dim_slider_value_changed) + slider.setVisible(name not in self._invisible_dims) + self.layout().addWidget(slider) + + def set_dimension_visible(self, name: str, visible: bool) -> None: + if visible: + self._invisible_dims.discard(name) + else: + self._invisible_dims.add(name) + if name in self._sliders: + self._sliders[name].setVisible(visible) + + def remove_dimension(self, name: str) -> None: + try: + slider = self._sliders.pop(name) + except KeyError: + warn(f"Dimension {name} not found in DimsSliders", stacklevel=2) + return + self.layout().removeWidget(slider) + slider.deleteLater() + + def _on_dim_slider_value_changed(self, dim_name: str, value: Index) -> None: + self._current_index[dim_name] = value + if not self._updating: + self.valueChanged.emit(self.value()) + + def add_or_update_dimension(self, name: DimensionKey, value: Index) -> None: + if name in self._sliders: + self._sliders[name].forceValue(value) + else: + self.add_dimension(name, value) + + +if __name__ == "__main__": + from qtpy.QtWidgets import QApplication + + app = QApplication([]) + w = DimsSliders() + w.add_dimension("x") + w.add_dimension("y", slice(5, 9)) + w.add_dimension("z", 10) + w.valueChanged.connect(print) + w.show() + app.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 1febb49a1..2c30c418c 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,39 +1,35 @@ from __future__ import annotations -from contextlib import suppress import itertools -import logging +from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, cast -from warnings import warn +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Iterable, + Literal, + Mapping, + Protocol, +) import cmap import numpy as np import superqt import useq from psygnal import Signal as psygnalSignal -from pymmcore_plus import CMMCorePlus from pymmcore_plus.mda.handlers import OMEZarrWriter -from qtpy.QtCore import Qt, Signal -from qtpy.QtWidgets import ( - QCheckBox, - QHBoxLayout, - QLabel, - QPushButton, - QVBoxLayout, - QWidget, -) -from superqt import QLabeledRangeSlider, QLabeledSlider +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QCheckBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget +from superqt import QLabeledRangeSlider from superqt.cmap import QColormapComboBox -from superqt.iconify import QIconifyIcon -from vispy import scene +from superqt.utils import signals_blocked -if TYPE_CHECKING: - import numpy.typing as npt - from PySide6.QtCore import QTimerEvent - from vispy.scene.events import SceneMouseEvent +from ._dims_slider import DimsSliders +from ._vispy_canvas import VispyViewerCanvas - ImageKey = tuple[tuple[str, int], ...] +if TYPE_CHECKING: + ImageKey = Hashable CHANNEL = "c" @@ -44,400 +40,207 @@ # FIXME: get rid of this thin subclass class DataStore(OMEZarrWriter): - frame_ready = psygnalSignal(useq.MDAEvent) + frame_ready = psygnalSignal(object, useq.MDAEvent) def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: super().frameReady(frame, event, meta) - self.frame_ready.emit(event) - - -class PlayButton(QPushButton): - """Just a styled QPushButton that toggles between play and pause icons.""" - - PLAY_ICON = "fa6-solid:play" - PAUSE_ICON = "fa6-solid:pause" - - def __init__(self, text: str = "", parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.PLAY_ICON) - icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) - super().__init__(icn, text, parent) - self.setCheckable(True) - - -class LockButton(QPushButton): - LOCK_ICON = "fa6-solid:lock-open" - UNLOCK_ICON = "fa6-solid:lock" - - def __init__(self, text: str = "", parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.LOCK_ICON) - icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On) - super().__init__(icn, text, parent) - self.setCheckable(True) - self.setMaximumWidth(20) + self.frame_ready.emit(frame, event) class ChannelVisControl(QWidget): - visibilityChanged = Signal(bool) - climsChanged = Signal(tuple) - cmapChanged = Signal(cmap.Colormap) - - def __init__(self, idx: int, name: str = "", parent: QWidget | None = None) -> None: + def __init__( + self, + name: str = "", + handles: Iterable[PImageHandle] = (), + parent: QWidget | None = None, + ) -> None: super().__init__(parent) - self.idx = idx + self._handles = handles self._name = name self._visible = QCheckBox(name) self._visible.setChecked(True) - self._visible.toggled.connect(self.visibilityChanged) + self._visible.toggled.connect(self._on_visible_changed) self._cmap = QColormapComboBox(allow_user_colormaps=True) - self._cmap.currentColormapChanged.connect(self.cmapChanged) + self._cmap.currentColormapChanged.connect(self._on_cmap_changed) for color in ["green", "magenta", "cyan"]: self._cmap.addColormap(color) - self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - self._clims.setRange(0, 2**14) - self._clims.valueChanged.connect(self.climsChanged) + self.clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self.clims.setRange(0, 2**14) + self.clims.valueChanged.connect(self._on_clims_changed) self._auto_clim = QCheckBox("Auto") + self._auto_clim.toggled.connect(self.update_autoscale) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self._visible) layout.addWidget(self._cmap) - layout.addWidget(self._clims) + layout.addWidget(self.clims) layout.addWidget(self._auto_clim) - def set_clim_for_dtype(self, dtype: npt.DTypeLike) -> None: - # get maximum possible value for the dtype - self._clims.setRange(0, np.iinfo(dtype).max) + def autoscaleChecked(self) -> bool: + return self._auto_clim.isChecked() + def _on_clims_changed(self, clims: tuple[float, float]) -> None: + self._auto_clim.setChecked(False) + for handle in self._handles: + handle.clim = clims -class DimsSlider(QWidget): - """A single slider in the DimsSliders widget. + def _on_visible_changed(self, visible: bool) -> None: + for handle in self._handles: + handle.visible = visible - Provides a play/pause button that toggles animation of the slider value. - Has a QLabeledSlider for the actual value. - Adds a label for the maximum value (e.g. "3 / 10") - """ + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + for handle in self._handles: + handle.cmap = cmap - valueChanged = Signal(str, int) - - def __init__(self, dimension_name: str, parent: QWidget | None = None) -> None: - super().__init__(parent) - self._interval = 1000 // 10 - self._name = dimension_name - - self._play_btn = PlayButton(dimension_name) - self._play_btn.toggled.connect(self._toggle_animation) - # note, this lock button only prevents the slider from updating programmatically - # using self.setValue, it doesn't prevent the user from changing the value. - self._lock_btn = LockButton() - - self._max_label = QLabel("/ 0") - self._slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) - self._slider.setMaximum(0) - self._slider.rangeChanged.connect(self._on_range_changed) - self._slider.valueChanged.connect(self._on_value_changed) - self._slider.layout().addWidget(self._max_label) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._play_btn) - layout.addWidget(self._slider) - layout.addWidget(self._lock_btn) - - def setMaximum(self, max_val: int) -> None: - self._slider.setMaximum(max_val) - - def setValue(self, val: int) -> None: - # variant of setValue that always updates the maximum - if val > self._slider.maximum(): - self._slider.setMaximum(val) - if self._lock_btn.isChecked(): + def update_autoscale(self) -> None: + if not self._auto_clim.isChecked(): return - self._slider.setValue(val) - - def set_fps(self, fps: int) -> None: - self._interval = 1000 // fps - - def _toggle_animation(self, checked: bool) -> None: - if checked: - self._timer_id = self.startTimer(self._interval) - else: - self.killTimer(self._timer_id) - - def timerEvent(self, event: QTimerEvent) -> None: - val = self._slider.value() - next_val = (val + 1) % (self._slider.maximum() + 1) - self._slider.setValue(next_val) - - def _on_range_changed(self, min: int, max: int) -> None: - self._max_label.setText("/ " + str(max)) - - def _on_value_changed(self, value: int) -> None: - self.valueChanged.emit(self._name, value) - - -class DimsSliders(QWidget): - """A Collection of DimsSlider widgets for each dimension in the data. - - Maintains the global current index and emits a signal when it changes. - """ - - indexChanged = Signal(dict) - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self._sliders: dict[str, DimsSlider] = {} - self._current_index: dict[str, int] = {} - self._invisible_dims: set[str] = set() - self._updating = False - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(0) - - def add_dimension(self, name: str) -> None: - self._sliders[name] = slider = DimsSlider(dimension_name=name, parent=self) - self._current_index[name] = 0 - slider.valueChanged.connect(self._on_value_changed) - self.layout().addWidget(slider) - slider.setVisible(name not in self._invisible_dims) - - def set_dimension_visible(self, name: str, visible: bool) -> None: - if visible: - self._invisible_dims.discard(name) - else: - self._invisible_dims.add(name) - if name in self._sliders: - self._sliders[name].setVisible(visible) - - def remove_dimension(self, name: str) -> None: - try: - slider = self._sliders.pop(name) - except KeyError: - warn(f"Dimension {name} not found in DimsSliders", stacklevel=2) - return - self.layout().removeWidget(slider) - slider.deleteLater() - - def _on_value_changed(self, dim_name: str, value: int) -> None: - self._current_index[dim_name] = value - if not self._updating: - self.indexChanged.emit(self._current_index) - - def add_or_update_dimension(self, name: str, value: int) -> None: - if name in self._sliders: - self._sliders[name].setValue(value) - else: - self.add_dimension(name) - def update_dimensions(self, index: Mapping[str, int]) -> None: - prev = self._current_index.copy() - self._updating = True - try: - for dim, value in index.items(): - self.add_or_update_dimension(dim, value) - if self._current_index != prev: - self.indexChanged.emit(self._current_index) - finally: - self._updating = False + # find the min and max values for the current channel + clims = [np.inf, -np.inf] + for handle in self._handles: + clims[0] = min(clims[0], np.nanmin(handle.data)) + clims[1] = max(clims[1], np.nanmax(handle.data)) + for handle in self._handles: + handle.clim = clims -class VispyViewerCanvas(QWidget): - """Vispy-based viewer for data. + # set the slider values to the new clims + with signals_blocked(self.clims): + self.clims.setValue(clims) - All vispy-specific code is encapsulated in this class (and non-vispy canvases - could be swapped in if needed as long as they implement the same interface). - """ - infoText = Signal(str) - - def __init__( - self, - datastore: OMEZarrWriter, - channel_mode: str = "composite", - parent: QWidget | None = None, - ) -> None: - super().__init__(parent) - self._datastore = datastore - - self._channel_mode = channel_mode - self._canvas = scene.SceneCanvas(parent=self) - self._canvas.events.mouse_move.connect(self._on_mouse_move) - self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) - - central_wdg: scene.Widget = self._canvas.central_widget - self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) - - # Mapping of image key to Image visual objects - # tbd... determine what the key should be - # could have an image per channel, - # but may also have multiple images per channel... in the case of tiles, etc... - self._images: dict[ImageKey, scene.visuals.Image] = {} - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._canvas.native) - - def _on_mouse_move(self, event: SceneMouseEvent) -> None: - """Mouse moved on the canvas, display the pixel value and position.""" - images = [] - # Get the images the mouse is over - seen = set() - while visual := self._canvas.visual_at(event.pos): - if isinstance(visual, scene.visuals.Image): - images.append(visual) - visual.interactive = False - seen.add(visual) - for visual in seen: - visual.interactive = True - if not images: - return +c = itertools.count() - tform = images[0].get_transform("canvas", "visual") - px, py, *_ = (int(x) for x in tform.map(event.pos)) - text = f"[{py}, {px}]" - for c, img in enumerate(images): - with suppress(IndexError): - text += f" c{c}: {img._data[py, px]}" - self.infoText.emit(text) - def add_image(self, key: ImageKey, data: np.ndarray | None = None) -> None: - """Add a new Image node to the scene.""" - if self._channel_mode == "composite": - cmap = next(COLORMAPS).to_vispy() - else: - cmap = "grays" - - self._images[key] = img = scene.visuals.Image( - data, cmap=cmap, parent=self._view.scene - ) - img.set_gl_state("additive", depth_test=False) - self.set_range() - img.interactive = True - - def set_channel_visibility(self, visible: bool, ch_idx: int | None = None) -> None: - """Set the visibility of an existing Image node.""" - if ch_idx is None: - ch_idx = getattr(self.sender(), "idx", 0) - self._map_func(lambda i: setattr(i, "visible", visible), ch_idx) - - def set_channel_clims(self, clims: tuple, ch_idx: int | None = None) -> None: - """Set the contrast limits for an existing Image node.""" - if ch_idx is None: - ch_idx = getattr(self.sender(), "idx", 0) - self._map_func(lambda i: setattr(i, "clim", clims), ch_idx) - - def set_channel_cmap(self, cmap: cmap.Colormap, ch_idx: int | None = None) -> None: - """Set the colormap for an existing Image node.""" - if ch_idx is None: - ch_idx = getattr(self.sender(), "idx", 0) - self._map_func(lambda i: setattr(i, "cmap", cmap.to_vispy()), ch_idx) - - def _map_func( - self, functor: Callable[[scene.visuals.Image], Any], ch_idx: int - ) -> None: - """Apply a function to all images that match the given axis key.""" - for axis_keys, img in self._images.items(): - if (CHANNEL, ch_idx) in axis_keys: - functor(img) - self._canvas.update() +class PDataStore(Protocol): ... - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - margin: float | None = 0.05, - ) -> None: - """Update the range of the PanZoomCamera. - When called with no arguments, the range is set to the full extent of the data. - """ - self._camera.set_range(x=x, y=y, margin=margin) +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... - def _image_key(self, index: Mapping[str, int]) -> ImageKey: - # gather all axes that require a unique image - # and return as, e.g. [('c', 0), ('g', 1)] - keys: list[tuple[str, int]] = [] - if self._channel_mode == "composite": - keys.append((CHANNEL, index.get(CHANNEL, 0))) - return tuple(keys) - def set_current_index(self, index: Mapping[str, int]) -> None: - """Set the current image index.""" - indices: list[Mapping[str, int]] = [] - if self._channel_mode != "composite": - indices = [index] - else: - # if we're in composite mode, we need to update the image for each channel - this_channel = index.get(CHANNEL) - this_channel_exists = False - for key in self._images: - for axis, axis_i in key: - if axis == CHANNEL: - indices.append({**index, axis: axis_i}) - if axis_i == this_channel: - this_channel_exists = True - if not this_channel_exists: - indices.insert(0, index) - - for index in indices: - # otherwise, we only have a single image to update - try: - data = self._datastore.isel(index) - except Exception as e: - logging.error(f"Error getting frame for index {index}: {e}") - continue - - if (key := self._image_key(index)) not in self._images: - self.add_image(key, data) - else: - # FIXME - # this is a hack to avoid data that hasn't arrived yet - if data.max() != 0: - self._images[key].set_data(data) - self._canvas.update() - -c = itertools.count() class StackViewer(QWidget): """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" - def __init__(self, *, parent: QWidget | None = None): + def __init__(self, datastore: PDataStore, *, parent: QWidget | None = None): super().__init__(parent=parent) - channel_mode: Literal["composite", "grayscale"] = "composite" + self._channels: defaultdict[Hashable, list[PImageHandle]] = defaultdict(list) - self._core = CMMCorePlus.instance() - self.datastore = DataStore() - self._canvas = VispyViewerCanvas(self.datastore, channel_mode=channel_mode) + self.datastore = datastore + self._canvas = VispyViewerCanvas() self._info_bar = QLabel("Info") self._dims_sliders = DimsSliders() - - if channel_mode == "composite": - self._dims_sliders.set_dimension_visible(CHANNEL, False) + self.set_channel_mode("composite") self._canvas.infoText.connect(lambda x: self._info_bar.setText(x)) - self._core.mda.events.frameReady.connect(self.on_frame_ready) - # self.datastore.frame_ready.connect(self.on_frame_ready) - self._dims_sliders.indexChanged.connect(self._on_dims_sliders) + self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) layout = QVBoxLayout(self) layout.addWidget(self._canvas, 1) layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) + self._channel_controls: dict[Hashable, ChannelVisControl] = {} for i, ch in enumerate(["DAPI", "FITC"]): - c = ChannelVisControl(i, ch) + self._channel_controls[i] = c = ChannelVisControl(ch, self._channels[i]) layout.addWidget(c) - c.climsChanged.connect(self._canvas.set_channel_clims) - c.cmapChanged.connect(self._canvas.set_channel_cmap) - c.visibilityChanged.connect(self._canvas.set_channel_visibility) - def _on_dims_sliders(self, index: dict) -> None: - self._canvas.set_current_index(index) + def set_channel_mode(self, mode: Literal["composite", "grayscale"]) -> None: + if mode == getattr(self, "_channel_mode", None): + return + + self._channel_mode = mode + if mode == "composite": + self._dims_sliders.set_dimension_visible(CHANNEL, False) + else: + self._dims_sliders.set_dimension_visible(CHANNEL, True) + + def _image_key(self, index: Mapping[str, int]) -> Hashable: + if self._channel_mode == "composite": + return index.get("c", 0) + return 0 + + def _isel(self, index: dict) -> np.ndarray: + return isel(self.datastore, index) + + def _on_dims_sliders_changed(self, index: dict) -> None: + """Set the current image index.""" + c = index.get(CHANNEL, 0) + indices = [index] + if self._channel_mode == "composite": + for i, handles in self._channels.items(): + if handles and c != i: + indices.append({**index, CHANNEL: i}) + + for idx in indices: + self._update_data_for_index(idx) + self._canvas.refresh() + + def _update_data_for_index(self, index: dict) -> None: + key = self._image_key(index) + data = self._isel(index) + if handles := self._channels.get(key): + for handle in handles: + handle.data = data + if ctrl := self._channel_controls.get(key, None): + ctrl.update_autoscale() + else: + cm = ( + next(COLORMAPS) + if self._channel_mode == "composite" + else cmap.Colormap("gray") + ) + new_img = self._canvas.add_image(data, cmap=cm) + self._channels[key].append(new_img) + + +class MDAViewer(StackViewer): + def __init__(self, *, parent: QWidget | None = None): + # self._core = CMMCorePlus.instance() + # self._core.mda.events.frameReady.connect(self.on_frame_ready) + super().__init__(DataStore(), parent=parent) + self.datastore.frame_ready.connect(self.on_frame_ready) @superqt.ensure_main_thread - def on_frame_ready(self, frame: np.narray, event: useq.MDAEvent) -> None: - self._dims_sliders.update_dimensions(event.index) + def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: + self._dims_sliders.setValue(event.index) + + +def isel(writer: OMEZarrWriter, indexers: Mapping[str, int | slice]) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + raise NotImplementedError("Cannot slice over position index") # TODO + + try: + sizes = [*list(writer.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for {len(writer.position_sizes)}" + ) from e + + data = writer.position_arrays[writer.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py new file mode 100644 index 000000000..fcb37383a --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +import cmap +from qtpy.QtCore import Signal +from qtpy.QtWidgets import QVBoxLayout, QWidget +from vispy import scene + +if TYPE_CHECKING: + import numpy as np + from vispy.scene.events import SceneMouseEvent + + +class VispyImageHandle: + def __init__(self, image: scene.visuals.Image) -> None: + self._image = image + + @property + def data(self) -> np.ndarray: + return self._image._data # type: ignore + + @data.setter + def data(self, data: np.ndarray) -> None: + self._image.set_data(data) + + @property + def visible(self) -> bool: + return bool(self._image.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._image.visible = visible + + @property + def clim(self) -> Any: + return self._image.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + self._image.clim = clims + + @property + def cmap(self) -> cmap.Colormap: + return cmap.Colormap(self._image.cmap) + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._image.cmap = cmap.to_vispy() + + +class VispyViewerCanvas(QWidget): + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + infoText = Signal(str) + + def __init__( + self, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._canvas = scene.SceneCanvas(parent=self) + self._canvas.events.mouse_move.connect(self._on_mouse_move) + self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + central_wdg: scene.Widget = self._canvas.central_widget + self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._canvas.native) + + def refresh(self) -> None: + self._canvas.update() + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> VispyImageHandle: + """Add a new Image node to the scene.""" + if cmap is not None: + cmap = cmap.to_vispy() + img = scene.visuals.Image(data, cmap=cmap, parent=self._view.scene) + img.set_gl_state("additive", depth_test=False) + img.interactive = True + self.set_range() + return VispyImageHandle(img) + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + margin: float | None = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + self._camera.set_range(x=x, y=y, margin=margin) + + def _on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + images = [] + # Get the images the mouse is over + seen = set() + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + for visual in seen: + visual.interactive = True + if not images: + return + + tform = images[0].get_transform("canvas", "visual") + px, py, *_ = (int(x) for x in tform.map(event.pos)) + text = f"[{py}, {px}]" + for c, img in enumerate(images): + with suppress(IndexError): + text += f" c{c}: {img._data[py, px]}" + self.infoText.emit(text) From a3322289151cf1f8042af10c140d46bd09ebf181 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 1 May 2024 20:06:15 -0400 Subject: [PATCH 06/73] getting better --- examples/stack_viewer2.py | 6 +- .../_stack_viewer2/_dims_slider.py | 4 +- .../_stack_viewer2/_flow_layout.py | 99 ++++++++++++++ .../_stack_viewer2/_lut_control.py | 103 ++++++++++++++ .../_stack_viewer2/_stack_viewer.py | 128 ++---------------- .../_stack_viewer2/_vispy_canvas.py | 16 ++- x.py | 20 +++ y.py | 35 +++++ z.py | 10 ++ 9 files changed, 295 insertions(+), 126 deletions(-) create mode 100644 src/pymmcore_widgets/_stack_viewer2/_flow_layout.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_lut_control.py create mode 100644 x.py create mode 100644 y.py create mode 100644 z.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 9883e3a2a..185710a80 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -1,5 +1,5 @@ from __future__ import annotations - +from PySide6 import QtWidgets from pymmcore_plus import CMMCorePlus, configure_logging from qtpy import QtWidgets from useq import MDASequence @@ -15,8 +15,8 @@ sequence = MDASequence( channels=( - {"config": "DAPI", "exposure": 10}, - {"config": "FITC", "exposure": 80}, + {"config": "DAPI", "exposure": 5}, + {"config": "FITC", "exposure": 20}, # {"config": "Cy5", "exposure": 20}, ), stage_positions=[(0, 0), (1, 1)], diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index f5f7ad87d..e8b9619bc 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -60,7 +60,7 @@ def __init__( ) -> None: super().__init__(parent) self._slice_mode = False - self._animation_fps = 10 + self._animation_fps = 30 self._dim_key = dimension_key self._play_btn = PlayButton() @@ -80,7 +80,7 @@ def __init__( self._slice_slider = QLabeledRangeSlider(Qt.Orientation.Horizontal, parent=self) self._slice_slider.setVisible(False) - # self._slice_slider.rangeChanged.connect(self._on_range_changed) + self._slice_slider.rangeChanged.connect(self._on_range_changed) self._slice_slider.valueChanged.connect(self._on_slice_value_changed) self.installEventFilter(self) diff --git a/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py b/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py new file mode 100644 index 000000000..c3633c670 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from qtpy.QtCore import QMargins, QPoint, QRect, QSize, Qt +from qtpy.QtWidgets import QLayout, QLayoutItem, QSizePolicy, QWidget + + +class FlowLayout(QLayout): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + + if parent is not None: + self.setContentsMargins(QMargins(0, 0, 0, 0)) + + self._item_list: list[QLayoutItem] = [] + + def __del__(self) -> None: + item = self.takeAt(0) + while item: + item = self.takeAt(0) + + def addItem(self, item: QLayoutItem) -> None: + self._item_list.append(item) + + def count(self) -> int: + return len(self._item_list) + + def itemAt(self, index: int) -> QLayoutItem | None: + if 0 <= index < len(self._item_list): + return self._item_list[index] + return None + + def takeAt(self, index: int) -> QLayoutItem | None: + if 0 <= index < len(self._item_list): + return self._item_list.pop(index) + return None + + def expandingDirections(self) -> Qt.Orientation: + return Qt.Orientation(0) + + def hasHeightForWidth(self) -> bool: + return True + + def heightForWidth(self, width: int) -> int: + height = self._do_layout(QRect(0, 0, width, 0), True) + return height + + def setGeometry(self, rect: QRect) -> None: + super().setGeometry(rect) + self._do_layout(rect, False) + + def sizeHint(self) -> QSize: + return self.minimumSize() + + def minimumSize(self) -> QSize: + size = QSize() + + for item in self._item_list: + size = size.expandedTo(item.minimumSize()) + + size += QSize( + 2 * self.contentsMargins().top(), 2 * self.contentsMargins().top() + ) + return size + + def _do_layout(self, rect: QRect, test_only: bool) -> int: + x = rect.x() + y = rect.y() + line_height = 0 + spacing = self.spacing() + + for item in self._item_list: + style = item.widget().style() + layout_spacing_x = style.layoutSpacing( + QSizePolicy.ControlType.PushButton, + QSizePolicy.ControlType.PushButton, + Qt.Orientation.Horizontal, + ) + layout_spacing_y = style.layoutSpacing( + QSizePolicy.ControlType.PushButton, + QSizePolicy.ControlType.PushButton, + Qt.Orientation.Vertical, + ) + space_x = spacing + layout_spacing_x + space_y = spacing + layout_spacing_y + next_x = x + item.sizeHint().width() + space_x + if next_x - space_x > rect.right() and line_height > 0: + x = rect.x() + y = y + line_height + space_y + next_x = x + item.sizeHint().width() + space_x + line_height = 0 + + if not test_only: + print(x, y) + item.setGeometry(QRect(QPoint(x, y), item.sizeHint())) + + x = next_x + line_height = max(line_height, item.sizeHint().height()) + + return y + line_height - rect.y() diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py new file mode 100644 index 000000000..c21ab5018 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable, Protocol + +import numpy as np +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QCheckBox, QHBoxLayout, QWidget +from superqt import QLabeledRangeSlider +from superqt.cmap import QColormapComboBox +from superqt.utils import signals_blocked + +if TYPE_CHECKING: + import cmap + + +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... + + +class LutControl(QWidget): + def __init__( + self, + name: str = "", + handles: Iterable[PImageHandle] = (), + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._handles = handles + self._name = name + + self._visible = QCheckBox(name) + self._visible.setChecked(True) + self._visible.toggled.connect(self._on_visible_changed) + + self._cmap = QColormapComboBox(allow_user_colormaps=True) + self._cmap.currentColormapChanged.connect(self._on_cmap_changed) + for handle in handles: + self._cmap.addColormap(handle.cmap) + for color in ["green", "magenta", "cyan"]: + self._cmap.addColormap(color) + + self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + self._clims.setRange(0, 2**14) + self._clims.valueChanged.connect(self._on_clims_changed) + + self._auto_clim = QCheckBox("Auto") + self._auto_clim.toggled.connect(self.update_autoscale) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._visible) + layout.addWidget(self._cmap) + layout.addWidget(self._clims) + layout.addWidget(self._auto_clim) + + def autoscaleChecked(self) -> bool: + return self._auto_clim.isChecked() + + def _on_clims_changed(self, clims: tuple[float, float]) -> None: + self._auto_clim.setChecked(False) + for handle in self._handles: + handle.clim = clims + + def _on_visible_changed(self, visible: bool) -> None: + for handle in self._handles: + handle.visible = visible + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + for handle in self._handles: + handle.cmap = cmap + + def update_autoscale(self) -> None: + if not self._auto_clim.isChecked(): + return + + # find the min and max values for the current channel + clims = [np.inf, -np.inf] + for handle in self._handles: + clims[0] = min(clims[0], np.nanmin(handle.data)) + clims[1] = max(clims[1], np.nanmax(handle.data)) + + clims_ = tuple(int(x) for x in clims) + for handle in self._handles: + handle.clim = clims_ + + # set the slider values to the new clims + with signals_blocked(self._clims): + self._clims.setValue(clims_) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 2c30c418c..9d50e539b 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -3,36 +3,25 @@ import itertools from collections import defaultdict from itertools import cycle -from typing import ( - TYPE_CHECKING, - Any, - Hashable, - Iterable, - Literal, - Mapping, - Protocol, -) +from typing import TYPE_CHECKING, Hashable, Literal, Mapping, Protocol import cmap -import numpy as np import superqt import useq from psygnal import Signal as psygnalSignal from pymmcore_plus.mda.handlers import OMEZarrWriter -from qtpy.QtCore import Qt -from qtpy.QtWidgets import QCheckBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget -from superqt import QLabeledRangeSlider -from superqt.cmap import QColormapComboBox -from superqt.utils import signals_blocked +from qtpy.QtWidgets import QLabel, QVBoxLayout, QWidget from ._dims_slider import DimsSliders +from ._lut_control import LutControl, PImageHandle from ._vispy_canvas import VispyViewerCanvas if TYPE_CHECKING: - ImageKey = Hashable + import numpy as np CHANNEL = "c" +GRAYS = cmap.Colormap("gray") COLORMAPS = cycle( [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] ) @@ -47,99 +36,12 @@ def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> Non self.frame_ready.emit(frame, event) -class ChannelVisControl(QWidget): - def __init__( - self, - name: str = "", - handles: Iterable[PImageHandle] = (), - parent: QWidget | None = None, - ) -> None: - super().__init__(parent) - self._handles = handles - self._name = name - - self._visible = QCheckBox(name) - self._visible.setChecked(True) - self._visible.toggled.connect(self._on_visible_changed) - - self._cmap = QColormapComboBox(allow_user_colormaps=True) - self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for color in ["green", "magenta", "cyan"]: - self._cmap.addColormap(color) - - self.clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - self.clims.setRange(0, 2**14) - self.clims.valueChanged.connect(self._on_clims_changed) - - self._auto_clim = QCheckBox("Auto") - self._auto_clim.toggled.connect(self.update_autoscale) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._visible) - layout.addWidget(self._cmap) - layout.addWidget(self.clims) - layout.addWidget(self._auto_clim) - - def autoscaleChecked(self) -> bool: - return self._auto_clim.isChecked() - - def _on_clims_changed(self, clims: tuple[float, float]) -> None: - self._auto_clim.setChecked(False) - for handle in self._handles: - handle.clim = clims - - def _on_visible_changed(self, visible: bool) -> None: - for handle in self._handles: - handle.visible = visible - - def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - for handle in self._handles: - handle.cmap = cmap - - def update_autoscale(self) -> None: - if not self._auto_clim.isChecked(): - return - - # find the min and max values for the current channel - clims = [np.inf, -np.inf] - for handle in self._handles: - clims[0] = min(clims[0], np.nanmin(handle.data)) - clims[1] = max(clims[1], np.nanmax(handle.data)) - - for handle in self._handles: - handle.clim = clims - - # set the slider values to the new clims - with signals_blocked(self.clims): - self.clims.setValue(clims) - - c = itertools.count() class PDataStore(Protocol): ... -class PImageHandle(Protocol): - @property - def data(self) -> np.ndarray: ... - @data.setter - def data(self, data: np.ndarray) -> None: ... - @property - def visible(self) -> bool: ... - @visible.setter - def visible(self, visible: bool) -> None: ... - @property - def clim(self) -> Any: ... - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: ... - @property - def cmap(self) -> Any: ... - @cmap.setter - def cmap(self, cmap: Any) -> None: ... - - class StackViewer(QWidget): """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" @@ -162,10 +64,7 @@ def __init__(self, datastore: PDataStore, *, parent: QWidget | None = None): layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) - self._channel_controls: dict[Hashable, ChannelVisControl] = {} - for i, ch in enumerate(["DAPI", "FITC"]): - self._channel_controls[i] = c = ChannelVisControl(ch, self._channels[i]) - layout.addWidget(c) + self._channel_controls: dict[Hashable, LutControl] = {} def set_channel_mode(self, mode: Literal["composite", "grayscale"]) -> None: if mode == getattr(self, "_channel_mode", None): @@ -201,19 +100,18 @@ def _on_dims_sliders_changed(self, index: dict) -> None: def _update_data_for_index(self, index: dict) -> None: key = self._image_key(index) data = self._isel(index) - if handles := self._channels.get(key): + if handles := self._channels[key]: for handle in handles: handle.data = data if ctrl := self._channel_controls.get(key, None): ctrl.update_autoscale() else: - cm = ( - next(COLORMAPS) - if self._channel_mode == "composite" - else cmap.Colormap("gray") - ) - new_img = self._canvas.add_image(data, cmap=cm) - self._channels[key].append(new_img) + cm = next(COLORMAPS) if self._channel_mode == "composite" else GRAYS + handles.append(self._canvas.add_image(data, cmap=cm)) + if key not in self._channel_controls: + channel_name = f"Channel {key}" + self._channel_controls[key] = c = LutControl(channel_name, handles) + self.layout().addWidget(c) class MDAViewer(StackViewer): diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py index fcb37383a..25d379421 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -3,12 +3,13 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any -import cmap from qtpy.QtCore import Signal from qtpy.QtWidgets import QVBoxLayout, QWidget from vispy import scene +from superqt.utils import qthrottled if TYPE_CHECKING: + import cmap import numpy as np from vispy.scene.events import SceneMouseEvent @@ -43,10 +44,11 @@ def clim(self, clims: tuple[float, float]) -> None: @property def cmap(self) -> cmap.Colormap: - return cmap.Colormap(self._image.cmap) + return self._cmap @cmap.setter def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap self._image.cmap = cmap.to_vispy() @@ -82,13 +84,14 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" - if cmap is not None: - cmap = cmap.to_vispy() - img = scene.visuals.Image(data, cmap=cmap, parent=self._view.scene) + img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True self.set_range() - return VispyImageHandle(img) + handle = VispyImageHandle(img) + if cmap is not None: + handle.cmap = cmap + return handle def set_range( self, @@ -102,6 +105,7 @@ def set_range( """ self._camera.set_range(x=x, y=y, margin=margin) + @qthrottled(timeout=50) def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" images = [] diff --git a/x.py b/x.py new file mode 100644 index 000000000..ecd5f2014 --- /dev/null +++ b/x.py @@ -0,0 +1,20 @@ +import useq +from pymmcore_plus import CMMCorePlus +from pymmcore_plus.mda.handlers import OMEZarrWriter + +core = CMMCorePlus() +core.loadSystemConfiguration() +seq = useq.MDASequence( + channels=["DAPI", "FITC"], + stage_positions=[(1, 2, 3)], + time_plan={"interval": 0, "loops": 3}, + grid_plan={"rows": 2, "columns": 1}, + z_plan={"range": 2, "step": 0.7}, +) +writer = OMEZarrWriter() +core.mda.run(seq, output=writer) + +xa = writer.as_xarray() +da = xa["p0"] +print(da) +print(da.dims) diff --git a/y.py b/y.py new file mode 100644 index 000000000..79f579143 --- /dev/null +++ b/y.py @@ -0,0 +1,35 @@ +import zarr +import numpy as np +import xarray as xr +from zarr.storage import KVStore + +# Create an in-memory store +memory_store = KVStore(dict()) + +# Create some data +data = np.random.randn(3, 2, 512, 512) # Shape corresponding to (t, c, y, x) + +# Create a Zarr group in the memory store +root = zarr.group(store=memory_store, overwrite=True) + +# Add dimensions and coordinates +t = np.array([0, 1, 2]) # Time coordinates +c = np.array(['DAPI', 'FITC']) # Channel labels + +# Create the dataset within the group +dset = root.create_dataset('data', data=data, chunks=(1, 1, 256, 256), dtype='float32') + +# Add attributes for xarray compatibility +dset.attrs['_ARRAY_DIMENSIONS'] = ['t', 'c', 'y', 'x'] + +# Create coordinate datasets +root['t'] = t +# root['c'] = c +root['t'].attrs['_ARRAY_DIMENSIONS'] = ['t'] +# root['c'].attrs['_ARRAY_DIMENSIONS'] = ['c'] + +# Open the Zarr group with xarray directly using the in-memory store +ds = xr.open_zarr(memory_store, consolidated=False) + +# Print the xarray dataset +print(ds['data']) diff --git a/z.py b/z.py new file mode 100644 index 000000000..8aa22b8ba --- /dev/null +++ b/z.py @@ -0,0 +1,10 @@ +from PySide6 import QtWidgets +from qtpy.QtWidgets import QApplication +from superqt import QLabeledRangeSlider + +app = QApplication([]) +sld = QLabeledRangeSlider() +# sld = QSlider() +sld.valueChanged.connect(lambda x: print(x)) +sld.show() +app.exec_() From d066d77b2e79b0587a5d9eb97325abc32a293db4 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 1 May 2024 20:06:27 -0400 Subject: [PATCH 07/73] some linting --- examples/stack_viewer2.py | 3 ++- .../_stack_viewer2/_vispy_canvas.py | 2 +- y.py | 16 ++++++++-------- z.py | 1 - 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 185710a80..288d8cd8d 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -1,6 +1,7 @@ from __future__ import annotations -from PySide6 import QtWidgets + from pymmcore_plus import CMMCorePlus, configure_logging +from PySide6 import QtWidgets from qtpy import QtWidgets from useq import MDASequence diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py index 25d379421..750a868ed 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -5,8 +5,8 @@ from qtpy.QtCore import Signal from qtpy.QtWidgets import QVBoxLayout, QWidget -from vispy import scene from superqt.utils import qthrottled +from vispy import scene if TYPE_CHECKING: import cmap diff --git a/y.py b/y.py index 79f579143..3787a15bb 100644 --- a/y.py +++ b/y.py @@ -1,10 +1,10 @@ -import zarr import numpy as np import xarray as xr +import zarr from zarr.storage import KVStore # Create an in-memory store -memory_store = KVStore(dict()) +memory_store = KVStore({}) # Create some data data = np.random.randn(3, 2, 512, 512) # Shape corresponding to (t, c, y, x) @@ -14,22 +14,22 @@ # Add dimensions and coordinates t = np.array([0, 1, 2]) # Time coordinates -c = np.array(['DAPI', 'FITC']) # Channel labels +c = np.array(["DAPI", "FITC"]) # Channel labels # Create the dataset within the group -dset = root.create_dataset('data', data=data, chunks=(1, 1, 256, 256), dtype='float32') +dset = root.create_dataset("data", data=data, chunks=(1, 1, 256, 256), dtype="float32") # Add attributes for xarray compatibility -dset.attrs['_ARRAY_DIMENSIONS'] = ['t', 'c', 'y', 'x'] +dset.attrs["_ARRAY_DIMENSIONS"] = ["t", "c", "y", "x"] # Create coordinate datasets -root['t'] = t +root["t"] = t # root['c'] = c -root['t'].attrs['_ARRAY_DIMENSIONS'] = ['t'] +root["t"].attrs["_ARRAY_DIMENSIONS"] = ["t"] # root['c'].attrs['_ARRAY_DIMENSIONS'] = ['c'] # Open the Zarr group with xarray directly using the in-memory store ds = xr.open_zarr(memory_store, consolidated=False) # Print the xarray dataset -print(ds['data']) +print(ds["data"]) diff --git a/z.py b/z.py index 8aa22b8ba..f35c64285 100644 --- a/z.py +++ b/z.py @@ -1,4 +1,3 @@ -from PySide6 import QtWidgets from qtpy.QtWidgets import QApplication from superqt import QLabeledRangeSlider From be222a468c50efd278fd711cd2d3d2d97e92dcba Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 1 May 2024 20:47:07 -0400 Subject: [PATCH 08/73] wip on composite mode --- examples/stack_viewer2.py | 1 - .../_stack_viewer2/_dims_slider.py | 3 + .../_stack_viewer2/_lut_control.py | 2 + .../_stack_viewer2/_stack_viewer.py | 63 +++++++++++++++---- .../_stack_viewer2/_vispy_canvas.py | 8 ++- 5 files changed, 63 insertions(+), 14 deletions(-) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 288d8cd8d..58ea55168 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -1,7 +1,6 @@ from __future__ import annotations from pymmcore_plus import CMMCorePlus, configure_logging -from PySide6 import QtWidgets from qtpy import QtWidgets from useq import MDASequence diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index e8b9619bc..770c5c1cc 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -202,6 +202,9 @@ def setValue(self, values: Indices) -> None: self.add_or_update_dimension(dim, index) self.valueChanged.emit(self.value()) + def maximum(self) -> dict[DimensionKey, int]: + return {k: v._int_slider.maximum() for k, v in self._sliders.items()} + def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) slider.setRange(0, 1) diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index c21ab5018..49a51932b 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -30,6 +30,7 @@ def clim(self, clims: tuple[float, float]) -> None: ... def cmap(self) -> Any: ... @cmap.setter def cmap(self, cmap: Any) -> None: ... + def remove(self) -> None: ... class LutControl(QWidget): @@ -60,6 +61,7 @@ def __init__( self._auto_clim = QCheckBox("Auto") self._auto_clim.toggled.connect(self.update_autoscale) + self._auto_clim.setChecked(True) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 9d50e539b..0c1434210 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -3,14 +3,14 @@ import itertools from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Hashable, Literal, Mapping, Protocol +from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping import cmap import superqt import useq from psygnal import Signal as psygnalSignal from pymmcore_plus.mda.handlers import OMEZarrWriter -from qtpy.QtWidgets import QLabel, QVBoxLayout, QWidget +from qtpy.QtWidgets import QLabel, QPushButton, QVBoxLayout, QWidget, QHBoxLayout from ._dims_slider import DimsSliders from ._lut_control import LutControl, PImageHandle @@ -25,6 +25,7 @@ COLORMAPS = cycle( [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] ) +c = itertools.count() # FIXME: get rid of this thin subclass @@ -36,45 +37,83 @@ def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> Non self.frame_ready.emit(frame, event) -c = itertools.count() +class ColorModeButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + self._modes = cycle(["grayscale", "composite"]) + super().__init__(parent) + self.clicked.connect(self._on_clicked) + self.setText(next(self._modes)) + def _on_clicked(self) -> None: + self.setText(next(self._modes)) -class PDataStore(Protocol): ... + def mode(self) -> str: + return self.text() class StackViewer(QWidget): """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" - def __init__(self, datastore: PDataStore, *, parent: QWidget | None = None): + def __init__(self, datastore: Any, *, parent: QWidget | None = None): super().__init__(parent=parent) self._channels: defaultdict[Hashable, list[PImageHandle]] = defaultdict(list) + self._channel_controls: dict[Hashable, LutControl] = {} self.datastore = datastore self._canvas = VispyViewerCanvas() self._info_bar = QLabel("Info") self._dims_sliders = DimsSliders() - self.set_channel_mode("composite") + self.set_channel_mode("grayscale") self._canvas.infoText.connect(lambda x: self._info_bar.setText(x)) self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) + self._channel_mode_picker = ColorModeButton("Channel Mode") + self._channel_mode_picker.clicked.connect(self.set_channel_mode) + self._set_range_btn = QPushButton("Set Range") + self._set_range_btn.clicked.connect(self._set_range_clicked) + + btns = QHBoxLayout() + btns.addWidget(self._channel_mode_picker) + btns.addWidget(self._set_range_btn) layout = QVBoxLayout(self) + layout.addLayout(btns) layout.addWidget(self._canvas, 1) layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) - self._channel_controls: dict[Hashable, LutControl] = {} + def _set_range_clicked(self) -> None: + self._canvas.set_range() - def set_channel_mode(self, mode: Literal["composite", "grayscale"]) -> None: + def set_channel_mode( + self, mode: Literal["composite", "grayscale"] | None = None + ) -> None: + if mode is None or isinstance(mode, bool): + mode = self._channel_mode_picker.mode() if mode == getattr(self, "_channel_mode", None): return self._channel_mode = mode - if mode == "composite": - self._dims_sliders.set_dimension_visible(CHANNEL, False) - else: - self._dims_sliders.set_dimension_visible(CHANNEL, True) + c_visible = mode != "composite" + self._dims_sliders.set_dimension_visible(CHANNEL, c_visible) + num_channels = self._dims_sliders.maximum().get(CHANNEL, -1) + 1 + value = self._dims_sliders.value() + if self._channels: + for handles in self._channels.values(): + for handle in handles: + handle.remove() + self._channels.clear() + for c in self._channel_controls.values(): + self.layout().removeWidget(c) + c.deleteLater() + self._channel_controls.clear() + if c_visible: + self._update_data_for_index(value) + else: + for i in range(num_channels): + self._update_data_for_index({**value, CHANNEL: i}) + self._canvas.refresh() def _image_key(self, index: Mapping[str, int]) -> Hashable: if self._channel_mode == "composite": diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py index 750a868ed..49658ffa8 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -51,6 +51,9 @@ def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap self._image.cmap = cmap.to_vispy() + def remove(self) -> None: + self._image.parent = None + class VispyViewerCanvas(QWidget): """Vispy-based viewer for data. @@ -69,6 +72,7 @@ def __init__( self._canvas = scene.SceneCanvas(parent=self) self._canvas.events.mouse_move.connect(self._on_mouse_move) self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + self._has_set_range = False central_wdg: scene.Widget = self._canvas.central_widget self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) @@ -87,7 +91,9 @@ def add_image( img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True - self.set_range() + if not self._has_set_range: + self.set_range() + self._has_set_range = True handle = VispyImageHandle(img) if cmap is not None: handle.cmap = cmap From 0c3edb51b21adb2dab68dad052bc5f1839ea954b Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 2 May 2024 09:14:27 -0400 Subject: [PATCH 09/73] starting pygfx --- examples/stack_viewer2.py | 5 +- .../_stack_viewer2/_dims_slider.py | 10 +- .../_stack_viewer2/_flow_layout.py | 99 ----------- .../_stack_viewer2/_lut_control.py | 28 +--- .../_stack_viewer2/_protocols.py | 42 +++++ .../_stack_viewer2/_pygfx_canvas.py | 155 ++++++++++++++++++ .../_stack_viewer2/_stack_viewer.py | 75 +++++---- .../_stack_viewer2/_vispy_canvas.py | 27 +-- 8 files changed, 262 insertions(+), 179 deletions(-) delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_flow_layout.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_protocols.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 58ea55168..9d3d8c9b4 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -31,4 +31,7 @@ v.show() mmcore.run_mda(sequence, output=v.datastore) -qapp.exec() +# qapp.exec() + +from wgpu.gui.auto import WgpuCanvas, run +run() \ No newline at end of file diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 770c5c1cc..043488f3f 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from warnings import warn from qtpy.QtCore import Qt, Signal @@ -150,15 +150,15 @@ def _toggle_animation(self, checked: bool) -> None: def timerEvent(self, event: Any) -> None: if self._slice_mode: - val = self._slice_slider.value() + val = cast(tuple[int, int], self._slice_slider.value()) next_val = [v + 1 for v in val] if next_val[1] > self._slice_slider.maximum(): next_val = [v - val[0] for v in val] self._slice_slider.setValue(next_val) else: - val = self._int_slider.value() - val = (val + 1) % (self._int_slider.maximum() + 1) - self._int_slider.setValue(val) + ival = self._int_slider.value() + ival = (ival + 1) % (self._int_slider.maximum() + 1) + self._int_slider.setValue(ival) def _on_range_changed(self, min: int, max: int) -> None: self._max_label.setText("/ " + str(max)) diff --git a/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py b/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py deleted file mode 100644 index c3633c670..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_flow_layout.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from qtpy.QtCore import QMargins, QPoint, QRect, QSize, Qt -from qtpy.QtWidgets import QLayout, QLayoutItem, QSizePolicy, QWidget - - -class FlowLayout(QLayout): - def __init__(self, parent: QWidget | None = None): - super().__init__(parent) - - if parent is not None: - self.setContentsMargins(QMargins(0, 0, 0, 0)) - - self._item_list: list[QLayoutItem] = [] - - def __del__(self) -> None: - item = self.takeAt(0) - while item: - item = self.takeAt(0) - - def addItem(self, item: QLayoutItem) -> None: - self._item_list.append(item) - - def count(self) -> int: - return len(self._item_list) - - def itemAt(self, index: int) -> QLayoutItem | None: - if 0 <= index < len(self._item_list): - return self._item_list[index] - return None - - def takeAt(self, index: int) -> QLayoutItem | None: - if 0 <= index < len(self._item_list): - return self._item_list.pop(index) - return None - - def expandingDirections(self) -> Qt.Orientation: - return Qt.Orientation(0) - - def hasHeightForWidth(self) -> bool: - return True - - def heightForWidth(self, width: int) -> int: - height = self._do_layout(QRect(0, 0, width, 0), True) - return height - - def setGeometry(self, rect: QRect) -> None: - super().setGeometry(rect) - self._do_layout(rect, False) - - def sizeHint(self) -> QSize: - return self.minimumSize() - - def minimumSize(self) -> QSize: - size = QSize() - - for item in self._item_list: - size = size.expandedTo(item.minimumSize()) - - size += QSize( - 2 * self.contentsMargins().top(), 2 * self.contentsMargins().top() - ) - return size - - def _do_layout(self, rect: QRect, test_only: bool) -> int: - x = rect.x() - y = rect.y() - line_height = 0 - spacing = self.spacing() - - for item in self._item_list: - style = item.widget().style() - layout_spacing_x = style.layoutSpacing( - QSizePolicy.ControlType.PushButton, - QSizePolicy.ControlType.PushButton, - Qt.Orientation.Horizontal, - ) - layout_spacing_y = style.layoutSpacing( - QSizePolicy.ControlType.PushButton, - QSizePolicy.ControlType.PushButton, - Qt.Orientation.Vertical, - ) - space_x = spacing + layout_spacing_x - space_y = spacing + layout_spacing_y - next_x = x + item.sizeHint().width() + space_x - if next_x - space_x > rect.right() and line_height > 0: - x = rect.x() - y = y + line_height + space_y - next_x = x + item.sizeHint().width() + space_x - line_height = 0 - - if not test_only: - print(x, y) - item.setGeometry(QRect(QPoint(x, y), item.sizeHint())) - - x = next_x - line_height = max(line_height, item.sizeHint().height()) - - return y + line_height - rect.y() diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index 49a51932b..c106ab561 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Protocol +from typing import TYPE_CHECKING, Iterable import numpy as np from qtpy.QtCore import Qt @@ -12,25 +12,7 @@ if TYPE_CHECKING: import cmap - -class PImageHandle(Protocol): - @property - def data(self) -> np.ndarray: ... - @data.setter - def data(self, data: np.ndarray) -> None: ... - @property - def visible(self) -> bool: ... - @visible.setter - def visible(self, visible: bool) -> None: ... - @property - def clim(self) -> Any: ... - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: ... - @property - def cmap(self) -> Any: ... - @cmap.setter - def cmap(self, cmap: Any) -> None: ... - def remove(self) -> None: ... + from ._protocols import PImageHandle class LutControl(QWidget): @@ -96,9 +78,9 @@ def update_autoscale(self) -> None: clims[0] = min(clims[0], np.nanmin(handle.data)) clims[1] = max(clims[1], np.nanmax(handle.data)) - clims_ = tuple(int(x) for x in clims) - for handle in self._handles: - handle.clim = clims_ + if (clims_ := tuple(int(x) for x in clims)) != (0, 0): + for handle in self._handles: + handle.clim = clims_ # set the slider values to the new clims with signals_blocked(self._clims): diff --git a/src/pymmcore_widgets/_stack_viewer2/_protocols.py b/src/pymmcore_widgets/_stack_viewer2/_protocols.py new file mode 100644 index 000000000..8c34543f5 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_protocols.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + import cmap + import numpy as np + from qtpy.QtWidgets import QWidget + + +class PImageHandle(Protocol): + @property + def data(self) -> np.ndarray: ... + @data.setter + def data(self, data: np.ndarray) -> None: ... + @property + def visible(self) -> bool: ... + @visible.setter + def visible(self, visible: bool) -> None: ... + @property + def clim(self) -> Any: ... + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: ... + @property + def cmap(self) -> Any: ... + @cmap.setter + def cmap(self, cmap: Any) -> None: ... + def remove(self) -> None: ... + + +class PCanvas(Protocol): + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + margin: float | None = 0.05, + ) -> None: ... + def refresh(self) -> None: ... + def qwidget(self) -> QWidget: ... + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PImageHandle: ... diff --git a/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py new file mode 100644 index 000000000..9c1408fba --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Callable, TypeGuard, cast + +import pygfx +from wgpu.gui.qt import QWgpuCanvas + +if TYPE_CHECKING: + import cmap + import numpy as np + from qtpy.QtWidgets import QWidget + from wgpu.gui import qt + + +class PyGFXImageHandle: + def __init__(self, image: pygfx.Image) -> None: + self._image = image + self._texture = cast("pygfx.Texture", image.geometry.grid) + self._material = cast("pygfx.ImageBasicMaterial", image.material) + + @property + def data(self) -> np.ndarray: + return self._texture._data # type: ignore + + @data.setter + def data(self, data: np.ndarray) -> None: + self._texture._data = data + + @property + def visible(self) -> bool: + return bool(self._image.visible) + + @visible.setter + def visible(self, visible: bool) -> None: + self._image.visible = visible + + @property + def clim(self) -> Any: + return self._material.clim + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + self._material.clim = clims + + @property + def cmap(self) -> cmap.Colormap: + return self._cmap + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + self._cmap = cmap + self._image.cmap = cmap.to_pygfx() + + def remove(self) -> None: + # self._image.parent = None + ... + + +def _is_qt_canvas_type(obj: type) -> TypeGuard[type[qt.WgpuCanvas]]: + if wgpu_qt := sys.modules.get("wgpu.gui.qt"): + return issubclass(obj, wgpu_qt.WgpuCanvas) + return False + + +class PyGFXViewerCanvas: + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + + self._canvas = QWgpuCanvas() + self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) + self._viewport: pygfx.Viewport = pygfx.Viewport(self._renderer) + self._scene = pygfx.Scene() + + self._camera = cam = pygfx.OrthographicCamera(512, 512) + cam.local.position = (256, 256, 0) + cam.scale_y = -1 + self._controller = pygfx.PanZoomController(cam) + + # TODO: background_color + # the qt backend, this shows by default... + # if we need to prevent it, we could potentially monkeypatch during init. + # if hasattr(self._canvas, "hide"): + # self._canvas.hide() + + def qwidget(self) -> QWidget: + return self._canvas + + def refresh(self) -> None: + self._canvas.request_draw(self._animate) + + def _animate(self, viewport: pygfx.Viewport | None = None) -> None: + vp = viewport or self._viewport + + print("rendering") + vp.render(self._scene, self._camera) + if hasattr(vp.renderer, "flush"): + vp.renderer.flush() + if viewport is None: + self._canvas.request_draw() + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + """Add a new Image node to the scene.""" + img = pygfx.Image( + pygfx.Geometry(grid=pygfx.Texture(data, dim=2)), + pygfx.ImageBasicMaterial(clim=(0, 255)), + ) + self._scene.add(img) + handle = PyGFXImageHandle(img) + if cmap is not None: + handle.cmap = cmap + return handle + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + margin: float | None = 0.05, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + # self._camera.set_range(x=x, y=y, margin=margin) + + # def _on_mouse_move(self, event: SceneMouseEvent) -> None: + # """Mouse moved on the canvas, display the pixel value and position.""" + # images = [] + # # Get the images the mouse is over + # seen = set() + # while visual := self._canvas.visual_at(event.pos): + # if isinstance(visual, scene.visuals.Image): + # images.append(visual) + # visual.interactive = False + # seen.add(visual) + # for visual in seen: + # visual.interactive = True + # if not images: + # return + + # tform = images[0].get_transform("canvas", "visual") + # px, py, *_ = (int(x) for x in tform.map(event.pos)) + # text = f"[{py}, {px}]" + # for c, img in enumerate(images): + # with suppress(IndexError): + # text += f" c{c}: {img._data[py, px]}" + # self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 0c1434210..ad01d2f41 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,31 +1,31 @@ from __future__ import annotations -import itertools from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping +from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, cast import cmap import superqt import useq from psygnal import Signal as psygnalSignal -from pymmcore_plus.mda.handlers import OMEZarrWriter -from qtpy.QtWidgets import QLabel, QPushButton, QVBoxLayout, QWidget, QHBoxLayout +from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from ._dims_slider import DimsSliders -from ._lut_control import LutControl, PImageHandle +from ._lut_control import LutControl from ._vispy_canvas import VispyViewerCanvas +from ._pygfx_canvas import PyGFXViewerCanvas if TYPE_CHECKING: import numpy as np + from ._protocols import PCanvas, PImageHandle + + ColorMode = Literal["composite", "grayscale"] CHANNEL = "c" GRAYS = cmap.Colormap("gray") -COLORMAPS = cycle( - [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] -) -c = itertools.count() +COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] # FIXME: get rid of this thin subclass @@ -39,16 +39,18 @@ def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> Non class ColorModeButton(QPushButton): def __init__(self, parent: QWidget | None = None): - self._modes = cycle(["grayscale", "composite"]) - super().__init__(parent) - self.clicked.connect(self._on_clicked) - self.setText(next(self._modes)) - - def _on_clicked(self) -> None: + modes = ["composite", "grayscale"] + self._modes = cycle(modes) + super().__init__(modes[-1], parent) + self.clicked.connect(self.next_mode) + self.next_mode() + + def next_mode(self) -> None: + self._mode = self.text() self.setText(next(self._modes)) - def mode(self) -> str: - return self.text() + def mode(self) -> ColorMode: + return self._mode # type: ignore class StackViewer(QWidget): @@ -61,17 +63,18 @@ def __init__(self, datastore: Any, *, parent: QWidget | None = None): self._channel_controls: dict[Hashable, LutControl] = {} self.datastore = datastore - self._canvas = VispyViewerCanvas() self._info_bar = QLabel("Info") + # self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) + self._canvas: PCanvas = PyGFXViewerCanvas(self._info_bar.setText) self._dims_sliders = DimsSliders() + self._cmaps = cycle(COLORMAPS) self.set_channel_mode("grayscale") - self._canvas.infoText.connect(lambda x: self._info_bar.setText(x)) self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) - self._channel_mode_picker = ColorModeButton("Channel Mode") + self._channel_mode_picker = ColorModeButton() self._channel_mode_picker.clicked.connect(self.set_channel_mode) - self._set_range_btn = QPushButton("Set Range") + self._set_range_btn = QPushButton("reset zoom") self._set_range_btn.clicked.connect(self._set_range_clicked) btns = QHBoxLayout() @@ -79,21 +82,20 @@ def __init__(self, datastore: Any, *, parent: QWidget | None = None): btns.addWidget(self._set_range_btn) layout = QVBoxLayout(self) layout.addLayout(btns) - layout.addWidget(self._canvas, 1) + layout.addWidget(self._canvas.qwidget(), 1) layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) def _set_range_clicked(self) -> None: self._canvas.set_range() - def set_channel_mode( - self, mode: Literal["composite", "grayscale"] | None = None - ) -> None: + def set_channel_mode(self, mode: ColorMode | None = None) -> None: if mode is None or isinstance(mode, bool): mode = self._channel_mode_picker.mode() if mode == getattr(self, "_channel_mode", None): return + self._cmaps = cycle(COLORMAPS) self._channel_mode = mode c_visible = mode != "composite" self._dims_sliders.set_dimension_visible(CHANNEL, c_visible) @@ -105,7 +107,7 @@ def set_channel_mode( handle.remove() self._channels.clear() for c in self._channel_controls.values(): - self.layout().removeWidget(c) + cast("QVBoxLayout", self.layout()).removeWidget(c) c.deleteLater() self._channel_controls.clear() if c_visible: @@ -120,7 +122,7 @@ def _image_key(self, index: Mapping[str, int]) -> Hashable: return index.get("c", 0) return 0 - def _isel(self, index: dict) -> np.ndarray: + def _isel(self, index: Mapping) -> np.ndarray: return isel(self.datastore, index) def _on_dims_sliders_changed(self, index: dict) -> None: @@ -136,7 +138,7 @@ def _on_dims_sliders_changed(self, index: dict) -> None: self._update_data_for_index(idx) self._canvas.refresh() - def _update_data_for_index(self, index: dict) -> None: + def _update_data_for_index(self, index: Mapping) -> None: key = self._image_key(index) data = self._isel(index) if handles := self._channels[key]: @@ -145,18 +147,16 @@ def _update_data_for_index(self, index: dict) -> None: if ctrl := self._channel_controls.get(key, None): ctrl.update_autoscale() else: - cm = next(COLORMAPS) if self._channel_mode == "composite" else GRAYS + cm = next(self._cmaps) if self._channel_mode == "composite" else GRAYS handles.append(self._canvas.add_image(data, cmap=cm)) if key not in self._channel_controls: channel_name = f"Channel {key}" self._channel_controls[key] = c = LutControl(channel_name, handles) - self.layout().addWidget(c) + cast("QVBoxLayout", self.layout()).addWidget(c) class MDAViewer(StackViewer): def __init__(self, *, parent: QWidget | None = None): - # self._core = CMMCorePlus.instance() - # self._core.mda.events.frameReady.connect(self.on_frame_ready) super().__init__(DataStore(), parent=parent) self.datastore.frame_ready.connect(self.on_frame_ready) @@ -165,7 +165,16 @@ def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: self._dims_sliders.setValue(event.index) -def isel(writer: OMEZarrWriter, indexers: Mapping[str, int | slice]) -> np.ndarray: +def isel(store: Any, indexers: Mapping[str, int | slice]) -> np.ndarray: + if isinstance(store, (OMEZarrWriter, OMETiffWriter)): + return isel_mmcore_5dbase(store, indexers) + + raise NotImplementedError(f"Unknown datastore type {type(store)}") + + +def isel_mmcore_5dbase( + writer: OMEZarrWriter | OMETiffWriter, indexers: Mapping[str, int | slice] +) -> np.ndarray: p_index = indexers.get("p", 0) if isinstance(p_index, slice): raise NotImplementedError("Cannot slice over position index") # TODO diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py index 49658ffa8..2721b4a46 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -1,16 +1,14 @@ from __future__ import annotations from contextlib import suppress -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable, cast -from qtpy.QtCore import Signal -from qtpy.QtWidgets import QVBoxLayout, QWidget -from superqt.utils import qthrottled from vispy import scene if TYPE_CHECKING: import cmap import numpy as np + from qtpy.QtWidgets import QWidget from vispy.scene.events import SceneMouseEvent @@ -55,21 +53,16 @@ def remove(self) -> None: self._image.parent = None -class VispyViewerCanvas(QWidget): +class VispyViewerCanvas: """Vispy-based viewer for data. All vispy-specific code is encapsulated in this class (and non-vispy canvases could be swapped in if needed as long as they implement the same interface). """ - infoText = Signal(str) - - def __init__( - self, - parent: QWidget | None = None, - ) -> None: - super().__init__(parent) - self._canvas = scene.SceneCanvas(parent=self) + def __init__(self, set_info: Callable[[str], None]) -> None: + self._set_info = set_info + self._canvas = scene.SceneCanvas() self._canvas.events.mouse_move.connect(self._on_mouse_move) self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) self._has_set_range = False @@ -77,9 +70,8 @@ def __init__( central_wdg: scene.Widget = self._canvas.central_widget self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._canvas.native) + def qwidget(self) -> QWidget: + return cast("QWidget", self._canvas.native) def refresh(self) -> None: self._canvas.update() @@ -111,7 +103,6 @@ def set_range( """ self._camera.set_range(x=x, y=y, margin=margin) - @qthrottled(timeout=50) def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" images = [] @@ -133,4 +124,4 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: for c, img in enumerate(images): with suppress(IndexError): text += f" c{c}: {img._data[py, px]}" - self.infoText.emit(text) + self._set_info(text) From aa412e48cc007463b021863220c9304feaddc8ae Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 2 May 2024 11:42:00 -0400 Subject: [PATCH 10/73] more updates --- examples/stack_viewer2.py | 7 +- .../_stack_viewer2/_dims_slider.py | 10 ++- .../_stack_viewer2/_pygfx_canvas.py | 64 ++++++--------- .../_stack_viewer2/_stack_viewer.py | 80 ++++++++++++++----- .../_stack_viewer2/_vispy_canvas.py | 8 +- z.py | 68 +++++++++++++--- 6 files changed, 161 insertions(+), 76 deletions(-) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 9d3d8c9b4..6e1677fbf 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -30,8 +30,5 @@ v = MDAViewer() v.show() -mmcore.run_mda(sequence, output=v.datastore) -# qapp.exec() - -from wgpu.gui.auto import WgpuCanvas, run -run() \ No newline at end of file +mmcore.run_mda(sequence, output=v._datastore) +qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 043488f3f..3e5c889a1 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -53,7 +53,7 @@ class DimsSlider(QWidget): Adds a label for the maximum value (e.g. "3 / 10") """ - valueChanged = Signal(str, object) # where object is int | slice + valueChanged = Signal(object, object) # where object is int | slice def __init__( self, dimension_key: DimensionKey, parent: QWidget | None = None @@ -205,6 +205,12 @@ def setValue(self, values: Indices) -> None: def maximum(self) -> dict[DimensionKey, int]: return {k: v._int_slider.maximum() for k, v in self._sliders.items()} + def setMaximum(self, values: Mapping[DimensionKey, int]) -> None: + for name, max_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMaximum(max_val) + def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) slider.setRange(0, 1) @@ -215,7 +221,7 @@ def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: slider.setVisible(name not in self._invisible_dims) self.layout().addWidget(slider) - def set_dimension_visible(self, name: str, visible: bool) -> None: + def set_dimension_visible(self, name: Hashable, visible: bool) -> None: if visible: self._invisible_dims.discard(name) else: diff --git a/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py index 9c1408fba..05ac13c37 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py @@ -1,31 +1,31 @@ from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Any, Callable, TypeGuard, cast +from typing import TYPE_CHECKING, Any, Callable, cast import pygfx +import pygfx.geometries +import pygfx.materials from wgpu.gui.qt import QWgpuCanvas if TYPE_CHECKING: import cmap import numpy as np from qtpy.QtWidgets import QWidget - from wgpu.gui import qt class PyGFXImageHandle: def __init__(self, image: pygfx.Image) -> None: self._image = image - self._texture = cast("pygfx.Texture", image.geometry.grid) - self._material = cast("pygfx.ImageBasicMaterial", image.material) + self._geom = cast("pygfx.geometries.Geometry", image.geometry.grid) + self._material = cast("pygfx.materials.ImageBasicMaterial", image.material) @property def data(self) -> np.ndarray: - return self._texture._data # type: ignore + return self._geom._data # type: ignore @data.setter def data(self, data: np.ndarray) -> None: - self._texture._data = data + self._geom.grid = pygfx.Texture(data, dim=2) @property def visible(self) -> bool: @@ -50,17 +50,11 @@ def cmap(self) -> cmap.Colormap: @cmap.setter def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap - self._image.cmap = cmap.to_pygfx() + self._material.map = cmap.to_pygfx() def remove(self) -> None: - # self._image.parent = None - ... - - -def _is_qt_canvas_type(obj: type) -> TypeGuard[type[qt.WgpuCanvas]]: - if wgpu_qt := sys.modules.get("wgpu.gui.qt"): - return issubclass(obj, wgpu_qt.WgpuCanvas) - return False + if (par := self._image.parent) is not None: + par.remove(self._image) class PyGFXViewerCanvas: @@ -73,48 +67,38 @@ class PyGFXViewerCanvas: def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info - self._canvas = QWgpuCanvas() + self._canvas = QWgpuCanvas(size=(512, 512)) self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) - self._viewport: pygfx.Viewport = pygfx.Viewport(self._renderer) self._scene = pygfx.Scene() - self._camera = cam = pygfx.OrthographicCamera(512, 512) - cam.local.position = (256, 256, 0) - cam.scale_y = -1 - self._controller = pygfx.PanZoomController(cam) - # TODO: background_color - # the qt backend, this shows by default... - # if we need to prevent it, we could potentially monkeypatch during init. - # if hasattr(self._canvas, "hide"): - # self._canvas.hide() + cam.local.position = (256, 256, 0) + cam.local.scale_y = -1 + controller = pygfx.PanZoomController(cam, register_events=self._renderer) + # increase zoom wheel gain + controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) def qwidget(self) -> QWidget: return self._canvas def refresh(self) -> None: + self._canvas.update() self._canvas.request_draw(self._animate) - def _animate(self, viewport: pygfx.Viewport | None = None) -> None: - vp = viewport or self._viewport - - print("rendering") - vp.render(self._scene, self._camera) - if hasattr(vp.renderer, "flush"): - vp.renderer.flush() - if viewport is None: - self._canvas.request_draw() + def _animate(self) -> None: + print("animate") + self._renderer.render(self._scene, self._camera) def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> PyGFXImageHandle: """Add a new Image node to the scene.""" - img = pygfx.Image( + image = pygfx.Image( pygfx.Geometry(grid=pygfx.Texture(data, dim=2)), - pygfx.ImageBasicMaterial(clim=(0, 255)), + pygfx.ImageBasicMaterial(), ) - self._scene.add(img) - handle = PyGFXImageHandle(img) + self._scene.add(image) + handle = PyGFXImageHandle(image) if cmap is not None: handle.cmap = cmap return handle diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index ad01d2f41..9a3798d08 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, cast import cmap +import numpy as np import superqt import useq from psygnal import Signal as psygnalSignal @@ -13,17 +14,15 @@ from ._dims_slider import DimsSliders from ._lut_control import LutControl + +# from ._pygfx_canvas import PyGFXViewerCanvas from ._vispy_canvas import VispyViewerCanvas -from ._pygfx_canvas import PyGFXViewerCanvas if TYPE_CHECKING: - import numpy as np - from ._protocols import PCanvas, PImageHandle ColorMode = Literal["composite", "grayscale"] -CHANNEL = "c" GRAYS = cmap.Colormap("gray") COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] @@ -56,16 +55,25 @@ def mode(self) -> ColorMode: class StackViewer(QWidget): """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" - def __init__(self, datastore: Any, *, parent: QWidget | None = None): + def __init__( + self, + data: Any, + *, + parent: QWidget | None = None, + channel_axis: int | str = 0, + ): super().__init__(parent=parent) self._channels: defaultdict[Hashable, list[PImageHandle]] = defaultdict(list) self._channel_controls: dict[Hashable, LutControl] = {} - self.datastore = datastore + self._sizes = {} + self.set_data(data) + self._channel_axis = channel_axis + self._info_bar = QLabel("Info") - # self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) - self._canvas: PCanvas = PyGFXViewerCanvas(self._info_bar.setText) + self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) + # self._canvas: PCanvas = PyGFXViewerCanvas(self._info_bar.setText) self._dims_sliders = DimsSliders() self._cmaps = cycle(COLORMAPS) self.set_channel_mode("grayscale") @@ -86,6 +94,33 @@ def __init__(self, datastore: Any, *, parent: QWidget | None = None): layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) + def set_data(self, data: Any, sizes: Mapping | None = None) -> None: + if sizes is not None: + self._sizes = dict(sizes) + else: + if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): + self._sizes = sz + elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): + self._sizes = {k: v - 1 for k, v in enumerate(shp[:-2])} + else: + self._sizes = {} + self._datastore = data + + @property + def sizes(self) -> Mapping[Hashable, int]: + return self._sizes + + def update_slider_maxima( + self, sizes: Any | tuple[int, ...] | Mapping[Hashable, int] | None = None + ) -> None: + if sizes is None: + _sizes = self.sizes + elif isinstance(sizes, tuple): + _sizes = {k: v - 1 for k, v in enumerate(sizes[:-2])} + elif not isinstance(sizes, Mapping): + raise ValueError(f"Invalid shape {sizes}") + self._dims_sliders.setMaximum(_sizes) + def _set_range_clicked(self) -> None: self._canvas.set_range() @@ -98,8 +133,8 @@ def set_channel_mode(self, mode: ColorMode | None = None) -> None: self._cmaps = cycle(COLORMAPS) self._channel_mode = mode c_visible = mode != "composite" - self._dims_sliders.set_dimension_visible(CHANNEL, c_visible) - num_channels = self._dims_sliders.maximum().get(CHANNEL, -1) + 1 + self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) + num_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 value = self._dims_sliders.value() if self._channels: for handles in self._channels.values(): @@ -114,7 +149,7 @@ def set_channel_mode(self, mode: ColorMode | None = None) -> None: self._update_data_for_index(value) else: for i in range(num_channels): - self._update_data_for_index({**value, CHANNEL: i}) + self._update_data_for_index({**value, self._channel_axis: i}) self._canvas.refresh() def _image_key(self, index: Mapping[str, int]) -> Hashable: @@ -123,16 +158,16 @@ def _image_key(self, index: Mapping[str, int]) -> Hashable: return 0 def _isel(self, index: Mapping) -> np.ndarray: - return isel(self.datastore, index) + return isel(self._datastore, index) def _on_dims_sliders_changed(self, index: dict) -> None: """Set the current image index.""" - c = index.get(CHANNEL, 0) + c = index.get(self._channel_axis, 0) indices = [index] if self._channel_mode == "composite": for i, handles in self._channels.items(): if handles and c != i: - indices.append({**index, CHANNEL: i}) + indices.append({**index, self._channel_axis: i}) for idx in indices: self._update_data_for_index(idx) @@ -154,24 +189,33 @@ def _update_data_for_index(self, index: Mapping) -> None: self._channel_controls[key] = c = LutControl(channel_name, handles) cast("QVBoxLayout", self.layout()).addWidget(c) + def setIndex(self, index: Mapping[str, int]) -> None: + self._dims_sliders.setValue(index) + class MDAViewer(StackViewer): def __init__(self, *, parent: QWidget | None = None): - super().__init__(DataStore(), parent=parent) - self.datastore.frame_ready.connect(self.on_frame_ready) + super().__init__(DataStore(), parent=parent, channel_axis="c") + self._datastore.frame_ready.connect(self.on_frame_ready) @superqt.ensure_main_thread def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: - self._dims_sliders.setValue(event.index) + self.setIndex(event.index) def isel(store: Any, indexers: Mapping[str, int | slice]) -> np.ndarray: if isinstance(store, (OMEZarrWriter, OMETiffWriter)): return isel_mmcore_5dbase(store, indexers) - + if isinstance(store, np.ndarray): + return isel_np_array(store, indexers) raise NotImplementedError(f"Unknown datastore type {type(store)}") +def isel_np_array(data: np.ndarray, indexers: Mapping[str, int | slice]) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(data.ndim)) + return data[idx] + + def isel_mmcore_5dbase( writer: OMEZarrWriter | OMETiffWriter, indexers: Mapping[str, int | slice] ) -> np.ndarray: diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py index 2721b4a46..c07e78f7e 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py @@ -3,11 +3,11 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, cast +import numpy as np from vispy import scene if TYPE_CHECKING: import cmap - import numpy as np from qtpy.QtWidgets import QWidget from vispy.scene.events import SceneMouseEvent @@ -22,6 +22,8 @@ def data(self) -> np.ndarray: @data.setter def data(self, data: np.ndarray) -> None: + if data.dtype == np.float64: + data = data.astype(np.float32) self._image.set_data(data) @property @@ -80,6 +82,8 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" + if data is not None and data.dtype == np.float64: + data = data.astype(np.float32) img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True @@ -95,7 +99,7 @@ def set_range( self, x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, - margin: float | None = 0.05, + margin: float | None = 0.01, ) -> None: """Update the range of the PanZoomCamera. diff --git a/z.py b/z.py index f35c64285..63b79486f 100644 --- a/z.py +++ b/z.py @@ -1,9 +1,59 @@ -from qtpy.QtWidgets import QApplication -from superqt import QLabeledRangeSlider - -app = QApplication([]) -sld = QLabeledRangeSlider() -# sld = QSlider() -sld.valueChanged.connect(lambda x: print(x)) -sld.show() -app.exec_() +from __future__ import annotations + +import numpy as np +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer + + +def generate_5d_sine_wave( + shape: tuple[int, int, int, int, int], + amplitude: float = 1.0, + base_frequency: float = 5, +) -> np.ndarray: + # Unpack the dimensions + angle_dim, freq_dim, phase_dim, ny, nx = shape + + # Create an empty array to hold the data + output = np.zeros(shape) + + # Define spatial coordinates for the last two dimensions + half_per = base_frequency * np.pi + x = np.linspace(-half_per, half_per, nx) + y = np.linspace(-half_per, half_per, ny) + y, x = np.meshgrid(y, x) + + # Iterate through each parameter in the higher dimensions + for phase_idx in range(phase_dim): + for freq_idx in range(freq_dim): + for angle_idx in range(angle_dim): + # Calculate phase and frequency + phase = np.pi / phase_dim * phase_idx + frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step + + # Calculate angle + angle = np.pi / angle_dim * angle_idx + # Rotate x and y coordinates + xr = np.cos(angle) * x - np.sin(angle) * y + np.sin(angle) * x + np.cos(angle) * y + + # Compute the sine wave + sine_wave = amplitude * np.sin(frequency * xr + phase) + + # Assign to the output array + output[angle_idx, freq_idx, phase_idx] = sine_wave + + return output + + +# Example usage +array_shape = (10, 5, 5, 512, 512) # Specify the desired dimensions +sine_wave_5d = generate_5d_sine_wave(array_shape) + + +qapp = QtWidgets.QApplication([]) +v = StackViewer(sine_wave_5d) +v.show() +v.update_slider_maxima() +v.setIndex({0: 1, 1: 0, 2: 0}) +qapp.exec() From 60101e3529bcec9deff7f71588d6284c44349fac Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 3 May 2024 15:46:28 -0400 Subject: [PATCH 11/73] more wip --- examples/stack_viewer2.py | 2 +- z.py => examples/stack_viewer_numpy.py | 19 +- examples/stack_viewer_xr.py | 19 ++ .../_stack_viewer2/_canvas/__init__.py | 0 .../{_pygfx_canvas.py => _canvas/_pygfx.py} | 2 +- .../_stack_viewer2/_canvas/_qt.py | 243 ++++++++++++++++++ .../{_vispy_canvas.py => _canvas/_vispy.py} | 2 +- .../_stack_viewer2/_canvas/gl.py | 94 +++++++ .../_stack_viewer2/_dims_slider.py | 48 ++-- .../_stack_viewer2/_indexing.py | 52 ++++ .../_stack_viewer2/_lut_control.py | 15 +- .../_stack_viewer2/_mda_viewer.py | 35 +++ .../_stack_viewer2/_stack_viewer.py | 129 ++++------ x.py | 25 +- zz.py | 115 +++++++++ 15 files changed, 669 insertions(+), 131 deletions(-) rename z.py => examples/stack_viewer_numpy.py (82%) create mode 100644 examples/stack_viewer_xr.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py rename src/pymmcore_widgets/_stack_viewer2/{_pygfx_canvas.py => _canvas/_pygfx.py} (98%) create mode 100644 src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py rename src/pymmcore_widgets/_stack_viewer2/{_vispy_canvas.py => _canvas/_vispy.py} (98%) create mode 100644 src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_indexing.py create mode 100644 src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py create mode 100644 zz.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 6e1677fbf..e43ae166d 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -4,7 +4,7 @@ from qtpy import QtWidgets from useq import MDASequence -from pymmcore_widgets._stack_viewer2._stack_viewer import MDAViewer +from pymmcore_widgets._stack_viewer2._mda_viewer import MDAViewer configure_logging(stderr_level="WARNING") diff --git a/z.py b/examples/stack_viewer_numpy.py similarity index 82% rename from z.py rename to examples/stack_viewer_numpy.py index 63b79486f..f6fe394f2 100644 --- a/z.py +++ b/examples/stack_viewer_numpy.py @@ -8,7 +8,7 @@ def generate_5d_sine_wave( shape: tuple[int, int, int, int, int], - amplitude: float = 1.0, + amplitude: float = 240, base_frequency: float = 5, ) -> np.ndarray: # Unpack the dimensions @@ -38,7 +38,8 @@ def generate_5d_sine_wave( np.sin(angle) * x + np.cos(angle) * y # Compute the sine wave - sine_wave = amplitude * np.sin(frequency * xr + phase) + sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase) + sine_wave += amplitude * 0.5 # Assign to the output array output[angle_idx, freq_idx, phase_idx] = sine_wave @@ -50,10 +51,10 @@ def generate_5d_sine_wave( array_shape = (10, 5, 5, 512, 512) # Specify the desired dimensions sine_wave_5d = generate_5d_sine_wave(array_shape) - -qapp = QtWidgets.QApplication([]) -v = StackViewer(sine_wave_5d) -v.show() -v.update_slider_maxima() -v.setIndex({0: 1, 1: 0, 2: 0}) -qapp.exec() +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(sine_wave_5d) + v.show() + v.update_slider_maxima() + v.setIndex({0: 1, 1: 0, 2: 0}) + qapp.exec() diff --git a/examples/stack_viewer_xr.py b/examples/stack_viewer_xr.py new file mode 100644 index 000000000..78f37fe63 --- /dev/null +++ b/examples/stack_viewer_xr.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +# from stack_viewer_numpy import generate_5d_sine_wave +import nd2 +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer + +# array_shape = (10, 5, 3, 512, 512) # Specify the desired dimensions +# sine_wave_5d = generate_5d_sine_wave(array_shape) +# data = xr.DataArray(sine_wave_5d, dims=["a", "f", "p", "y", "x"]) + +data = and2.imread("~/dev/self/nd2/tests/data/t3p3z5c3.and2", xarray=True) +qapp = QtWidgets.QApplication([]) +v = StackViewer(data, channel_axis="C") +v.show() +v.update_slider_maxima() +v.setIndex({}) +qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py similarity index 98% rename from src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py rename to src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py index 05ac13c37..25529a848 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_pygfx_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py @@ -79,7 +79,7 @@ def __init__(self, set_info: Callable[[str], None]) -> None: controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) def qwidget(self) -> QWidget: - return self._canvas + return cast("QWidget", self._canvas) def refresh(self) -> None: self._canvas.update() diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py new file mode 100644 index 000000000..e72ebcd36 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import sys +from typing import Any, Callable + +import cmap +import numpy as np +from qtpy.QtCore import Qt, QTimer +from qtpy.QtGui import QImage, QPixmap +from qtpy.QtWidgets import ( + QApplication, + QGraphicsPixmapItem, + QGraphicsScene, + QGraphicsView, + QVBoxLayout, + QWidget, +) + +_FORMATS: dict[tuple[np.dtype, int], QImage.Format] = { + (np.dtype(np.uint8), 1): QImage.Format.Format_Grayscale8, + (np.dtype(np.uint8), 3): QImage.Format.Format_RGB888, + (np.dtype(np.uint8), 4): QImage.Format.Format_RGBA8888, + (np.dtype(np.uint16), 1): QImage.Format.Format_Grayscale16, + (np.dtype(np.uint16), 3): QImage.Format.Format_RGB16, + (np.dtype(np.uint16), 4): QImage.Format.Format_RGBA64, + (np.dtype(np.float32), 1): QImage.Format.Format_Grayscale8, + (np.dtype(np.float32), 3): QImage.Format.Format_RGBA16FPx4, + (np.dtype(np.float32), 4): QImage.Format.Format_RGBA32FPx4, +} + + +def _normalize255( + array: np.ndarray, + normalize: tuple[bool, bool] | bool, + clip: tuple[int, int] = (0, 255), +) -> np.ndarray: + # by default, we do not want to clip in-place + # (the input array should not be modified): + clip_target = None + + if normalize: + if normalize is True: + if array.dtype == bool: + normalize = (False, True) + else: + normalize = array.min(), array.max() + if clip == (0, 255): + clip = None + elif np.isscalar(normalize): + normalize = (0, normalize) + + nmin, nmax = normalize + + if nmin: + array = array - nmin + clip_target = array + + if nmax != nmin: + if array.dtype == bool: + scale = 255.0 + else: + scale = 255.0 / (nmax - nmin) + + if scale != 1.0: + array = array * scale + clip_target = array + + if clip: + low, high = clip + array = np.clip(array, low, high, clip_target) + + return array + + +def np2qimg(data: np.ndarray) -> QImage: + if np.ndim(data) == 2: + data = data[..., None] + elif np.ndim(data) != 3: + raise ValueError("data must be 2D or 3D") + if data.shape[-1] not in (1, 3, 4): + raise ValueError( + "Last dimension must contain one (scalar/gray), " + "three (R,G,B), or four (R,G,B,A) channels" + ) + h, w, nc = data.shape + + fmt = _FORMATS.get((data.dtype, data.shape[-1])) + if fmt is None: + raise ValueError(f"Unsupported data type {data.dtype} with {nc} channels") + + if data.dtype == np.float32 and data.shape[-1] == 1: + dmin = data.min() + data = ((data - dmin) / (data.max() - dmin) * 255).astype(np.uint8) + fmt = QImage.Format.Format_Grayscale8 + print(data.shape, w, h, fmt, data.min(), data.max()) + qimage = QImage(data, w, h, fmt) + return qimage + + +class QtImageHandle: + def __init__(self, item: QGraphicsPixmapItem, data: np.ndarray) -> None: + self._data = data + self._item = item + + @property + def data(self) -> np.ndarray: + return self._data + + @data.setter + def data(self, data: np.ndarray) -> None: + self._data = data.squeeze() + self._item.setPixmap(QPixmap.fromImage(np2qimg(self._data))) + + @property + def visible(self) -> bool: + return self._item.isVisible() + + @visible.setter + def visible(self, visible: bool) -> None: + self._item.setVisible(visible) + + @property + def clim(self) -> Any: + return (0, 255) + + @clim.setter + def clim(self, clims: tuple[float, float]) -> None: + pass + + @property + def cmap(self) -> cmap.Colormap: + return cmap.Colormap("viridis") + + @cmap.setter + def cmap(self, cmap: cmap.Colormap) -> None: + pass + + def remove(self) -> None: + """Remove the image from the scene.""" + if scene := self._item.scene(): + scene.removeItem(self._item) + + +class QtViewerCanvas(QWidget): + """Vispy-based viewer for data. + + All vispy-specific code is encapsulated in this class (and non-vispy canvases + could be swapped in if needed as long as they implement the same interface). + """ + + def __init__(self, set_info: Callable[[str], None]) -> None: + super().__init__() + + # Create a QGraphicsScene which holds the graphics items + self.scene = QGraphicsScene() + self.view = QGraphicsView(self.scene, self) + self.view.setBackgroundBrush(Qt.GlobalColor.black) + + # make baground of this widget black + self.setStyleSheet("background-color: black;") + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.view) + + def qwidget(self) -> QWidget: + return self + + def refresh(self) -> None: + """Refresh the canvas.""" + self.update() + + def add_image( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> QtImageHandle: + """Add a new Image node to the scene.""" + item = QGraphicsPixmapItem(QPixmap.fromImage(np2qimg(data))) + self.scene.addItem(item) + return QtImageHandle(item, data) + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + margin: float | None = 0.01, + ) -> None: + """Update the range of the PanZoomCamera. + + When called with no arguments, the range is set to the full extent of the data. + """ + + +class ImageWindow(QWidget): + def __init__(self) -> None: + super().__init__() + + # Create a QGraphicsScene which holds the graphics items + self.scene = QGraphicsScene() + + # Create a QGraphicsView which provides a widget for displaying the contents of a QGraphicsScene + self.view = QGraphicsView(self.scene, self) + self.view.setBackgroundBrush(Qt.GlobalColor.black) + + # make baground of this widget black + self.setStyleSheet("background-color: black;") + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.view) + + # Create a QImage from random data + self.image_data = next(images) + qimage = QImage(self.image_data, *shape, QImage.Format.Format_RGB888) + + # Convert QImage to QPixmap and add it to the scene using QGraphicsPixmapItem + self.pixmap_item = QGraphicsPixmapItem(QPixmap.fromImage(qimage)) + self.scene.addItem(self.pixmap_item) + + # Use a timer to update the image + self.timer = QTimer() + self.timer.timeout.connect(self.update_image) + self.timer.start(10) + + def resizeEvent(self, event: Any) -> None: + self.fitInView() + + def fitInView(self) -> None: + # Scale view to fit the pixmap preserving the aspect ratio + if not self.pixmap_item.pixmap().isNull(): + self.view.fitInView(self.pixmap_item, Qt.AspectRatioMode.KeepAspectRatio) + + def update_image(self) -> None: + # Update the image with new random data + self.image_data = next(images) + qimage = QImage(self.image_data, *shape, QImage.Format.Format_RGB888) + self.pixmap_item.setPixmap(QPixmap.fromImage(qimage)) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = ImageWindow() + window.show() + sys.exit(app.exec()) diff --git a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py similarity index 98% rename from src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py rename to src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index c07e78f7e..895be0214 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_vispy_canvas.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -127,5 +127,5 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: text = f"[{py}, {px}]" for c, img in enumerate(images): with suppress(IndexError): - text += f" c{c}: {img._data[py, px]}" + text += f" c{c}: {img._data[py, px]:0.2f}" self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py new file mode 100644 index 000000000..a98b5119b --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py @@ -0,0 +1,94 @@ +import sys +from itertools import cycle + +import numpy as np +from OpenGL.GL import * # noqa +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import QApplication, QMainWindow, QOpenGLWidget + +shape = (1024, 1024) +images = cycle((np.random.rand(100, *shape, 3) * 255).astype(np.uint8)) + + +class GLWidget(QOpenGLWidget): + def __init__(self, parent=None) -> None: + super().__init__(parent) + self.image_data = next(images) + + def initializeGL(self) -> None: + glClearColor(0, 0, 0, 1) + glEnable(GL_TEXTURE_2D) + self.texture = glGenTextures(1) + glBindTexture(GL_TEXTURE_2D, self.texture) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + # Set unpack alignment to 1 (important for images with width not multiple of 4) + glPixelStorei(GL_UNPACK_ALIGNMENT, 1) + + def paintGL(self): + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + glBindTexture(GL_TEXTURE_2D, self.texture) + glTexImage2D( + GL_TEXTURE_2D, + 0, + GL_RGB, + *shape, + 0, + GL_RGB, + GL_UNSIGNED_BYTE, + self.image_data, + ) + + # Calculate aspect ratio of the window + width = self.width() + height = self.height() + aspect_ratio = width / height + + # Adjust vertices to maintain 1:1 aspect ratio in the center of the viewport + if aspect_ratio > 1: + # Wider than tall: limit width to match height + scale = height / width + x0, x1 = -scale, scale + y0, y1 = -1, 1 + else: + # Taller than wide: limit height to match width + scale = width / height + x0, x1 = -1, 1 + y0, y1 = -scale, scale + + glBegin(GL_QUADS) + glTexCoord2f(0, 0) + glVertex2f(x0, y0) + glTexCoord2f(1, 0) + glVertex2f(x1, y0) + glTexCoord2f(1, 1) + glVertex2f(x1, y1) + glTexCoord2f(0, 1) + glVertex2f(x0, y1) + glEnd() + + def update_image(self, new_image: np.ndarray) -> None: + self.image_data = new_image + self.update() # Request a repaint + + +class MainWindow(QMainWindow): + def __init__(self) -> None: + super().__init__() + self.gl_widget = GLWidget(self) + self.setCentralWidget(self.gl_widget) + self.timer = QTimer() + self.timer.timeout.connect(self.on_timer) + self.timer.start(1) # Update image every 100 ms + + def on_timer(self) -> None: + # Generate a new random image + new_image = (next(images)).astype(np.uint8) + self.gl_widget.update_image(new_image) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 3e5c889a1..0d9c449ce 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -14,9 +14,15 @@ from qtpy.QtGui import QMouseEvent + # any hashable represent a single dimension in a AND array DimensionKey: TypeAlias = Hashable + # any object that can be used to index a single dimension in an AND array Index: TypeAlias = int | slice + # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) + # this object is used frequently to query or set the currently displayed slice Indices: TypeAlias = Mapping[DimensionKey, Index] + # mapping of dimension keys to the maximum value for that dimension + Sizes: TypeAlias = Mapping[DimensionKey, int] class PlayButton(QPushButton): @@ -178,14 +184,13 @@ class DimsSliders(QWidget): Maintains the global current index and emits a signal when it changes. """ - valueChanged = Signal(dict) + valueChanged = Signal(dict) # dict is of type Indices def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self._sliders: dict[DimensionKey, DimsSlider] = {} self._current_index: dict[DimensionKey, Index] = {} self._invisible_dims: set[DimensionKey] = set() - self._updating = False layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -202,10 +207,10 @@ def setValue(self, values: Indices) -> None: self.add_or_update_dimension(dim, index) self.valueChanged.emit(self.value()) - def maximum(self) -> dict[DimensionKey, int]: + def maximum(self) -> Sizes: return {k: v._int_slider.maximum() for k, v in self._sliders.items()} - def setMaximum(self, values: Mapping[DimensionKey, int]) -> None: + def setMaximum(self, values: Sizes) -> None: for name, max_val in values.items(): if name not in self._sliders: self.add_dimension(name) @@ -219,35 +224,34 @@ def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: slider.forceValue(val) slider.valueChanged.connect(self._on_dim_slider_value_changed) slider.setVisible(name not in self._invisible_dims) - self.layout().addWidget(slider) + cast("QVBoxLayout", self.layout()).addWidget(slider) - def set_dimension_visible(self, name: Hashable, visible: bool) -> None: + def set_dimension_visible(self, key: DimensionKey, visible: bool) -> None: if visible: - self._invisible_dims.discard(name) + self._invisible_dims.discard(key) else: - self._invisible_dims.add(name) - if name in self._sliders: - self._sliders[name].setVisible(visible) + self._invisible_dims.add(key) + if key in self._sliders: + self._sliders[key].setVisible(visible) - def remove_dimension(self, name: str) -> None: + def remove_dimension(self, key: DimensionKey) -> None: try: - slider = self._sliders.pop(name) + slider = self._sliders.pop(key) except KeyError: - warn(f"Dimension {name} not found in DimsSliders", stacklevel=2) + warn(f"Dimension {key} not found in DimsSliders", stacklevel=2) return - self.layout().removeWidget(slider) + cast("QVBoxLayout", self.layout()).removeWidget(slider) slider.deleteLater() - def _on_dim_slider_value_changed(self, dim_name: str, value: Index) -> None: - self._current_index[dim_name] = value - if not self._updating: - self.valueChanged.emit(self.value()) + def _on_dim_slider_value_changed(self, key: DimensionKey, value: Index) -> None: + self._current_index[key] = value + self.valueChanged.emit(self.value()) - def add_or_update_dimension(self, name: DimensionKey, value: Index) -> None: - if name in self._sliders: - self._sliders[name].forceValue(value) + def add_or_update_dimension(self, key: DimensionKey, value: Index) -> None: + if key in self._sliders: + self._sliders[key].forceValue(value) else: - self.add_dimension(name, value) + self.add_dimension(key, value) if __name__ == "__main__": diff --git a/src/pymmcore_widgets/_stack_viewer2/_indexing.py b/src/pymmcore_widgets/_stack_viewer2/_indexing.py new file mode 100644 index 000000000..f1af64ac6 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_indexing.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any + +import numpy as np +from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + +# from ._pygfx_canvas import PyGFXViewerCanvas + +if TYPE_CHECKING: + import xarray as xr # noqa + + from ._dims_slider import Indices + + +def isel(store: Any, indexers: Indices) -> np.ndarray: + """Select a slice from a data store.""" + if isinstance(store, (OMEZarrWriter, OMETiffWriter)): + return isel_mmcore_5dbase(store, indexers) + if isinstance(store, np.ndarray): + return isel_np_array(store, indexers) + if not TYPE_CHECKING: + xr = sys.modules.get("xarray") + if xr and isinstance(store, xr.DataArray): + return store.isel(indexers).to_numpy() + raise NotImplementedError(f"Unknown datastore type {type(store)}") + + +def isel_np_array(data: np.ndarray, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(data.ndim)) + return data[idx] + + +def isel_mmcore_5dbase( + writer: OMEZarrWriter | OMETiffWriter, indexers: Indices +) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + raise NotImplementedError("Cannot slice over position index") # TODO + + try: + sizes = [*list(writer.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for {len(writer.position_sizes)}" + ) from e + + data = writer.position_arrays[writer.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index c106ab561..258d8ddb7 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -38,7 +38,7 @@ def __init__( self._cmap.addColormap(color) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - self._clims.setRange(0, 2**14) + self._clims.setRange(0, 2**8) self._clims.valueChanged.connect(self._on_clims_changed) self._auto_clim = QCheckBox("Auto") @@ -78,10 +78,13 @@ def update_autoscale(self) -> None: clims[0] = min(clims[0], np.nanmin(handle.data)) clims[1] = max(clims[1], np.nanmax(handle.data)) - if (clims_ := tuple(int(x) for x in clims)) != (0, 0): + mi, ma = tuple(int(x) for x in clims) + if mi != ma: for handle in self._handles: - handle.clim = clims_ + handle.clim = (mi, ma) - # set the slider values to the new clims - with signals_blocked(self._clims): - self._clims.setValue(clims_) + # set the slider values to the new clims + with signals_blocked(self._clims): + self._clims.setMinimum(min(mi, self._clims.minimum())) + self._clims.setMaximum(max(ma, self._clims.maximum())) + self._clims.setValue((mi, ma)) diff --git a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py new file mode 100644 index 000000000..a1a5f63d4 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import superqt +import useq +from psygnal import Signal as psygnalSignal +from pymmcore_plus.mda.handlers import OMEZarrWriter + +from ._stack_viewer import StackViewer + +if TYPE_CHECKING: + import numpy as np + from qtpy.QtWidgets import QWidget + + +# FIXME: get rid of this thin subclass +class DataStore(OMEZarrWriter): + frame_ready = psygnalSignal(object, useq.MDAEvent) + + def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: + super().frameReady(frame, event, meta) + self.frame_ready.emit(frame, event) + + +class MDAViewer(StackViewer): + """StackViewer specialized for pymmcore-plus MDA acquisitions.""" + + def __init__(self, *, parent: QWidget | None = None): + super().__init__(DataStore(), parent=parent, channel_axis="c") + self._datastore.frame_ready.connect(self.on_frame_ready) + + @superqt.ensure_main_thread + def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: + self.setIndex(event.index) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 9a3798d08..ed0af28fb 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -2,40 +2,29 @@ from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, cast +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Mapping, cast import cmap -import numpy as np -import superqt -import useq -from psygnal import Signal as psygnalSignal -from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from ._canvas._vispy import VispyViewerCanvas from ._dims_slider import DimsSliders +from ._indexing import isel from ._lut_control import LutControl -# from ._pygfx_canvas import PyGFXViewerCanvas -from ._vispy_canvas import VispyViewerCanvas - if TYPE_CHECKING: + import numpy as np + + from ._dims_slider import DimensionKey, Indices, Sizes from ._protocols import PCanvas, PImageHandle ColorMode = Literal["composite", "grayscale"] + GRAYS = cmap.Colormap("gray") COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] -# FIXME: get rid of this thin subclass -class DataStore(OMEZarrWriter): - frame_ready = psygnalSignal(object, useq.MDAEvent) - - def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: - super().frameReady(frame, event, meta) - self.frame_ready.emit(frame, event) - - class ColorModeButton(QPushButton): def __init__(self, parent: QWidget | None = None): modes = ["composite", "grayscale"] @@ -60,19 +49,24 @@ def __init__( data: Any, *, parent: QWidget | None = None, - channel_axis: int | str = 0, + channel_axis: DimensionKey = 0, ): super().__init__(parent=parent) - self._channels: defaultdict[Hashable, list[PImageHandle]] = defaultdict(list) - self._channel_controls: dict[Hashable, LutControl] = {} + self._channels: defaultdict[DimensionKey, list[PImageHandle]] = defaultdict( + list + ) + self._channel_controls: dict[DimensionKey, LutControl] = {} + + self._sizes: Sizes = {} + # the set of dimensions we are currently visualizing (e.g. XY) + self._visible_dims: set[DimensionKey] = set() - self._sizes = {} - self.set_data(data) self._channel_axis = channel_axis self._info_bar = QLabel("Info") self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) + # self._canvas: PCanvas = QtViewerCanvas(self._info_bar.setText) # self._canvas: PCanvas = PyGFXViewerCanvas(self._info_bar.setText) self._dims_sliders = DimsSliders() self._cmaps = cycle(COLORMAPS) @@ -85,6 +79,8 @@ def __init__( self._set_range_btn = QPushButton("reset zoom") self._set_range_btn.clicked.connect(self._set_range_clicked) + self.set_data(data) + btns = QHBoxLayout() btns.addWidget(self._channel_mode_picker) btns.addWidget(self._set_range_btn) @@ -94,32 +90,42 @@ def __init__( layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) - def set_data(self, data: Any, sizes: Mapping | None = None) -> None: + def set_data(self, data: Any, sizes: Sizes | None = None) -> None: if sizes is not None: self._sizes = dict(sizes) else: if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): self._sizes = sz elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): - self._sizes = {k: v - 1 for k, v in enumerate(shp[:-2])} + self._sizes = dict(enumerate(shp[:-2])) else: self._sizes = {} + self._datastore = data + self.set_visible_dims(list(self._sizes)[-2:]) + + def set_visible_dims(self, dims: Iterable[DimensionKey]) -> None: + self._visible_dims = set(dims) + for d in self._visible_dims: + self._dims_sliders.set_dimension_visible(d, False) @property - def sizes(self) -> Mapping[Hashable, int]: + def sizes(self) -> Sizes: return self._sizes def update_slider_maxima( - self, sizes: Any | tuple[int, ...] | Mapping[Hashable, int] | None = None + self, sizes: tuple[int, ...] | Sizes | None = None ) -> None: if sizes is None: _sizes = self.sizes elif isinstance(sizes, tuple): - _sizes = {k: v - 1 for k, v in enumerate(sizes[:-2])} + _sizes = dict(enumerate(sizes[:-2])) elif not isinstance(sizes, Mapping): raise ValueError(f"Invalid shape {sizes}") - self._dims_sliders.setMaximum(_sizes) + + for dim in list(_sizes.values())[-2:]: + self._dims_sliders.set_dimension_visible(dim, False) + self._dims_sliders.setMaximum({k: v - 1 for k, v in _sizes.items()}) def _set_range_clicked(self) -> None: self._canvas.set_range() @@ -152,21 +158,29 @@ def set_channel_mode(self, mode: ColorMode | None = None) -> None: self._update_data_for_index({**value, self._channel_axis: i}) self._canvas.refresh() - def _image_key(self, index: Mapping[str, int]) -> Hashable: + def _image_key(self, index: Indices) -> Hashable: if self._channel_mode == "composite": - return index.get("c", 0) + val = index.get(self._channel_axis, 0) + if isinstance(val, slice): + return (val.start, val.stop) + return val return 0 - def _isel(self, index: Mapping) -> np.ndarray: - return isel(self._datastore, index) + def _isel(self, index: Indices) -> np.ndarray: + idx = {k: v for k, v in index.items() if k not in self._visible_dims} + try: + return isel(self._datastore, idx) + except Exception as e: + raise type(e)(f"Failed to index data with {idx}: {e}") from e - def _on_dims_sliders_changed(self, index: dict) -> None: + def _on_dims_sliders_changed(self, index: Indices) -> None: """Set the current image index.""" c = index.get(self._channel_axis, 0) - indices = [index] + indices: list[Indices] = [index] if self._channel_mode == "composite": for i, handles in self._channels.items(): if handles and c != i: + # FIXME: type error is legit indices.append({**index, self._channel_axis: i}) for idx in indices: @@ -189,48 +203,5 @@ def _update_data_for_index(self, index: Mapping) -> None: self._channel_controls[key] = c = LutControl(channel_name, handles) cast("QVBoxLayout", self.layout()).addWidget(c) - def setIndex(self, index: Mapping[str, int]) -> None: + def setIndex(self, index: Indices) -> None: self._dims_sliders.setValue(index) - - -class MDAViewer(StackViewer): - def __init__(self, *, parent: QWidget | None = None): - super().__init__(DataStore(), parent=parent, channel_axis="c") - self._datastore.frame_ready.connect(self.on_frame_ready) - - @superqt.ensure_main_thread - def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: - self.setIndex(event.index) - - -def isel(store: Any, indexers: Mapping[str, int | slice]) -> np.ndarray: - if isinstance(store, (OMEZarrWriter, OMETiffWriter)): - return isel_mmcore_5dbase(store, indexers) - if isinstance(store, np.ndarray): - return isel_np_array(store, indexers) - raise NotImplementedError(f"Unknown datastore type {type(store)}") - - -def isel_np_array(data: np.ndarray, indexers: Mapping[str, int | slice]) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(data.ndim)) - return data[idx] - - -def isel_mmcore_5dbase( - writer: OMEZarrWriter | OMETiffWriter, indexers: Mapping[str, int | slice] -) -> np.ndarray: - p_index = indexers.get("p", 0) - if isinstance(p_index, slice): - raise NotImplementedError("Cannot slice over position index") # TODO - - try: - sizes = [*list(writer.position_sizes[p_index]), "y", "x"] - except IndexError as e: - raise IndexError( - f"Position index {p_index} out of range for {len(writer.position_sizes)}" - ) from e - - data = writer.position_arrays[writer.get_position_key(p_index)] - full = slice(None, None) - index = tuple(indexers.get(k, full) for k in sizes) - return data[index] # type: ignore diff --git a/x.py b/x.py index ecd5f2014..28c8d632f 100644 --- a/x.py +++ b/x.py @@ -1,20 +1,21 @@ import useq -from pymmcore_plus import CMMCorePlus -from pymmcore_plus.mda.handlers import OMEZarrWriter +from rich import print -core = CMMCorePlus() -core.loadSystemConfiguration() seq = useq.MDASequence( channels=["DAPI", "FITC"], - stage_positions=[(1, 2, 3)], + stage_positions=[ + (1, 2, 3), + { + "x": 4, + "y": 5, + "z": 6, + "sequence": useq.MDASequence(grid_plan={"rows": 2, "columns": 1}), + }, + ], time_plan={"interval": 0, "loops": 3}, - grid_plan={"rows": 2, "columns": 1}, z_plan={"range": 2, "step": 0.7}, ) -writer = OMEZarrWriter() -core.mda.run(seq, output=writer) -xa = writer.as_xarray() -da = xa["p0"] -print(da) -print(da.dims) +print("main", seq.sizes) +print("p0", seq.stage_positions[0].sequence) +print("p1", seq.stage_positions[1].sequence.sizes) diff --git a/zz.py b/zz.py new file mode 100644 index 000000000..fde0470c7 --- /dev/null +++ b/zz.py @@ -0,0 +1,115 @@ +import sys +from itertools import cycle +from typing import Any + +import numpy as np +from qtpy.QtCore import Qt, QTimer +from qtpy.QtGui import QImage, QPixmap +from qtpy.QtWidgets import ( + QApplication, + QGraphicsPixmapItem, + QGraphicsScene, + QGraphicsView, + QVBoxLayout, + QWidget, +) + +shape = (512, 512) +images = cycle((np.random.rand(100, *shape) * 255).astype(np.uint8)) + + +def np2qimg(data: np.ndarray) -> QImage: + if np.ndim(data) == 2: + data = data[..., None] + elif np.ndim(data) != 3: + raise ValueError("data must be 2D or 3D") + if data.shape[-1] not in (1, 2, 3, 4): + raise ValueError( + "Last dimension must contain one (scalar/gray), two (gray+alpha), " + "three (R,G,B), or four (R,G,B,A) channels" + ) + h, w, nc = data.shape + np_dtype = data.dtype + hasAlpha = nc in (2, 4) + isRGB = nc in (3, 4) + if np_dtype == np.uint8: + if hasAlpha: + fmt = QImage.Format.Format_RGBA8888 + elif isRGB: + fmt = QImage.Format.Format_RGB888 + else: + fmt = QImage.Format.Format_Grayscale8 + elif np_dtype == np.uint16: + if hasAlpha: + fmt = QImage.Format.Format_RGBA64 + elif isRGB: + fmt = QImage.Format.Format_RGB16 + else: + fmt = QImage.Format.Format_Grayscale16 + elif np_dtype == np.float32: + if hasAlpha: + fmt = QImage.Format.Format_RGBA32FPx4 + elif isRGB: + fmt = QImage.Format.Format_RGBA16FPx4 + else: + dmin = data.min() + data = ((data - dmin) / (data.max() - dmin) * 255).astype(np.uint8) + fmt = QImage.Format.Format_Grayscale8 + qimage = QImage(data, w, h, fmt) + return qimage + + +class ImageWindow(QWidget): + def __init__(self) -> None: + super().__init__() + + # Create a QGraphicsScene which holds the graphics items + self.scene = QGraphicsScene() + + # Create a QGraphicsView which provides a widget for displaying the contents of a QGraphicsScene + self.view = QGraphicsView(self.scene, self) + self.view.setBackgroundBrush(Qt.GlobalColor.black) + + # make baground of this widget black + self.setStyleSheet("background-color: black;") + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.view) + + # Create a QImage from random data + self.add_image() + + # Use a timer to update the image + self.timer = QTimer() + self.timer.timeout.connect(self.update_image) + self.timer.start(10) + + def add_image(self) -> None: + self.image_data = next(images) + qimage = np2qimg(self.image_data) + + # Convert QImage to QPixmap and add it to the scene using QGraphicsPixmapItem + self.pixmap_item = QGraphicsPixmapItem(QPixmap.fromImage(qimage)) + self.scene.addItem(self.pixmap_item) + + def resizeEvent(self, event: Any) -> None: + self.fitInView() + + def fitInView(self) -> None: + # Scale view to fit the pixmap preserving the aspect ratio + if not self.pixmap_item.pixmap().isNull(): + self.view.fitInView(self.pixmap_item, Qt.AspectRatioMode.KeepAspectRatio) + + def update_image(self) -> None: + # Update the image with new random data + self.image_data = next(images) + qimage = np2qimg(self.image_data) + self.pixmap_item.setPixmap(QPixmap.fromImage(qimage)) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = ImageWindow() + window.show() + sys.exit(app.exec()) From d727dcd4c4a006207d00993d663bec7e643fc568 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 3 May 2024 17:10:40 -0400 Subject: [PATCH 12/73] wip --- examples/stack_viewer_numpy.py | 2 +- .../_stack_viewer2/_canvas/_vispy.py | 2 +- .../_stack_viewer2/_dims_slider.py | 22 ++++++++++++---- .../_stack_viewer2/_stack_viewer.py | 25 +++++++++++-------- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer_numpy.py index f6fe394f2..798331020 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer_numpy.py @@ -53,7 +53,7 @@ def generate_5d_sine_wave( if __name__ == "__main__": qapp = QtWidgets.QApplication([]) - v = StackViewer(sine_wave_5d) + v = StackViewer(sine_wave_5d, channel_axis=2) v.show() v.update_slider_maxima() v.setIndex({0: 1, 1: 0, 2: 0}) diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index 895be0214..2388c49d1 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -127,5 +127,5 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: text = f"[{py}, {px}]" for c, img in enumerate(images): with suppress(IndexError): - text += f" c{c}: {img._data[py, px]:0.2f}" + text += f" {c}: {img._data[py, px]:0.2f}" self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 0d9c449ce..527b3a95e 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -3,9 +3,16 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn -from qtpy.QtCore import Qt, Signal -from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QLabeledRangeSlider, QLabeledSlider +from qtpy.QtCore import QSize, Qt, Signal +from qtpy.QtWidgets import ( + QHBoxLayout, + QLabel, + QPushButton, + QSlider, + QVBoxLayout, + QWidget, +) +from superqt import QLabeledRangeSlider from superqt.iconify import QIconifyIcon from superqt.utils import signals_blocked @@ -37,6 +44,7 @@ def __init__(self, text: str = "", parent: QWidget | None = None) -> None: super().__init__(icn, text, parent) self.setCheckable(True) self.setMaximumWidth(22) + self.setIconSize(QSize(10, 10)) class LockButton(QPushButton): @@ -49,6 +57,7 @@ def __init__(self, text: str = "", parent: QWidget | None = None) -> None: super().__init__(icn, text, parent) self.setCheckable(True) self.setMaximumWidth(20) + self.setIconSize(QSize(10, 10)) class DimsSlider(QWidget): @@ -79,10 +88,12 @@ def __init__( self._lock_btn = LockButton() self._max_label = QLabel("/ 0") - self._int_slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) + + # self._int_slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) + self._int_slider = QSlider(Qt.Orientation.Horizontal, parent=self) self._int_slider.rangeChanged.connect(self._on_range_changed) self._int_slider.valueChanged.connect(self._on_int_value_changed) - self._int_slider.layout().addWidget(self._max_label) + # self._int_slider.layout().addWidget(self._max_label) self._slice_slider = QLabeledRangeSlider(Qt.Orientation.Horizontal, parent=self) self._slice_slider.setVisible(False) @@ -92,6 +103,7 @@ def __init__( self.installEventFilter(self) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(4) layout.addWidget(self._play_btn) layout.addWidget(self._dim_label) layout.addWidget(self._int_slider) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index ed0af28fb..c9d077c8c 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -60,7 +60,7 @@ def __init__( self._sizes: Sizes = {} # the set of dimensions we are currently visualizing (e.g. XY) - self._visible_dims: set[DimensionKey] = set() + self._visualized_dims: set[DimensionKey] = set() self._channel_axis = channel_axis @@ -97,7 +97,7 @@ def set_data(self, data: Any, sizes: Sizes | None = None) -> None: if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): self._sizes = sz elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): - self._sizes = dict(enumerate(shp[:-2])) + self._sizes = dict(enumerate(shp)) else: self._sizes = {} @@ -105,8 +105,10 @@ def set_data(self, data: Any, sizes: Sizes | None = None) -> None: self.set_visible_dims(list(self._sizes)[-2:]) def set_visible_dims(self, dims: Iterable[DimensionKey]) -> None: - self._visible_dims = set(dims) - for d in self._visible_dims: + self._visualized_dims = set(dims) + for d in self._dims_sliders._sliders: + self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) + for d in self._visualized_dims: self._dims_sliders.set_dimension_visible(d, False) @property @@ -119,7 +121,7 @@ def update_slider_maxima( if sizes is None: _sizes = self.sizes elif isinstance(sizes, tuple): - _sizes = dict(enumerate(sizes[:-2])) + _sizes = dict(enumerate(sizes)) elif not isinstance(sizes, Mapping): raise ValueError(f"Invalid shape {sizes}") @@ -167,14 +169,18 @@ def _image_key(self, index: Indices) -> Hashable: return 0 def _isel(self, index: Indices) -> np.ndarray: - idx = {k: v for k, v in index.items() if k not in self._visible_dims} + idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: return isel(self._datastore, idx) except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e + def setIndex(self, index: Indices) -> None: + """Set the index of the displayed image.""" + self._dims_sliders.setValue(index) + def _on_dims_sliders_changed(self, index: Indices) -> None: - """Set the current image index.""" + """Update the displayed image when the sliders are changed.""" c = index.get(self._channel_axis, 0) indices: list[Indices] = [index] if self._channel_mode == "composite": @@ -187,7 +193,7 @@ def _on_dims_sliders_changed(self, index: Indices) -> None: self._update_data_for_index(idx) self._canvas.refresh() - def _update_data_for_index(self, index: Mapping) -> None: + def _update_data_for_index(self, index: Indices) -> None: key = self._image_key(index) data = self._isel(index) if handles := self._channels[key]: @@ -202,6 +208,3 @@ def _update_data_for_index(self, index: Mapping) -> None: channel_name = f"Channel {key}" self._channel_controls[key] = c = LutControl(channel_name, handles) cast("QVBoxLayout", self.layout()).addWidget(c) - - def setIndex(self, index: Indices) -> None: - self._dims_sliders.setValue(index) From 57187d21e03cdbdf6f9d6aa484f82d7b38abfa84 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 3 May 2024 20:58:16 -0400 Subject: [PATCH 13/73] wip --- .../_stack_viewer2/_canvas/_vispy.py | 2 +- .../_stack_viewer2/_dims_slider.py | 150 +++++++++++++++--- .../_stack_viewer2/_lut_control.py | 5 +- .../_stack_viewer2/_stack_viewer.py | 2 + 4 files changed, 132 insertions(+), 27 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index 2388c49d1..154d9d60e 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -127,5 +127,5 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: text = f"[{py}, {px}]" for c, img in enumerate(images): with suppress(IndexError): - text += f" {c}: {img._data[py, px]:0.2f}" + text += f" {c}: {round(img._data[py, px], 2)}" self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 527b3a95e..b1466e988 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -3,23 +3,27 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn -from qtpy.QtCore import QSize, Qt, Signal +from qtpy.QtCore import QPointF, QSize, Qt, Signal from qtpy.QtWidgets import ( + QDialog, QHBoxLayout, QLabel, QPushButton, + QSizePolicy, QSlider, + QSpinBox, QVBoxLayout, QWidget, ) -from superqt import QLabeledRangeSlider +from superqt import QElidingLabel, QLabeledRangeSlider from superqt.iconify import QIconifyIcon from superqt.utils import signals_blocked if TYPE_CHECKING: from typing import Hashable, Mapping, TypeAlias - from qtpy.QtGui import QMouseEvent + from PyQt6.QtGui import QResizeEvent + from qtpy.QtGui import QKeyEvent # any hashable represent a single dimension in a AND array DimensionKey: TypeAlias = Hashable @@ -32,32 +36,94 @@ Sizes: TypeAlias = Mapping[DimensionKey, int] +SS = """ +QSlider::groove:horizontal { + height: 6px; + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(128, 128, 128, 0.25), + stop:1 rgba(128, 128, 128, 0.1) + ); + border-radius: 3px; + margin: 2px 0; +} + +QSlider::handle:horizontal { + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(180, 180, 180, 1), + stop:1 rgba(180, 180, 180, 1) + ); + width: 28px; + margin: -3px 0; + border-radius: 3px; +}""" + + +class _DissmissableDialog(QDialog): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setWindowFlags( + self.windowFlags() | Qt.WindowType.FramelessWindowHint | Qt.WindowType.Popup + ) + + def keyPressEvent(self, e: QKeyEvent | None) -> None: + if e and e.key() in (Qt.Key.Key_Enter, Qt.Key.Key_Return, Qt.Key.Key_Escape): + self.accept() + print("accept") + + class PlayButton(QPushButton): """Just a styled QPushButton that toggles between play and pause icons.""" - PLAY_ICON = "fa6-solid:play" - PAUSE_ICON = "fa6-solid:pause" + fpsChanged = Signal(int) + + PLAY_ICON = "bi:play-fill" + PAUSE_ICON = "bi:pause-fill" def __init__(self, text: str = "", parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.PLAY_ICON) icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) super().__init__(icn, text, parent) self.setCheckable(True) - self.setMaximumWidth(22) - self.setIconSize(QSize(10, 10)) + self.setFixedSize(14, 18) + self.setIconSize(QSize(16, 16)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") + + # def mousePressEvent(self, e: QMouseEvent | None) -> None: + # if e and e.button() == Qt.MouseButton.RightButton: + # self._show_fps_dialog(e.globalPosition()) + # else: + # super().mousePressEvent(e) + + def _show_fps_dialog(self, pos: QPointF) -> None: + dialog = _DissmissableDialog() + + sb = QSpinBox() + sb.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) + sb.valueChanged.connect(self.fpsChanged) + + layout = QHBoxLayout(dialog) + layout.setContentsMargins(4, 0, 4, 0) + layout.addWidget(QLabel("FPS")) + layout.addWidget(sb) + + dialog.setGeometry(int(pos.x()) - 20, int(pos.y()) - 50, 40, 40) + dialog.exec() class LockButton(QPushButton): - LOCK_ICON = "fa6-solid:lock-open" - UNLOCK_ICON = "fa6-solid:lock" + LOCK_ICON = "uis:unlock" + UNLOCK_ICON = "uis:lock" def __init__(self, text: str = "", parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.LOCK_ICON) icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On) super().__init__(icn, text, parent) self.setCheckable(True) - self.setMaximumWidth(20) - self.setIconSize(QSize(10, 10)) + self.setFixedSize(20, 20) + self.setIconSize(QSize(14, 14)) + self.setStyleSheet("border: none; padding: 0; margin: 0;") class DimsSlider(QWidget): @@ -74,22 +140,29 @@ def __init__( self, dimension_key: DimensionKey, parent: QWidget | None = None ) -> None: super().__init__(parent) + self.setStyleSheet(SS) self._slice_mode = False self._animation_fps = 30 self._dim_key = dimension_key self._play_btn = PlayButton() + self._play_btn.fpsChanged.connect(self.set_fps) self._play_btn.toggled.connect(self._toggle_animation) - self._dim_label = QLabel(str(dimension_key)) + self._dim_label = QElidingLabel(str(dimension_key).upper()) # note, this lock button only prevents the slider from updating programmatically # using self.setValue, it doesn't prevent the user from changing the value. self._lock_btn = LockButton() - self._max_label = QLabel("/ 0") + self._pos_label = QSpinBox() + self._pos_label.valueChanged.connect(self._on_pos_label_edited) + self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) + self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) + self._pos_label.setStyleSheet( + "border: none; padding: 0; margin: 0; background: transparent" + ) - # self._int_slider = QLabeledSlider(Qt.Orientation.Horizontal, parent=self) self._int_slider = QSlider(Qt.Orientation.Horizontal, parent=self) self._int_slider.rangeChanged.connect(self._on_range_changed) self._int_slider.valueChanged.connect(self._on_int_value_changed) @@ -97,22 +170,36 @@ def __init__( self._slice_slider = QLabeledRangeSlider(Qt.Orientation.Horizontal, parent=self) self._slice_slider.setVisible(False) - self._slice_slider.rangeChanged.connect(self._on_range_changed) - self._slice_slider.valueChanged.connect(self._on_slice_value_changed) + # self._slice_slider.rangeChanged.connect(self._on_range_changed) + # self._slice_slider.valueChanged.connect(self._on_slice_value_changed) self.installEventFilter(self) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(4) + layout.setSpacing(2) layout.addWidget(self._play_btn) layout.addWidget(self._dim_label) layout.addWidget(self._int_slider) - layout.addWidget(self._slice_slider) + layout.addWidget(self._pos_label) layout.addWidget(self._lock_btn) - - def mouseDoubleClickEvent(self, a0: QMouseEvent | None) -> None: - self._set_slice_mode(not self._slice_mode) - return super().mouseDoubleClickEvent(a0) + self.setMinimumHeight(22) + + def resizeEvent(self, a0: QResizeEvent | None) -> None: + # align all labels + if isinstance(par := self.parent(), DimsSliders): + lbl_width = max( + x._dim_label.sizeHint().width() for x in par.findChildren(DimsSlider) + ) + self._dim_label.setFixedWidth(min(lbl_width + 2, 40)) + pos_lbl_width = max( + x._pos_label.sizeHint().width() for x in par.findChildren(DimsSlider) + ) + self._pos_label.setFixedWidth(min(pos_lbl_width + 2, 40)) + return super().resizeEvent(a0) + + # def mouseDoubleClickEvent(self, a0: QMouseEvent | None) -> None: + # self._set_slice_mode(not self._slice_mode) + # return super().mouseDoubleClickEvent(a0) def setMaximum(self, max_val: int) -> None: if max_val > self._int_slider.maximum(): @@ -178,10 +265,20 @@ def timerEvent(self, event: Any) -> None: ival = (ival + 1) % (self._int_slider.maximum() + 1) self._int_slider.setValue(ival) + def _on_pos_label_edited(self) -> None: + if self._slice_mode: + self._slice_slider.setValue( + (self._pos_label.value(), self._pos_label.value() + 1) + ) + else: + self._int_slider.setValue(self._pos_label.value()) + def _on_range_changed(self, min: int, max: int) -> None: - self._max_label.setText("/ " + str(max)) + self._pos_label.setSuffix(" / " + str(max)) + self._pos_label.setRange(min, max) def _on_int_value_changed(self, value: int) -> None: + self._pos_label.setValue(value) if not self._slice_mode: self.valueChanged.emit(self._dim_key, value) @@ -204,6 +301,8 @@ def __init__(self, parent: QWidget | None = None) -> None: self._current_index: dict[DimensionKey, Index] = {} self._invisible_dims: set[DimensionKey] = set() + self.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Maximum) + layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) @@ -271,9 +370,10 @@ def add_or_update_dimension(self, key: DimensionKey, value: Index) -> None: app = QApplication([]) w = DimsSliders() - w.add_dimension("x") - w.add_dimension("y", slice(5, 9)) + w.add_dimension("x", 5) + w.add_dimension("ysadfdasas", 20) w.add_dimension("z", 10) + w.add_dimension("w", 10) w.valueChanged.connect(print) w.show() app.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index 258d8ddb7..dad41abc5 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -30,7 +30,10 @@ def __init__( self._visible.setChecked(True) self._visible.toggled.connect(self._on_visible_changed) - self._cmap = QColormapComboBox(allow_user_colormaps=True) + self._cmap = QColormapComboBox( + allow_user_colormaps=True, add_colormap_text="Add..." + ) + self._cmap.setMinimumWidth(100) self._cmap.currentColormapChanged.connect(self._on_cmap_changed) for handle in handles: self._cmap.addColormap(handle.cmap) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index c9d077c8c..35726dc01 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -82,9 +82,11 @@ def __init__( self.set_data(data) btns = QHBoxLayout() + btns.addStretch() btns.addWidget(self._channel_mode_picker) btns.addWidget(self._set_range_btn) layout = QVBoxLayout(self) + layout.setSpacing(4) layout.addLayout(btns) layout.addWidget(self._canvas.qwidget(), 1) layout.addWidget(self._info_bar) From b066f3a4ad92b1dd09d9e354c047a5451b0133ab Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 11:51:20 -0400 Subject: [PATCH 14/73] good progress --- examples/stack_viewer2.py | 5 +- .../_stack_viewer2/_canvas/_vispy.py | 3 +- .../_stack_viewer2/_dims_slider.py | 126 +++++--- .../_stack_viewer2/_indexing.py | 64 +++- .../_stack_viewer2/_lut_control.py | 35 ++- .../_stack_viewer2/_stack_viewer.py | 283 +++++++++++++----- 6 files changed, 368 insertions(+), 148 deletions(-) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index e43ae166d..e6a81d74a 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -20,14 +20,15 @@ # {"config": "Cy5", "exposure": 20}, ), stage_positions=[(0, 0), (1, 1)], - z_plan={"range": 2, "step": 0.4}, - time_plan={"interval": 0.8, "loops": 2}, + z_plan={"range": 9, "step": 0.4}, + time_plan={"interval": 0.2, "loops": 4}, # grid_plan={"rows": 2, "columns": 1}, ) qapp = QtWidgets.QApplication([]) v = MDAViewer() +v.dims_sliders.setLocksVisible(False) v.show() mmcore.run_mda(sequence, output=v._datastore) diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index 154d9d60e..5cd1d62f8 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -40,7 +40,8 @@ def clim(self) -> Any: @clim.setter def clim(self, clims: tuple[float, float]) -> None: - self._image.clim = clims + with suppress(ZeroDivisionError): + self._image.clim = clims @property def cmap(self) -> cmap.Colormap: diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index b1466e988..410f21279 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn +from PyQt6.QtGui import QResizeEvent from qtpy.QtCore import QPointF, QSize, Qt, Signal from qtpy.QtWidgets import ( QDialog, @@ -26,38 +27,47 @@ from qtpy.QtGui import QKeyEvent # any hashable represent a single dimension in a AND array - DimensionKey: TypeAlias = Hashable + DimKey: TypeAlias = Hashable # any object that can be used to index a single dimension in an AND array Index: TypeAlias = int | slice # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) # this object is used frequently to query or set the currently displayed slice - Indices: TypeAlias = Mapping[DimensionKey, Index] + Indices: TypeAlias = Mapping[DimKey, Index] # mapping of dimension keys to the maximum value for that dimension - Sizes: TypeAlias = Mapping[DimensionKey, int] + Sizes: TypeAlias = Mapping[DimKey, int] +BAR_COLOR = "#24007AAB" SS = """ QSlider::groove:horizontal { - height: 6px; + height: 14px; background: qlineargradient( x1:0, y1:0, x2:0, y2:1, stop:0 rgba(128, 128, 128, 0.25), stop:1 rgba(128, 128, 128, 0.1) ); border-radius: 3px; - margin: 2px 0; } QSlider::handle:horizontal { background: qlineargradient( x1:0, y1:0, x2:0, y2:1, - stop:0 rgba(180, 180, 180, 1), - stop:1 rgba(180, 180, 180, 1) + stop:0 rgba(148, 148, 148, 1), + stop:1 rgba(148, 148, 148, 1) ); - width: 28px; - margin: -3px 0; + width: 32px; border-radius: 3px; -}""" +} + +QLabel { + font-size: 12px; +} + +SliderLabel { + font-size: 12px; + color: white; +} +""" class _DissmissableDialog(QDialog): @@ -136,9 +146,7 @@ class DimsSlider(QWidget): valueChanged = Signal(object, object) # where object is int | slice - def __init__( - self, dimension_key: DimensionKey, parent: QWidget | None = None - ) -> None: + def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: super().__init__(parent) self.setStyleSheet(SS) self._slice_mode = False @@ -162,16 +170,20 @@ def __init__( self._pos_label.setStyleSheet( "border: none; padding: 0; margin: 0; background: transparent" ) + self._out_of_label = QLabel() self._int_slider = QSlider(Qt.Orientation.Horizontal, parent=self) self._int_slider.rangeChanged.connect(self._on_range_changed) self._int_slider.valueChanged.connect(self._on_int_value_changed) # self._int_slider.layout().addWidget(self._max_label) - self._slice_slider = QLabeledRangeSlider(Qt.Orientation.Horizontal, parent=self) - self._slice_slider.setVisible(False) - # self._slice_slider.rangeChanged.connect(self._on_range_changed) - # self._slice_slider.valueChanged.connect(self._on_slice_value_changed) + self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) + slc._slider.barColor = BAR_COLOR + slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) + slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + slc.setVisible(False) + slc.rangeChanged.connect(self._on_range_changed) + slc.valueChanged.connect(self._on_slice_value_changed) self.installEventFilter(self) layout = QHBoxLayout(self) @@ -180,26 +192,19 @@ def __init__( layout.addWidget(self._play_btn) layout.addWidget(self._dim_label) layout.addWidget(self._int_slider) + layout.addWidget(self._slice_slider) layout.addWidget(self._pos_label) + layout.addWidget(self._out_of_label) layout.addWidget(self._lock_btn) self.setMinimumHeight(22) def resizeEvent(self, a0: QResizeEvent | None) -> None: - # align all labels if isinstance(par := self.parent(), DimsSliders): - lbl_width = max( - x._dim_label.sizeHint().width() for x in par.findChildren(DimsSlider) - ) - self._dim_label.setFixedWidth(min(lbl_width + 2, 40)) - pos_lbl_width = max( - x._pos_label.sizeHint().width() for x in par.findChildren(DimsSlider) - ) - self._pos_label.setFixedWidth(min(pos_lbl_width + 2, 40)) - return super().resizeEvent(a0) + par.resizeEvent(None) - # def mouseDoubleClickEvent(self, a0: QMouseEvent | None) -> None: - # self._set_slice_mode(not self._slice_mode) - # return super().mouseDoubleClickEvent(a0) + def mouseDoubleClickEvent(self, a0: Any) -> None: + self._set_slice_mode(not self._slice_mode) + return super().mouseDoubleClickEvent(a0) def setMaximum(self, max_val: int) -> None: if max_val > self._int_slider.maximum(): @@ -212,11 +217,12 @@ def setRange(self, min_val: int, max_val: int) -> None: self._slice_slider.setRange(min_val, max_val) def value(self) -> Index: - return ( - self._int_slider.value() - if not self._slice_mode - else slice(*self._slice_slider.value()) - ) + if not self._slice_mode: + return self._int_slider.value() + start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) + if start == stop: + return start + return slice(start, stop) def setValue(self, val: Index) -> None: # variant of setValue that always updates the maximum @@ -236,6 +242,8 @@ def forceValue(self, val: Index) -> None: self.setValue(val) def _set_slice_mode(self, mode: bool = True) -> None: + if mode == self._slice_mode: + return self._slice_mode = mode if mode: self._slice_slider.setVisible(True) @@ -243,6 +251,7 @@ def _set_slice_mode(self, mode: bool = True) -> None: else: self._int_slider.setVisible(True) self._slice_slider.setVisible(False) + self.valueChanged.emit(self._dim_key, self.value()) def set_fps(self, fps: int) -> None: self._animation_fps = fps @@ -274,8 +283,9 @@ def _on_pos_label_edited(self) -> None: self._int_slider.setValue(self._pos_label.value()) def _on_range_changed(self, min: int, max: int) -> None: - self._pos_label.setSuffix(" / " + str(max)) + self._out_of_label.setText(f"| {max}") self._pos_label.setRange(min, max) + self.resizeEvent(None) def _on_int_value_changed(self, value: int) -> None: self._pos_label.setValue(value) @@ -297,16 +307,19 @@ class DimsSliders(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self._sliders: dict[DimensionKey, DimsSlider] = {} - self._current_index: dict[DimensionKey, Index] = {} - self._invisible_dims: set[DimensionKey] = set() + self._sliders: dict[DimKey, DimsSlider] = {} + self._current_index: dict[DimKey, Index] = {} + self._invisible_dims: set[DimKey] = set() - self.setSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Maximum) + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) + def sizeHint(self) -> QSize: + return super().sizeHint().boundedTo(QSize(9999, 0)) + def value(self) -> Indices: return self._current_index.copy() @@ -327,8 +340,19 @@ def setMaximum(self, values: Sizes) -> None: self.add_dimension(name) self._sliders[name].setMaximum(max_val) - def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: + def setLocksVisible(self, visible: bool | Mapping[DimKey, bool]) -> None: + self._lock_visible = visible + for dim, slider in self._sliders.items(): + viz = visible if isinstance(visible, bool) else visible.get(dim, False) + slider._lock_btn.setVisible(viz) + + def add_dimension(self, name: DimKey, val: Index | None = None) -> None: self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) + if isinstance(self._lock_visible, dict) and name in self._lock_visible: + slider._lock_btn.setVisible(self._lock_visible[name]) + else: + slider._lock_btn.setVisible(bool(self._lock_visible)) + slider.setRange(0, 1) val = val if val is not None else 0 self._current_index[name] = val @@ -337,7 +361,7 @@ def add_dimension(self, name: DimensionKey, val: Index | None = None) -> None: slider.setVisible(name not in self._invisible_dims) cast("QVBoxLayout", self.layout()).addWidget(slider) - def set_dimension_visible(self, key: DimensionKey, visible: bool) -> None: + def set_dimension_visible(self, key: DimKey, visible: bool) -> None: if visible: self._invisible_dims.discard(key) else: @@ -345,7 +369,7 @@ def set_dimension_visible(self, key: DimensionKey, visible: bool) -> None: if key in self._sliders: self._sliders[key].setVisible(visible) - def remove_dimension(self, key: DimensionKey) -> None: + def remove_dimension(self, key: DimKey) -> None: try: slider = self._sliders.pop(key) except KeyError: @@ -354,16 +378,26 @@ def remove_dimension(self, key: DimensionKey) -> None: cast("QVBoxLayout", self.layout()).removeWidget(slider) slider.deleteLater() - def _on_dim_slider_value_changed(self, key: DimensionKey, value: Index) -> None: + def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: self._current_index[key] = value self.valueChanged.emit(self.value()) - def add_or_update_dimension(self, key: DimensionKey, value: Index) -> None: + def add_or_update_dimension(self, key: DimKey, value: Index) -> None: if key in self._sliders: self._sliders[key].forceValue(value) else: self.add_dimension(key, value) + def resizeEvent(self, a0: QResizeEvent | None) -> None: + # align all labels + if sliders := list(self._sliders.values()): + for lbl in ("_dim_label", "_pos_label", "_out_of_label"): + lbl_width = max(getattr(s, lbl).sizeHint().width() for s in sliders) + for s in sliders: + getattr(s, lbl).setFixedWidth(lbl_width) + + return super().resizeEvent(a0) + if __name__ == "__main__": from qtpy.QtWidgets import QApplication @@ -372,7 +406,7 @@ def add_or_update_dimension(self, key: DimensionKey, value: Index) -> None: w = DimsSliders() w.add_dimension("x", 5) w.add_dimension("ysadfdasas", 20) - w.add_dimension("z", 10) + w.add_dimension("z", slice(10, 20)) w.add_dimension("w", 10) w.valueChanged.connect(print) w.show() diff --git a/src/pymmcore_widgets/_stack_viewer2/_indexing.py b/src/pymmcore_widgets/_stack_viewer2/_indexing.py index f1af64ac6..5281278fe 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer2/_indexing.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any +import warnings +from typing import TYPE_CHECKING import numpy as np from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter @@ -9,27 +10,61 @@ # from ._pygfx_canvas import PyGFXViewerCanvas if TYPE_CHECKING: - import xarray as xr # noqa + from typing import Any, Protocol, TypeGuard - from ._dims_slider import Indices + import dask.array as da + import numpy.typing as npt + import xarray as xr + + from ._dims_slider import Index, Indices + + class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + +def is_xr_dataarray(obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + +def is_dask_array(obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + +def is_duck_array(obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + ): + return True + return False def isel(store: Any, indexers: Indices) -> np.ndarray: - """Select a slice from a data store.""" + """Select a slice from a data store using (possibly) named indices. + + For xarray.DataArray, use the built-in isel method. + For any other duck-typed array, use numpy-style indexing, where indexers + is a mapping of axis to slice objects or indices. + """ if isinstance(store, (OMEZarrWriter, OMETiffWriter)): return isel_mmcore_5dbase(store, indexers) - if isinstance(store, np.ndarray): - return isel_np_array(store, indexers) - if not TYPE_CHECKING: - xr = sys.modules.get("xarray") - if xr and isinstance(store, xr.DataArray): + if is_xr_dataarray(store): return store.isel(indexers).to_numpy() - raise NotImplementedError(f"Unknown datastore type {type(store)}") + if is_duck_array(store): + return isel_np_array(store, indexers) + raise NotImplementedError(f"Don't know how to index into type {type(store)}") -def isel_np_array(data: np.ndarray, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(data.ndim)) - return data[idx] +def isel_np_array(data: SupportsIndexing, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(data.shape))) + return np.asarray(data[idx]) def isel_mmcore_5dbase( @@ -37,7 +72,8 @@ def isel_mmcore_5dbase( ) -> np.ndarray: p_index = indexers.get("p", 0) if isinstance(p_index, slice): - raise NotImplementedError("Cannot slice over position index") # TODO + warnings.warn("Cannot slice over position index", stacklevel=2) # TODO + p_index = p_index.start try: sizes = [*list(writer.position_sizes[p_index]), "y", "x"] diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index dad41abc5..554ab9803 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -4,17 +4,32 @@ import numpy as np from qtpy.QtCore import Qt -from qtpy.QtWidgets import QCheckBox, QHBoxLayout, QWidget +from qtpy.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QPushButton, QWidget from superqt import QLabeledRangeSlider from superqt.cmap import QColormapComboBox from superqt.utils import signals_blocked +from ._dims_slider import BAR_COLOR, SS + if TYPE_CHECKING: import cmap from ._protocols import PImageHandle +class CmapCombo(QColormapComboBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") + self.setMinimumSize(100, 22) + self.setStyleSheet("background-color: transparent;") + + def showPopup(self) -> None: + super().showPopup() + popup = self.findChild(QFrame) + popup.setMinimumWidth(self.width() + 100) + popup.move(popup.x(), popup.y() - self.height() - popup.height()) + + class LutControl(QWidget): def __init__( self, @@ -30,10 +45,7 @@ def __init__( self._visible.setChecked(True) self._visible.toggled.connect(self._on_visible_changed) - self._cmap = QColormapComboBox( - allow_user_colormaps=True, add_colormap_text="Add..." - ) - self._cmap.setMinimumWidth(100) + self._cmap = CmapCombo() self._cmap.currentColormapChanged.connect(self._on_cmap_changed) for handle in handles: self._cmap.addColormap(handle.cmap) @@ -41,12 +53,21 @@ def __init__( self._cmap.addColormap(color) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + if hasattr(self._clims, "_slider"): + self._clims._slider.barColor = BAR_COLOR + self._clims.setStyleSheet(SS) + self._clims.setHandleLabelPosition( + QLabeledRangeSlider.LabelPosition.LabelsOnHandle + ) + self._clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) self._clims.setRange(0, 2**8) self._clims.valueChanged.connect(self._on_clims_changed) - self._auto_clim = QCheckBox("Auto") - self._auto_clim.toggled.connect(self.update_autoscale) + self._auto_clim = QPushButton("Auto") + self._auto_clim.setMaximumWidth(42) + self._auto_clim.setCheckable(True) self._auto_clim.setChecked(True) + self._auto_clim.toggled.connect(self.update_autoscale) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 35726dc01..e7794432f 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -2,10 +2,12 @@ from collections import defaultdict from itertools import cycle -from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Mapping, cast +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast import cmap +import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QCollapsible, QIconifyIcon from ._canvas._vispy import VispyViewerCanvas from ._dims_slider import DimsSliders @@ -13,12 +15,15 @@ from ._lut_control import LutControl if TYPE_CHECKING: - import numpy as np + from typing import Any, Callable, Hashable, Literal, TypeAlias - from ._dims_slider import DimensionKey, Indices, Sizes + from ._dims_slider import DimKey, Indices, Sizes from ._protocols import PCanvas, PImageHandle ColorMode = Literal["composite", "grayscale"] + ImgKey: TypeAlias = Hashable + # any mapping of dimensions to sizes + SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] GRAYS = cmap.Colormap("gray") @@ -42,127 +47,206 @@ def mode(self) -> ColorMode: class StackViewer(QWidget): - """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events.""" + """A viewer for ND arrays.""" def __init__( self, data: Any, *, parent: QWidget | None = None, - channel_axis: DimensionKey = 0, + channel_axis: DimKey | None = None, ): super().__init__(parent=parent) - self._channels: defaultdict[DimensionKey, list[PImageHandle]] = defaultdict( - list - ) - self._channel_controls: dict[DimensionKey, LutControl] = {} + # ATTRIBUTES ---------------------------------------------------- + # dimensions of the data in the datastore self._sizes: Sizes = {} + # mapping of key to a list of objects that control image nodes in the canvas + self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) + # mapping of same keys to the LutControl objects control image display props + self._lut_ctrls: dict[ImgKey, LutControl] = {} # the set of dimensions we are currently visualizing (e.g. XY) - self._visualized_dims: set[DimensionKey] = set() - + # this is used to control which dimensions have sliders and the behavior + # of isel when selecting data from the datastore + self._visualized_dims: set[DimKey] = set() + # the axis that represents the channels in the data self._channel_axis = channel_axis - - self._info_bar = QLabel("Info") - self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) - # self._canvas: PCanvas = QtViewerCanvas(self._info_bar.setText) - # self._canvas: PCanvas = PyGFXViewerCanvas(self._info_bar.setText) - self._dims_sliders = DimsSliders() + # colormaps that will be cycled through when displaying composite images + # TODO: allow user to set this self._cmaps = cycle(COLORMAPS) - self.set_channel_mode("grayscale") - self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) + # WIDGETS ---------------------------------------------------- + # the button that controls the display mode of the channels self._channel_mode_picker = ColorModeButton() self._channel_mode_picker.clicked.connect(self.set_channel_mode) + # button to reset the zoom of the canvas self._set_range_btn = QPushButton("reset zoom") - self._set_range_btn.clicked.connect(self._set_range_clicked) + self._set_range_btn.clicked.connect(self._on_set_range_clicked) - self.set_data(data) + # place to display arbitrary text + self._info_bar = QLabel("Info") + # the canvas that displays the images + self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders() + self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) + + self._lut_drop = QCollapsible("LUTs") + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down")) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up")) + lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) + lut_layout.setContentsMargins(0, 1, 0, 1) + lut_layout.setSpacing(0) + if hasattr(self._lut_drop, "_content") and ( + layout := self._lut_drop._content.layout() + ): + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # LAYOUT ----------------------------------------------------- btns = QHBoxLayout() + btns.setContentsMargins(0, 0, 0, 0) + btns.setSpacing(0) btns.addStretch() btns.addWidget(self._channel_mode_picker) btns.addWidget(self._set_range_btn) layout = QVBoxLayout(self) - layout.setSpacing(4) - layout.addLayout(btns) + layout.setSpacing(3) + layout.setContentsMargins(6, 6, 6, 6) layout.addWidget(self._canvas.qwidget(), 1) layout.addWidget(self._info_bar) layout.addWidget(self._dims_sliders) + layout.addWidget(self._lut_drop) + layout.addLayout(btns) - def set_data(self, data: Any, sizes: Sizes | None = None) -> None: - if sizes is not None: - self._sizes = dict(sizes) - else: + # SETUP ------------------------------------------------------ + + self.set_data(data) + self.set_channel_mode("grayscale") + + # ------------------- PUBLIC API ---------------------------- + + def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: + """Set the datastore, and, optionally, the sizes of the data.""" + if sizes is None: if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): - self._sizes = sz + sizes = sz elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): - self._sizes = dict(enumerate(shp)) - else: - self._sizes = {} - + sizes = shp + self._sizes = _to_sizes(sizes) self._datastore = data - self.set_visible_dims(list(self._sizes)[-2:]) + if self._channel_axis is None: + self._channel_axis = self._guess_channel_axis(data) + self.set_visualized_dims(list(self._sizes)[-2:]) - def set_visible_dims(self, dims: Iterable[DimensionKey]) -> None: + def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: + """Set the dimensions that will be visualized. + + This dims will NOT have sliders associated with them. + """ self._visualized_dims = set(dims) for d in self._dims_sliders._sliders: self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) for d in self._visualized_dims: self._dims_sliders.set_dimension_visible(d, False) + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + @property def sizes(self) -> Sizes: + """Return sizes {dimkey: int} of the dimensions in the datastore.""" return self._sizes - def update_slider_maxima( - self, sizes: tuple[int, ...] | Sizes | None = None - ) -> None: + def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: + """Set the maximum values of the sliders. + + If `sizes` is not provided, sizes will be inferred from the datastore. + """ if sizes is None: - _sizes = self.sizes - elif isinstance(sizes, tuple): - _sizes = dict(enumerate(sizes)) - elif not isinstance(sizes, Mapping): - raise ValueError(f"Invalid shape {sizes}") + sizes = self.sizes + sizes = _to_sizes(sizes) + self._dims_sliders.setMaximum({k: v - 1 for k, v in sizes.items()}) - for dim in list(_sizes.values())[-2:]: + # FIXME: this needs to be moved and made user-controlled + for dim in list(sizes.values())[-2:]: self._dims_sliders.set_dimension_visible(dim, False) - self._dims_sliders.setMaximum({k: v - 1 for k, v in _sizes.items()}) - - def _set_range_clicked(self) -> None: - self._canvas.set_range() def set_channel_mode(self, mode: ColorMode | None = None) -> None: + """Set the mode for displaying the channels. + + In "composite" mode, the channels are displayed as a composite image, using + self._channel_axis as the channel axis. In "grayscale" mode, each channel is + displayed separately. (If mode is None, the current value of the + channel_mode_picker button is used) + """ if mode is None or isinstance(mode, bool): mode = self._channel_mode_picker.mode() if mode == getattr(self, "_channel_mode", None): return - self._cmaps = cycle(COLORMAPS) self._channel_mode = mode + # reset the colormap cycle + self._cmaps = cycle(COLORMAPS) + # set the visibility of the channel slider c_visible = mode != "composite" self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) - num_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 - value = self._dims_sliders.value() - if self._channels: - for handles in self._channels.values(): - for handle in handles: - handle.remove() - self._channels.clear() - for c in self._channel_controls.values(): - cast("QVBoxLayout", self.layout()).removeWidget(c) - c.deleteLater() - self._channel_controls.clear() - if c_visible: - self._update_data_for_index(value) - else: - for i in range(num_channels): - self._update_data_for_index({**value, self._channel_axis: i}) + + if not self._img_handles: + return + + # determine what needs to be updated + n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 + value = self._dims_sliders.value() # get before clearing + self._clear_images() + indices = ( + [value] + if c_visible + else [{**value, self._channel_axis: i} for i in range(n_channels)] + ) + + # update the displayed images + for idx in indices: + self._update_data_for_index(idx) self._canvas.refresh() - def _image_key(self, index: Indices) -> Hashable: + def setIndex(self, index: Indices) -> None: + """Set the index of the displayed image.""" + self._dims_sliders.setValue(index) + + # ------------------- PRIVATE METHODS ---------------------------- + + def _guess_channel_axis(self, data: Any) -> DimKey: + """Guess the channel axis from the data.""" + if isinstance(data, np.ndarray): + # for numpy arrays, use the smallest dimension as the channel axis + return data.shape.index(min(data.shape)) + + return 0 + + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._img_handles.values(): + for handle in handles: + handle.remove() + self._img_handles.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + + def _on_set_range_clicked(self) -> None: + self._canvas.set_range() + + def _image_key(self, index: Indices) -> ImgKey: + """Return the key for image handle(s) corresponding to `index`.""" if self._channel_mode == "composite": val = index.get(self._channel_axis, 0) if isinstance(val, slice): @@ -171,22 +255,19 @@ def _image_key(self, index: Indices) -> Hashable: return 0 def _isel(self, index: Indices) -> np.ndarray: + """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: return isel(self._datastore, idx) except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e - def setIndex(self, index: Indices) -> None: - """Set the index of the displayed image.""" - self._dims_sliders.setValue(index) - def _on_dims_sliders_changed(self, index: Indices) -> None: """Update the displayed image when the sliders are changed.""" c = index.get(self._channel_axis, 0) indices: list[Indices] = [index] if self._channel_mode == "composite": - for i, handles in self._channels.items(): + for i, handles in self._img_handles.items(): if handles and c != i: # FIXME: type error is legit indices.append({**index, self._channel_axis: i}) @@ -196,17 +277,63 @@ def _on_dims_sliders_changed(self, index: Indices) -> None: self._canvas.refresh() def _update_data_for_index(self, index: Indices) -> None: - key = self._image_key(index) - data = self._isel(index) - if handles := self._channels[key]: + """Update the displayed image for the given index. + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. + """ + imkey = self._image_key(index) + data = self._isel(index).squeeze() + data = self._reduce_dims_for_display(data) + if handles := self._img_handles[imkey]: for handle in handles: handle.data = data - if ctrl := self._channel_controls.get(key, None): + if ctrl := self._lut_ctrls.get(imkey, None): ctrl.update_autoscale() else: cm = next(self._cmaps) if self._channel_mode == "composite" else GRAYS handles.append(self._canvas.add_image(data, cmap=cm)) - if key not in self._channel_controls: - channel_name = f"Channel {key}" - self._channel_controls[key] = c = LutControl(channel_name, handles) - cast("QVBoxLayout", self.layout()).addWidget(c) + if imkey not in self._lut_ctrls: + channel_name = f"Ch {imkey}" # TODO: get name from user + self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) + c.update_autoscale() + self._lut_drop.addWidget(c) + + def _reduce_dims_for_display( + self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max + ) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + """ + # TODO + # - allow for 3d data + # - allow dimensions to control how they are reduced + # - for better way to determine which dims need to be reduced + visualized_dims = 2 + if extra_dims := data.ndim - visualized_dims: + shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) + return reductor(data, axis=smallest_dims) + return data + + +def _to_sizes(sizes: SizesLike | None) -> Sizes: + """Coerce `sizes` to a {dimKey -> int} mapping.""" + if sizes is None: + return {} + if isinstance(sizes, Mapping): + return {k: int(v) for k, v in sizes.items()} + if not isinstance(sizes, Iterable): + raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}") + _sizes: dict[Hashable, int] = {} + for i, val in enumerate(sizes): + if isinstance(val, int): + _sizes[i] = val + elif isinstance(val, Sequence) and len(val) == 2: + _sizes[val[0]] = int(val[1]) + else: + raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.") + return _sizes From fb66297e9a7763efcd6dca1c533f2a86c83f7d58 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 13:14:47 -0400 Subject: [PATCH 15/73] more progress for xarray and numpy --- examples/stack_viewer_numpy.py | 2 - examples/stack_viewer_xr.py | 10 +- .../_stack_viewer2/_canvas/_vispy.py | 8 +- .../_stack_viewer2/_dims_slider.py | 17 +-- .../_stack_viewer2/_indexing.py | 26 +++-- .../_stack_viewer2/_lut_control.py | 4 +- .../_stack_viewer2/_stack_viewer.py | 107 ++++++++++++------ 7 files changed, 111 insertions(+), 63 deletions(-) diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer_numpy.py index 798331020..87b154a8b 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer_numpy.py @@ -55,6 +55,4 @@ def generate_5d_sine_wave( qapp = QtWidgets.QApplication([]) v = StackViewer(sine_wave_5d, channel_axis=2) v.show() - v.update_slider_maxima() - v.setIndex({0: 1, 1: 0, 2: 0}) qapp.exec() diff --git a/examples/stack_viewer_xr.py b/examples/stack_viewer_xr.py index 78f37fe63..115defc9a 100644 --- a/examples/stack_viewer_xr.py +++ b/examples/stack_viewer_xr.py @@ -1,19 +1,13 @@ from __future__ import annotations # from stack_viewer_numpy import generate_5d_sine_wave -import nd2 +import and2 from qtpy import QtWidgets from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer -# array_shape = (10, 5, 3, 512, 512) # Specify the desired dimensions -# sine_wave_5d = generate_5d_sine_wave(array_shape) -# data = xr.DataArray(sine_wave_5d, dims=["a", "f", "p", "y", "x"]) - -data = and2.imread("~/dev/self/nd2/tests/data/t3p3z5c3.and2", xarray=True) +data = and2.imread("/Users/talley/Downloads/6D_test.nd2", xarray=True, dask=True) qapp = QtWidgets.QApplication([]) v = StackViewer(data, channel_axis="C") v.show() -v.update_slider_maxima() -v.setIndex({}) qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index 5cd1d62f8..962466230 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -112,6 +112,7 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" images = [] # Get the images the mouse is over + # FIXME: must be a better way to do this seen = set() while visual := self._canvas.visual_at(event.pos): if isinstance(visual, scene.visuals.Image): @@ -126,7 +127,10 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: tform = images[0].get_transform("canvas", "visual") px, py, *_ = (int(x) for x in tform.map(event.pos)) text = f"[{py}, {px}]" - for c, img in enumerate(images): + for c, img in enumerate(reversed(images)): with suppress(IndexError): - text += f" {c}: {round(img._data[py, px], 2)}" + value = img._data[py, px] + if isinstance(value, (np.floating, float)): + value = f"{value:.2f}" + text += f" {c}: {value}" self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 410f21279..fe1ca33a4 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn -from PyQt6.QtGui import QResizeEvent from qtpy.QtCore import QPointF, QSize, Qt, Signal +from qtpy.QtGui import QResizeEvent from qtpy.QtWidgets import ( QDialog, QHBoxLayout, @@ -36,11 +36,11 @@ # mapping of dimension keys to the maximum value for that dimension Sizes: TypeAlias = Mapping[DimKey, int] -BAR_COLOR = "#24007AAB" +BAR_COLOR = "#2258575B" SS = """ QSlider::groove:horizontal { - height: 14px; + height: 15px; background: qlineargradient( x1:0, y1:0, x2:0, y2:1, stop:0 rgba(128, 128, 128, 0.25), @@ -50,12 +50,12 @@ } QSlider::handle:horizontal { + width: 38px; background: qlineargradient( x1:0, y1:0, x2:0, y2:1, stop:0 rgba(148, 148, 148, 1), stop:1 rgba(148, 148, 148, 1) ); - width: 32px; border-radius: 3px; } @@ -307,6 +307,7 @@ class DimsSliders(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) + self._locks_visible = True self._sliders: dict[DimKey, DimsSlider] = {} self._current_index: dict[DimKey, Index] = {} self._invisible_dims: set[DimKey] = set() @@ -341,17 +342,17 @@ def setMaximum(self, values: Sizes) -> None: self._sliders[name].setMaximum(max_val) def setLocksVisible(self, visible: bool | Mapping[DimKey, bool]) -> None: - self._lock_visible = visible + self._locks_visible = visible for dim, slider in self._sliders.items(): viz = visible if isinstance(visible, bool) else visible.get(dim, False) slider._lock_btn.setVisible(viz) def add_dimension(self, name: DimKey, val: Index | None = None) -> None: self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) - if isinstance(self._lock_visible, dict) and name in self._lock_visible: - slider._lock_btn.setVisible(self._lock_visible[name]) + if isinstance(self._locks_visible, dict) and name in self._locks_visible: + slider._lock_btn.setVisible(self._locks_visible[name]) else: - slider._lock_btn.setVisible(bool(self._lock_visible)) + slider._lock_btn.setVisible(bool(self._locks_visible)) slider.setRange(0, 1) val = val if val is not None else 0 diff --git a/src/pymmcore_widgets/_stack_viewer2/_indexing.py b/src/pymmcore_widgets/_stack_viewer2/_indexing.py index 5281278fe..76f98d195 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer2/_indexing.py @@ -5,9 +5,6 @@ from typing import TYPE_CHECKING import numpy as np -from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - -# from ._pygfx_canvas import PyGFXViewerCanvas if TYPE_CHECKING: from typing import Any, Protocol, TypeGuard @@ -15,6 +12,7 @@ import dask.array as da import numpy.typing as npt import xarray as xr + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase from ._dims_slider import Index, Indices @@ -24,7 +22,19 @@ def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... def shape(self) -> tuple[int, ...]: ... -def is_xr_dataarray(obj: Any) -> TypeGuard[xr.DataArray]: +def is_pymmcore_writer(obj: Any) -> TypeGuard[_5DWriterBase]: + try: + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + except ImportError: + from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + + _5DWriterBase = (OMETiffWriter, OMEZarrWriter) # type: ignore + if isinstance(obj, _5DWriterBase): + return True + return False + + +def is_xarray_dataarray(obj: Any) -> TypeGuard[xr.DataArray]: if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): return True return False @@ -53,9 +63,9 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: For any other duck-typed array, use numpy-style indexing, where indexers is a mapping of axis to slice objects or indices. """ - if isinstance(store, (OMEZarrWriter, OMETiffWriter)): + if is_pymmcore_writer(store): return isel_mmcore_5dbase(store, indexers) - if is_xr_dataarray(store): + if is_xarray_dataarray(store): return store.isel(indexers).to_numpy() if is_duck_array(store): return isel_np_array(store, indexers) @@ -67,9 +77,7 @@ def isel_np_array(data: SupportsIndexing, indexers: Indices) -> np.ndarray: return np.asarray(data[idx]) -def isel_mmcore_5dbase( - writer: OMEZarrWriter | OMETiffWriter, indexers: Indices -) -> np.ndarray: +def isel_mmcore_5dbase(writer: _5DWriterBase, indexers: Indices) -> np.ndarray: p_index = indexers.get("p", 0) if isinstance(p_index, slice): warnings.warn("Cannot slice over position index", stacklevel=2) # TODO diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index 554ab9803..f03260a3f 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -20,8 +20,8 @@ class CmapCombo(QColormapComboBox): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") - self.setMinimumSize(100, 22) - self.setStyleSheet("background-color: transparent;") + self.setMinimumSize(120, 21) + # self.setStyleSheet("background-color: transparent;") def showPopup(self) -> None: super().showPopup() diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index e7794432f..0dae237fe 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,26 +1,26 @@ from __future__ import annotations from collections import defaultdict +from enum import Enum from itertools import cycle from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast import cmap import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QCollapsible, QIconifyIcon +from superqt import QCollapsible, QElidingLabel, QIconifyIcon from ._canvas._vispy import VispyViewerCanvas from ._dims_slider import DimsSliders -from ._indexing import isel +from ._indexing import is_xarray_dataarray, isel from ._lut_control import LutControl if TYPE_CHECKING: - from typing import Any, Callable, Hashable, Literal, TypeAlias + from typing import Any, Callable, Hashable, TypeAlias from ._dims_slider import DimKey, Indices, Sizes from ._protocols import PCanvas, PImageHandle - ColorMode = Literal["composite", "grayscale"] ImgKey: TypeAlias = Hashable # any mapping of dimensions to sizes SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] @@ -30,24 +30,38 @@ COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] -class ColorModeButton(QPushButton): +class ChannelMode(str, Enum): + COMPOSITE = "composite" + MONO = "mono" + + def __str__(self) -> str: + return self.value + + +class ChannelModeButton(QPushButton): def __init__(self, parent: QWidget | None = None): - modes = ["composite", "grayscale"] - self._modes = cycle(modes) - super().__init__(modes[-1], parent) - self.clicked.connect(self.next_mode) - self.next_mode() + super().__init__(parent) + self.setCheckable(True) + self.toggled.connect(self.next_mode) def next_mode(self) -> None: - self._mode = self.text() - self.setText(next(self._modes)) + if self.isChecked(): + self.setMode(ChannelMode.MONO) + else: + self.setMode(ChannelMode.COMPOSITE) + + def mode(self) -> ChannelMode: + return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE - def mode(self) -> ColorMode: - return self._mode # type: ignore + def setMode(self, mode: ChannelMode) -> None: + # we show the name of the next mode, not the current one + other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO + self.setText(str(other)) + self.setChecked(mode == ChannelMode.MONO) class StackViewer(QWidget): - """A viewer for ND arrays.""" + """A viewer for AND arrays.""" def __init__( self, @@ -55,6 +69,7 @@ def __init__( *, parent: QWidget | None = None, channel_axis: DimKey | None = None, + channel_mode: ChannelMode = ChannelMode.MONO, ): super().__init__(parent=parent) @@ -72,6 +87,7 @@ def __init__( self._visualized_dims: set[DimKey] = set() # the axis that represents the channels in the data self._channel_axis = channel_axis + self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode # colormaps that will be cycled through when displaying composite images # TODO: allow user to set this self._cmaps = cycle(COLORMAPS) @@ -79,16 +95,18 @@ def __init__( # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels - self._channel_mode_picker = ColorModeButton() - self._channel_mode_picker.clicked.connect(self.set_channel_mode) + self._channel_mode_btn = ChannelModeButton() + self._channel_mode_btn.clicked.connect(self.set_channel_mode) # button to reset the zoom of the canvas self._set_range_btn = QPushButton("reset zoom") self._set_range_btn.clicked.connect(self._on_set_range_clicked) + # place to display dataset summary + self._data_info = QElidingLabel("") # place to display arbitrary text - self._info_bar = QLabel("Info") + self._hover_info = QLabel("Info") # the canvas that displays the images - self._canvas: PCanvas = VispyViewerCanvas(self._info_bar.setText) + self._canvas: PCanvas = VispyViewerCanvas(self._hover_info.setText) # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders() self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) @@ -99,8 +117,9 @@ def __init__( lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) lut_layout.setContentsMargins(0, 1, 0, 1) lut_layout.setSpacing(0) - if hasattr(self._lut_drop, "_content") and ( - layout := self._lut_drop._content.layout() + if ( + hasattr(self._lut_drop, "_content") + and (layout := self._lut_drop._content.layout()) is not None ): layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) @@ -111,13 +130,15 @@ def __init__( btns.setContentsMargins(0, 0, 0, 0) btns.setSpacing(0) btns.addStretch() - btns.addWidget(self._channel_mode_picker) + btns.addWidget(self._channel_mode_btn) btns.addWidget(self._set_range_btn) + layout = QVBoxLayout(self) - layout.setSpacing(3) + layout.setSpacing(2) layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self._data_info) layout.addWidget(self._canvas.qwidget(), 1) - layout.addWidget(self._info_bar) + layout.addWidget(self._hover_info) layout.addWidget(self._dims_sliders) layout.addWidget(self._lut_drop) layout.addLayout(btns) @@ -125,7 +146,7 @@ def __init__( # SETUP ------------------------------------------------------ self.set_data(data) - self.set_channel_mode("grayscale") + self.set_channel_mode(channel_mode) # ------------------- PUBLIC API ---------------------------- @@ -141,6 +162,19 @@ def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: if self._channel_axis is None: self._channel_axis = self._guess_channel_axis(data) self.set_visualized_dims(list(self._sizes)[-2:]) + self.update_slider_maxima() + self.setIndex({}) + + if all(isinstance(x, int) for x in self._sizes): + size_str = repr(tuple(self._sizes.values())) + else: + size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) + size_str = f"({size_str})" + dtype = getattr(data, "dtype", "") + nbytes = getattr(data, "nbytes", "") / 1e6 + self._data_info.setText( + f"{type(data).__name__}, {size_str}, {dtype}, {nbytes}MB" + ) def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: """Set the dimensions that will be visualized. @@ -177,7 +211,7 @@ def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: for dim in list(sizes.values())[-2:]: self._dims_sliders.set_dimension_visible(dim, False) - def set_channel_mode(self, mode: ColorMode | None = None) -> None: + def set_channel_mode(self, mode: ChannelMode | None = None) -> None: """Set the mode for displaying the channels. In "composite" mode, the channels are displayed as a composite image, using @@ -186,7 +220,9 @@ def set_channel_mode(self, mode: ColorMode | None = None) -> None: channel_mode_picker button is used) """ if mode is None or isinstance(mode, bool): - mode = self._channel_mode_picker.mode() + mode = self._channel_mode_btn.mode() + else: + self._channel_mode_btn.setMode(mode) if mode == getattr(self, "_channel_mode", None): return @@ -194,7 +230,7 @@ def set_channel_mode(self, mode: ColorMode | None = None) -> None: # reset the colormap cycle self._cmaps = cycle(COLORMAPS) # set the visibility of the channel slider - c_visible = mode != "composite" + c_visible = mode != ChannelMode.COMPOSITE self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) if not self._img_handles: @@ -226,7 +262,10 @@ def _guess_channel_axis(self, data: Any) -> DimKey: if isinstance(data, np.ndarray): # for numpy arrays, use the smallest dimension as the channel axis return data.shape.index(min(data.shape)) - + if is_xarray_dataarray(data): + for d in data.dims: + if str(d).lower() in ("channel", "ch", "c"): + return d return 0 def _clear_images(self) -> None: @@ -247,7 +286,7 @@ def _on_set_range_clicked(self) -> None: def _image_key(self, index: Indices) -> ImgKey: """Return the key for image handle(s) corresponding to `index`.""" - if self._channel_mode == "composite": + if self._channel_mode == ChannelMode.COMPOSITE: val = index.get(self._channel_axis, 0) if isinstance(val, slice): return (val.start, val.stop) @@ -266,7 +305,7 @@ def _on_dims_sliders_changed(self, index: Indices) -> None: """Update the displayed image when the sliders are changed.""" c = index.get(self._channel_axis, 0) indices: list[Indices] = [index] - if self._channel_mode == "composite": + if self._channel_mode == ChannelMode.COMPOSITE: for i, handles in self._img_handles.items(): if handles and c != i: # FIXME: type error is legit @@ -291,7 +330,11 @@ def _update_data_for_index(self, index: Indices) -> None: if ctrl := self._lut_ctrls.get(imkey, None): ctrl.update_autoscale() else: - cm = next(self._cmaps) if self._channel_mode == "composite" else GRAYS + cm = ( + next(self._cmaps) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS + ) handles.append(self._canvas.add_image(data, cmap=cm)) if imkey not in self._lut_ctrls: channel_name = f"Ch {imkey}" # TODO: get name from user From f63328ee868179190157ca46bed454fabbcfa93a Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 15:22:16 -0400 Subject: [PATCH 16/73] better pygfx --- examples/stack_viewer2.py | 4 +- examples/stack_viewer_xr.py | 4 +- pyproject.toml | 13 +++- .../_stack_viewer2/_canvas/__init__.py | 38 ++++++++++++ .../_stack_viewer2/_canvas/_pygfx.py | 62 ++++++++++++++----- .../_stack_viewer2/_canvas/_vispy.py | 4 -- .../_stack_viewer2/_dims_slider.py | 4 +- .../_stack_viewer2/_mda_viewer.py | 3 +- .../_stack_viewer2/_protocols.py | 3 +- .../_stack_viewer2/_stack_viewer.py | 40 +++++++----- x.py | 21 ------- 11 files changed, 131 insertions(+), 65 deletions(-) delete mode 100644 x.py diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index e6a81d74a..0027caa63 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -28,8 +28,8 @@ qapp = QtWidgets.QApplication([]) v = MDAViewer() -v.dims_sliders.setLocksVisible(False) +v.dims_sliders.set_locks_visible(False) v.show() -mmcore.run_mda(sequence, output=v._datastore) +mmcore.run_mda(sequence, output=v._data) qapp.exec() diff --git a/examples/stack_viewer_xr.py b/examples/stack_viewer_xr.py index 115defc9a..f47922876 100644 --- a/examples/stack_viewer_xr.py +++ b/examples/stack_viewer_xr.py @@ -1,12 +1,12 @@ from __future__ import annotations # from stack_viewer_numpy import generate_5d_sine_wave -import and2 +import nd2 from qtpy import QtWidgets from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer -data = and2.imread("/Users/talley/Downloads/6D_test.nd2", xarray=True, dask=True) +data = nd2.imread("/Users/talley/Downloads/6D_test.nd2", xarray=True, dask=True) qapp = QtWidgets.QApplication([]) v = StackViewer(data, channel_axis="C") v.show() diff --git a/pyproject.toml b/pyproject.toml index 7cb1de774..30e103986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,15 @@ dependencies = [ # extras # https://peps.python.org/pep-0621/#dependencies-optional-dependencies [project.optional-dependencies] -test = ["pytest>=6.0", "pytest-cov", "pytest-qt", "PyYAML", "vispy", "cmap", "zarr"] +test = [ + "pytest>=6.0", + "pytest-cov", + "pytest-qt", + "PyYAML", + "vispy", + "cmap", + "zarr", +] pyqt5 = ["PyQt5"] pyside2 = ["PySide2"] pyqt6 = ["PyQt6"] @@ -172,3 +180,6 @@ ignore = [ "examples/**/*", "CHANGELOG.md", ] + +[tool.typos.default] +extend-ignore-identifiers-re = ["(?i)nd2?.*", "(?i)ome"] diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py index e69de29bb..0337b7772 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import importlib +import importlib.util +import os +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pymmcore_widgets._stack_viewer2._protocols import PCanvas + + +ENV_BACKEND = os.getenv("CANVAS_BACKEND", None) + + +def get_canvas(backend: str | None = ENV_BACKEND) -> type[PCanvas]: + if backend == "vispy" or (backend is None and "vispy" in sys.modules): + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + if backend is None: + if importlib.util.find_spec("vispy") is not None: + from ._vispy import VispyViewerCanvas + + return VispyViewerCanvas + + if importlib.util.find_spec("pygfx") is not None: + from ._pygfx import PyGFXViewerCanvas + + return PyGFXViewerCanvas + + raise RuntimeError("No canvas backend found") diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py index 25529a848..139237d05 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py @@ -2,30 +2,33 @@ from typing import TYPE_CHECKING, Any, Callable, cast +import numpy as np import pygfx -import pygfx.geometries -import pygfx.materials +from qtpy.QtCore import QSize from wgpu.gui.qt import QWgpuCanvas if TYPE_CHECKING: import cmap - import numpy as np + from pygfx.materials import ImageBasicMaterial + from pygfx.resources import Texture from qtpy.QtWidgets import QWidget class PyGFXImageHandle: - def __init__(self, image: pygfx.Image) -> None: + def __init__(self, image: pygfx.Image, render: Callable) -> None: self._image = image - self._geom = cast("pygfx.geometries.Geometry", image.geometry.grid) - self._material = cast("pygfx.materials.ImageBasicMaterial", image.material) + self._render = render + self._grid = cast("Texture", image.geometry.grid) + self._material = cast("ImageBasicMaterial", image.material) @property def data(self) -> np.ndarray: - return self._geom._data # type: ignore + return self._grid.data # type: ignore @data.setter def data(self, data: np.ndarray) -> None: - self._geom.grid = pygfx.Texture(data, dim=2) + self._grid.data[:] = data + self._grid.update_range((0, 0, 0), self._grid.size) @property def visible(self) -> bool: @@ -34,6 +37,7 @@ def visible(self) -> bool: @visible.setter def visible(self, visible: bool) -> None: self._image.visible = visible + self._render() @property def clim(self) -> Any: @@ -42,6 +46,7 @@ def clim(self) -> Any: @clim.setter def clim(self, clims: tuple[float, float]) -> None: self._material.clim = clims + self._render() @property def cmap(self) -> cmap.Colormap: @@ -51,12 +56,18 @@ def cmap(self) -> cmap.Colormap: def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap self._material.map = cmap.to_pygfx() + self._render() def remove(self) -> None: if (par := self._image.parent) is not None: par.remove(self._image) +class _QWgpuCanvas(QWgpuCanvas): + def sizeHint(self) -> QSize: + return QSize(512, 512) + + class PyGFXViewerCanvas: """Vispy-based viewer for data. @@ -67,16 +78,17 @@ class PyGFXViewerCanvas: def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info - self._canvas = QWgpuCanvas(size=(512, 512)) + self._canvas = _QWgpuCanvas(size=(512, 512)) self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) + self._renderer.blend_mode = "weighted" self._scene = pygfx.Scene() self._camera = cam = pygfx.OrthographicCamera(512, 512) + cam.local.scale_y = -1 cam.local.position = (256, 256, 0) - cam.local.scale_y = -1 - controller = pygfx.PanZoomController(cam, register_events=self._renderer) + self._controller = pygfx.PanZoomController(cam, register_events=self._renderer) # increase zoom wheel gain - controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) + self._controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) def qwidget(self) -> QWidget: return cast("QWidget", self._canvas) @@ -86,7 +98,6 @@ def refresh(self) -> None: self._canvas.request_draw(self._animate) def _animate(self) -> None: - print("animate") self._renderer.render(self._scene, self._camera) def add_image( @@ -95,10 +106,13 @@ def add_image( """Add a new Image node to the scene.""" image = pygfx.Image( pygfx.Geometry(grid=pygfx.Texture(data, dim=2)), - pygfx.ImageBasicMaterial(), + # depth_test=False for additive-like blending + pygfx.ImageBasicMaterial(depth_test=False), ) self._scene.add(image) - handle = PyGFXImageHandle(image) + # FIXME: I suspect there are more performant ways to refresh the canvas + # look into it. + handle = PyGFXImageHandle(image, self.refresh) if cmap is not None: handle.cmap = cmap return handle @@ -107,13 +121,27 @@ def set_range( self, x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, - margin: float | None = 0.05, + margin: float = 0.05, ) -> None: """Update the range of the PanZoomCamera. When called with no arguments, the range is set to the full extent of the data. """ - # self._camera.set_range(x=x, y=y, margin=margin) + if not self._scene.children: + return + + cam = self._camera + cam.show_object(self._scene) + + width, height, depth = np.ptp(self._scene.get_world_bounding_box(), axis=0) + if width < 0.01: + width = 1 + if height < 0.01: + height = 1 + cam.width = width + cam.height = height + cam.zoom = 1 - margin + self.refresh() # def _on_mouse_move(self, event: SceneMouseEvent) -> None: # """Mouse moved on the canvas, display the pixel value and position.""" diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py index 962466230..8c60b4fb9 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py @@ -22,8 +22,6 @@ def data(self) -> np.ndarray: @data.setter def data(self, data: np.ndarray) -> None: - if data.dtype == np.float64: - data = data.astype(np.float32) self._image.set_data(data) @property @@ -83,8 +81,6 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" - if data is not None and data.dtype == np.float64: - data = data.astype(np.float32) img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index fe1ca33a4..0ca5aff63 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -307,7 +307,7 @@ class DimsSliders(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self._locks_visible = True + self._locks_visible: bool | Mapping[DimKey, bool] = False self._sliders: dict[DimKey, DimsSlider] = {} self._current_index: dict[DimKey, Index] = {} self._invisible_dims: set[DimKey] = set() @@ -341,7 +341,7 @@ def setMaximum(self, values: Sizes) -> None: self.add_dimension(name) self._sliders[name].setMaximum(max_val) - def setLocksVisible(self, visible: bool | Mapping[DimKey, bool]) -> None: + def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: self._locks_visible = visible for dim, slider in self._sliders.items(): viz = visible if isinstance(visible, bool) else visible.get(dim, False) diff --git a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py index a1a5f63d4..9d001db8f 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py @@ -28,7 +28,8 @@ class MDAViewer(StackViewer): def __init__(self, *, parent: QWidget | None = None): super().__init__(DataStore(), parent=parent, channel_axis="c") - self._datastore.frame_ready.connect(self.on_frame_ready) + self._data.frame_ready.connect(self.on_frame_ready) + self.dims_sliders.set_locks_visible(True) @superqt.ensure_main_thread def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: diff --git a/src/pymmcore_widgets/_stack_viewer2/_protocols.py b/src/pymmcore_widgets/_stack_viewer2/_protocols.py index 8c34543f5..8271431fe 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_protocols.py +++ b/src/pymmcore_widgets/_stack_viewer2/_protocols.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Callable, Protocol if TYPE_CHECKING: import cmap @@ -29,6 +29,7 @@ def remove(self) -> None: ... class PCanvas(Protocol): + def __init__(self, set_text: Callable[[str], Any]) -> None: ... def set_range( self, x: tuple[float, float] | None = None, diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 0dae237fe..4bea69548 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -10,7 +10,7 @@ from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon -from ._canvas._vispy import VispyViewerCanvas +from ._canvas import get_canvas from ._dims_slider import DimsSliders from ._indexing import is_xarray_dataarray, isel from ._lut_control import LutControl @@ -106,7 +106,7 @@ def __init__( # place to display arbitrary text self._hover_info = QLabel("Info") # the canvas that displays the images - self._canvas: PCanvas = VispyViewerCanvas(self._hover_info.setText) + self._canvas: PCanvas = get_canvas()(self._hover_info.setText) # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders() self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) @@ -150,6 +150,11 @@ def __init__( # ------------------- PUBLIC API ---------------------------- + @property + def data(self) -> Any: + """Return the data backing the view.""" + return self._data + def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: """Set the datastore, and, optionally, the sizes of the data.""" if sizes is None: @@ -158,23 +163,27 @@ def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): sizes = shp self._sizes = _to_sizes(sizes) - self._datastore = data + self._data = data if self._channel_axis is None: self._channel_axis = self._guess_channel_axis(data) self.set_visualized_dims(list(self._sizes)[-2:]) self.update_slider_maxima() self.setIndex({}) - if all(isinstance(x, int) for x in self._sizes): - size_str = repr(tuple(self._sizes.values())) - else: - size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) - size_str = f"({size_str})" - dtype = getattr(data, "dtype", "") - nbytes = getattr(data, "nbytes", "") / 1e6 - self._data_info.setText( - f"{type(data).__name__}, {size_str}, {dtype}, {nbytes}MB" - ) + info = f"{getattr(type(data), '__qualname__', '')}" + + if self._sizes: + if all(isinstance(x, int) for x in self._sizes): + size_str = repr(tuple(self._sizes.values())) + else: + size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + self._data_info.setText(info) def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: """Set the dimensions that will be visualized. @@ -297,7 +306,7 @@ def _isel(self, index: Indices) -> np.ndarray: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: - return isel(self._datastore, idx) + return isel(self._data, idx) except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e @@ -360,6 +369,9 @@ def _reduce_dims_for_display( shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) return reductor(data, axis=smallest_dims) + + if data.dtype == np.float64: + data = data.astype(np.float32) return data diff --git a/x.py b/x.py deleted file mode 100644 index 28c8d632f..000000000 --- a/x.py +++ /dev/null @@ -1,21 +0,0 @@ -import useq -from rich import print - -seq = useq.MDASequence( - channels=["DAPI", "FITC"], - stage_positions=[ - (1, 2, 3), - { - "x": 4, - "y": 5, - "z": 6, - "sequence": useq.MDASequence(grid_plan={"rows": 2, "columns": 1}), - }, - ], - time_plan={"interval": 0, "loops": 3}, - z_plan={"range": 2, "step": 0.7}, -) - -print("main", seq.sizes) -print("p0", seq.stage_positions[0].sequence) -print("p1", seq.stage_positions[1].sequence.sizes) From a7cb3e015d1bab2c816268be71ad4aaf267168c5 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 15:45:18 -0400 Subject: [PATCH 17/73] linting and cleanup --- .pre-commit-config.yaml | 2 +- examples/stack_viewer_numpy.py | 1 + pyproject.toml | 2 +- .../{_canvas => _backends}/__init__.py | 6 +- .../{_canvas => _backends}/_pygfx.py | 6 +- .../{_canvas => _backends}/_qt.py | 0 .../{_canvas => _backends}/_vispy.py | 2 +- .../_stack_viewer2/_canvas/gl.py | 94 -------------- .../_stack_viewer2/_dims_slider.py | 6 +- .../_stack_viewer2/_indexing.py | 5 +- .../_stack_viewer2/_lut_control.py | 4 +- .../_stack_viewer2/_mda_viewer.py | 4 +- .../_stack_viewer2/_protocols.py | 10 +- .../_stack_viewer2/_stack_viewer.py | 13 +- y.py | 35 ------ zz.py | 115 ------------------ 16 files changed, 30 insertions(+), 275 deletions(-) rename src/pymmcore_widgets/_stack_viewer2/{_canvas => _backends}/__init__.py (88%) rename src/pymmcore_widgets/_stack_viewer2/{_canvas => _backends}/_pygfx.py (96%) rename src/pymmcore_widgets/_stack_viewer2/{_canvas => _backends}/_qt.py (100%) rename src/pymmcore_widgets/_stack_viewer2/{_canvas => _backends}/_vispy.py (99%) delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py delete mode 100644 y.py delete mode 100644 zz.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0f838337..746da736f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,5 +27,5 @@ repos: - id: mypy files: "^src/" additional_dependencies: - - pymmcore-plus >=0.9.0 + - pymmcore-plus >=0.9.5 - useq-schema >=0.4.7 diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer_numpy.py index 87b154a8b..0b283e5e7 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer_numpy.py @@ -11,6 +11,7 @@ def generate_5d_sine_wave( amplitude: float = 240, base_frequency: float = 5, ) -> np.ndarray: + """5D dataset.""" # Unpack the dimensions angle_dim, freq_dim, phase_dim, ny, nx = shape diff --git a/pyproject.toml b/pyproject.toml index 30e103986..2ce5d6e07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ 'fonticon-materialdesignicons6', - 'pymmcore-plus[cli] >=0.9.0', + 'pymmcore-plus[cli] >=0.9.5', 'qtpy >=2.0', 'superqt[quantity] >=0.5.3', 'useq-schema >=0.4.7', diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py b/src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py similarity index 88% rename from src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py rename to src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py index 0337b7772..7683e8cff 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py @@ -10,10 +10,8 @@ from pymmcore_widgets._stack_viewer2._protocols import PCanvas -ENV_BACKEND = os.getenv("CANVAS_BACKEND", None) - - -def get_canvas(backend: str | None = ENV_BACKEND) -> type[PCanvas]: +def get_canvas(backend: str | None = None) -> type[PCanvas]: + backend = backend or os.getenv("CANVAS_BACKEND", None) if backend == "vispy" or (backend is None and "vispy" in sys.modules): from ._vispy import VispyViewerCanvas diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py b/src/pymmcore_widgets/_stack_viewer2/_backends/_pygfx.py similarity index 96% rename from src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py rename to src/pymmcore_widgets/_stack_viewer2/_backends/_pygfx.py index 139237d05..bbd1eedab 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_pygfx.py +++ b/src/pymmcore_widgets/_stack_viewer2/_backends/_pygfx.py @@ -69,11 +69,7 @@ def sizeHint(self) -> QSize: class PyGFXViewerCanvas: - """Vispy-based viewer for data. - - All vispy-specific code is encapsulated in this class (and non-vispy canvases - could be swapped in if needed as long as they implement the same interface). - """ + """pygfx-based canvas wrapper.""" def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py b/src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_canvas/_qt.py rename to src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py b/src/pymmcore_widgets/_stack_viewer2/_backends/_vispy.py similarity index 99% rename from src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py rename to src/pymmcore_widgets/_stack_viewer2/_backends/_vispy.py index 8c60b4fb9..d7eee7533 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer2/_backends/_vispy.py @@ -96,7 +96,7 @@ def set_range( self, x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, - margin: float | None = 0.01, + margin: float = 0.01, ) -> None: """Update the range of the PanZoomCamera. diff --git a/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py b/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py deleted file mode 100644 index a98b5119b..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_canvas/gl.py +++ /dev/null @@ -1,94 +0,0 @@ -import sys -from itertools import cycle - -import numpy as np -from OpenGL.GL import * # noqa -from qtpy.QtCore import QTimer -from qtpy.QtWidgets import QApplication, QMainWindow, QOpenGLWidget - -shape = (1024, 1024) -images = cycle((np.random.rand(100, *shape, 3) * 255).astype(np.uint8)) - - -class GLWidget(QOpenGLWidget): - def __init__(self, parent=None) -> None: - super().__init__(parent) - self.image_data = next(images) - - def initializeGL(self) -> None: - glClearColor(0, 0, 0, 1) - glEnable(GL_TEXTURE_2D) - self.texture = glGenTextures(1) - glBindTexture(GL_TEXTURE_2D, self.texture) - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST) - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST) - # Set unpack alignment to 1 (important for images with width not multiple of 4) - glPixelStorei(GL_UNPACK_ALIGNMENT, 1) - - def paintGL(self): - glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) - glBindTexture(GL_TEXTURE_2D, self.texture) - glTexImage2D( - GL_TEXTURE_2D, - 0, - GL_RGB, - *shape, - 0, - GL_RGB, - GL_UNSIGNED_BYTE, - self.image_data, - ) - - # Calculate aspect ratio of the window - width = self.width() - height = self.height() - aspect_ratio = width / height - - # Adjust vertices to maintain 1:1 aspect ratio in the center of the viewport - if aspect_ratio > 1: - # Wider than tall: limit width to match height - scale = height / width - x0, x1 = -scale, scale - y0, y1 = -1, 1 - else: - # Taller than wide: limit height to match width - scale = width / height - x0, x1 = -1, 1 - y0, y1 = -scale, scale - - glBegin(GL_QUADS) - glTexCoord2f(0, 0) - glVertex2f(x0, y0) - glTexCoord2f(1, 0) - glVertex2f(x1, y0) - glTexCoord2f(1, 1) - glVertex2f(x1, y1) - glTexCoord2f(0, 1) - glVertex2f(x0, y1) - glEnd() - - def update_image(self, new_image: np.ndarray) -> None: - self.image_data = new_image - self.update() # Request a repaint - - -class MainWindow(QMainWindow): - def __init__(self) -> None: - super().__init__() - self.gl_widget = GLWidget(self) - self.setCentralWidget(self.gl_widget) - self.timer = QTimer() - self.timer.timeout.connect(self.on_timer) - self.timer.start(1) # Update image every 100 ms - - def on_timer(self) -> None: - # Generate a new random image - new_image = (next(images)).astype(np.uint8) - self.gl_widget.update_image(new_image) - - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = MainWindow() - window.show() - sys.exit(app.exec_()) diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py index 0ca5aff63..c195fd0d4 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py @@ -204,7 +204,7 @@ def resizeEvent(self, a0: QResizeEvent | None) -> None: def mouseDoubleClickEvent(self, a0: Any) -> None: self._set_slice_mode(not self._slice_mode) - return super().mouseDoubleClickEvent(a0) + super().mouseDoubleClickEvent(a0) def setMaximum(self, max_val: int) -> None: if max_val > self._int_slider.maximum(): @@ -218,7 +218,7 @@ def setRange(self, min_val: int, max_val: int) -> None: def value(self) -> Index: if not self._slice_mode: - return self._int_slider.value() + return self._int_slider.value() # type: ignore start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) if start == stop: return start @@ -397,7 +397,7 @@ def resizeEvent(self, a0: QResizeEvent | None) -> None: for s in sliders: getattr(s, lbl).setFixedWidth(lbl_width) - return super().resizeEvent(a0) + super().resizeEvent(a0) if __name__ == "__main__": diff --git a/src/pymmcore_widgets/_stack_viewer2/_indexing.py b/src/pymmcore_widgets/_stack_viewer2/_indexing.py index 76f98d195..f3df119ea 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer2/_indexing.py @@ -2,7 +2,7 @@ import sys import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np @@ -66,7 +66,7 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: if is_pymmcore_writer(store): return isel_mmcore_5dbase(store, indexers) if is_xarray_dataarray(store): - return store.isel(indexers).to_numpy() + return cast("np.ndarray", store.isel(indexers).to_numpy()) if is_duck_array(store): return isel_np_array(store, indexers) raise NotImplementedError(f"Don't know how to index into type {type(store)}") @@ -82,6 +82,7 @@ def isel_mmcore_5dbase(writer: _5DWriterBase, indexers: Indices) -> np.ndarray: if isinstance(p_index, slice): warnings.warn("Cannot slice over position index", stacklevel=2) # TODO p_index = p_index.start + p_index = cast(int, p_index) try: sizes = [*list(writer.position_sizes[p_index]), "y", "x"] diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py index f03260a3f..fc1a78535 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer2/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Iterable, cast import numpy as np from qtpy.QtCore import Qt @@ -77,7 +77,7 @@ def __init__( layout.addWidget(self._auto_clim) def autoscaleChecked(self) -> bool: - return self._auto_clim.isChecked() + return cast("bool", self._auto_clim.isChecked()) def _on_clims_changed(self, clims: tuple[float, float]) -> None: self._auto_clim.setChecked(False) diff --git a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py index 9d001db8f..3d182a580 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py @@ -31,6 +31,6 @@ def __init__(self, *, parent: QWidget | None = None): self._data.frame_ready.connect(self.on_frame_ready) self.dims_sliders.set_locks_visible(True) - @superqt.ensure_main_thread + @superqt.ensure_main_thread # type: ignore def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: - self.setIndex(event.index) + self.setIndex(event.index) # type: ignore diff --git a/src/pymmcore_widgets/_stack_viewer2/_protocols.py b/src/pymmcore_widgets/_stack_viewer2/_protocols.py index 8271431fe..8b8d5d67a 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_protocols.py +++ b/src/pymmcore_widgets/_stack_viewer2/_protocols.py @@ -29,15 +29,15 @@ def remove(self) -> None: ... class PCanvas(Protocol): - def __init__(self, set_text: Callable[[str], Any]) -> None: ... + def __init__(self, set_info: Callable[[str], None]) -> None: ... def set_range( self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - margin: float | None = 0.05, + x: tuple[float, float] | None = ..., + y: tuple[float, float] | None = ..., + margin: float = ..., ) -> None: ... def refresh(self) -> None: ... def qwidget(self) -> QWidget: ... def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... ) -> PImageHandle: ... diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 4bea69548..7ddf4a92a 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections import defaultdict from enum import Enum from itertools import cycle @@ -10,7 +11,7 @@ from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon -from ._canvas import get_canvas +from ._backends import get_canvas from ._dims_slider import DimsSliders from ._indexing import is_xarray_dataarray, isel from ._lut_control import LutControl @@ -274,7 +275,7 @@ def _guess_channel_axis(self, data: Any) -> DimKey: if is_xarray_dataarray(data): for d in data.dims: if str(d).lower() in ("channel", "ch", "c"): - return d + return cast("DimKey", d) return 0 def _clear_images(self) -> None: @@ -316,9 +317,11 @@ def _on_dims_sliders_changed(self, index: Indices) -> None: indices: list[Indices] = [index] if self._channel_mode == ChannelMode.COMPOSITE: for i, handles in self._img_handles.items(): - if handles and c != i: - # FIXME: type error is legit - indices.append({**index, self._channel_axis: i}) + if isinstance(i, (int, slice)): + if handles and c != i: + indices.append({**index, self._channel_axis: i}) + else: # pragma: no cover + warnings.warn(f"Invalid key for composite image: {i}", stacklevel=2) for idx in indices: self._update_data_for_index(idx) diff --git a/y.py b/y.py deleted file mode 100644 index 3787a15bb..000000000 --- a/y.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy as np -import xarray as xr -import zarr -from zarr.storage import KVStore - -# Create an in-memory store -memory_store = KVStore({}) - -# Create some data -data = np.random.randn(3, 2, 512, 512) # Shape corresponding to (t, c, y, x) - -# Create a Zarr group in the memory store -root = zarr.group(store=memory_store, overwrite=True) - -# Add dimensions and coordinates -t = np.array([0, 1, 2]) # Time coordinates -c = np.array(["DAPI", "FITC"]) # Channel labels - -# Create the dataset within the group -dset = root.create_dataset("data", data=data, chunks=(1, 1, 256, 256), dtype="float32") - -# Add attributes for xarray compatibility -dset.attrs["_ARRAY_DIMENSIONS"] = ["t", "c", "y", "x"] - -# Create coordinate datasets -root["t"] = t -# root['c'] = c -root["t"].attrs["_ARRAY_DIMENSIONS"] = ["t"] -# root['c'].attrs['_ARRAY_DIMENSIONS'] = ['c'] - -# Open the Zarr group with xarray directly using the in-memory store -ds = xr.open_zarr(memory_store, consolidated=False) - -# Print the xarray dataset -print(ds["data"]) diff --git a/zz.py b/zz.py deleted file mode 100644 index fde0470c7..000000000 --- a/zz.py +++ /dev/null @@ -1,115 +0,0 @@ -import sys -from itertools import cycle -from typing import Any - -import numpy as np -from qtpy.QtCore import Qt, QTimer -from qtpy.QtGui import QImage, QPixmap -from qtpy.QtWidgets import ( - QApplication, - QGraphicsPixmapItem, - QGraphicsScene, - QGraphicsView, - QVBoxLayout, - QWidget, -) - -shape = (512, 512) -images = cycle((np.random.rand(100, *shape) * 255).astype(np.uint8)) - - -def np2qimg(data: np.ndarray) -> QImage: - if np.ndim(data) == 2: - data = data[..., None] - elif np.ndim(data) != 3: - raise ValueError("data must be 2D or 3D") - if data.shape[-1] not in (1, 2, 3, 4): - raise ValueError( - "Last dimension must contain one (scalar/gray), two (gray+alpha), " - "three (R,G,B), or four (R,G,B,A) channels" - ) - h, w, nc = data.shape - np_dtype = data.dtype - hasAlpha = nc in (2, 4) - isRGB = nc in (3, 4) - if np_dtype == np.uint8: - if hasAlpha: - fmt = QImage.Format.Format_RGBA8888 - elif isRGB: - fmt = QImage.Format.Format_RGB888 - else: - fmt = QImage.Format.Format_Grayscale8 - elif np_dtype == np.uint16: - if hasAlpha: - fmt = QImage.Format.Format_RGBA64 - elif isRGB: - fmt = QImage.Format.Format_RGB16 - else: - fmt = QImage.Format.Format_Grayscale16 - elif np_dtype == np.float32: - if hasAlpha: - fmt = QImage.Format.Format_RGBA32FPx4 - elif isRGB: - fmt = QImage.Format.Format_RGBA16FPx4 - else: - dmin = data.min() - data = ((data - dmin) / (data.max() - dmin) * 255).astype(np.uint8) - fmt = QImage.Format.Format_Grayscale8 - qimage = QImage(data, w, h, fmt) - return qimage - - -class ImageWindow(QWidget): - def __init__(self) -> None: - super().__init__() - - # Create a QGraphicsScene which holds the graphics items - self.scene = QGraphicsScene() - - # Create a QGraphicsView which provides a widget for displaying the contents of a QGraphicsScene - self.view = QGraphicsView(self.scene, self) - self.view.setBackgroundBrush(Qt.GlobalColor.black) - - # make baground of this widget black - self.setStyleSheet("background-color: black;") - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.view) - - # Create a QImage from random data - self.add_image() - - # Use a timer to update the image - self.timer = QTimer() - self.timer.timeout.connect(self.update_image) - self.timer.start(10) - - def add_image(self) -> None: - self.image_data = next(images) - qimage = np2qimg(self.image_data) - - # Convert QImage to QPixmap and add it to the scene using QGraphicsPixmapItem - self.pixmap_item = QGraphicsPixmapItem(QPixmap.fromImage(qimage)) - self.scene.addItem(self.pixmap_item) - - def resizeEvent(self, event: Any) -> None: - self.fitInView() - - def fitInView(self) -> None: - # Scale view to fit the pixmap preserving the aspect ratio - if not self.pixmap_item.pixmap().isNull(): - self.view.fitInView(self.pixmap_item, Qt.AspectRatioMode.KeepAspectRatio) - - def update_image(self) -> None: - # Update the image with new random data - self.image_data = next(images) - qimage = np2qimg(self.image_data) - self.pixmap_item.setPixmap(QPixmap.fromImage(qimage)) - - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = ImageWindow() - window.show() - sys.exit(app.exec()) From 2215d50d3deac81c6773e60e446d51376626205d Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 16:16:22 -0400 Subject: [PATCH 18/73] save btn --- .../_stack_viewer2/_mda_viewer.py | 3 ++ .../_stack_viewer2/_save_button.py | 50 +++++++++++++++++++ .../_stack_viewer2/_stack_viewer.py | 6 ++- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 src/pymmcore_widgets/_stack_viewer2/_save_button.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py index 3d182a580..7999dea7b 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py @@ -7,6 +7,7 @@ from psygnal import Signal as psygnalSignal from pymmcore_plus.mda.handlers import OMEZarrWriter +from ._save_button import SaveButton from ._stack_viewer import StackViewer if TYPE_CHECKING: @@ -28,6 +29,8 @@ class MDAViewer(StackViewer): def __init__(self, *, parent: QWidget | None = None): super().__init__(DataStore(), parent=parent, channel_axis="c") + self._save_btn = SaveButton(self.data) + self._btns.addWidget(self._save_btn) self._data.frame_ready.connect(self.on_frame_ready) self.dims_sliders.set_locks_visible(True) diff --git a/src/pymmcore_widgets/_stack_viewer2/_save_button.py b/src/pymmcore_widgets/_stack_viewer2/_save_button.py new file mode 100644 index 000000000..c526ab258 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer2/_save_button.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt.iconify import QIconifyIcon + +from ._indexing import is_xarray_dataarray + + +class SaveButton(QPushButton): + def __init__( + self, + datastore: Any, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + self.setIcon(QIconifyIcon("mdi:content-save")) + self.clicked.connect(self._on_click) + + self._data = datastore + self._last_loc = str(Path.home()) + + def _on_click(self) -> None: + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + _save_as_zarr(self._last_loc, self._data) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") + + +def _save_as_zarr(save_loc: str | Path, data: Any) -> None: + import zarr + from pymmcore_plus.mda.handlers import OMEZarrWriter + + if isinstance(data, OMEZarrWriter): + zarr.copy_store(data.group.store, zarr.DirectoryStore(save_loc)) + elif isinstance(data, zarr.Array): + data.store = zarr.DirectoryStore(save_loc) + elif isinstance(data, np.ndarray): + zarr.save(str(save_loc), data) + elif is_xarray_dataarray(data): + data.to_zarr(save_loc) + else: + raise ValueError(f"Cannot save data of type {type(data)} to Zarr format.") diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py index 7ddf4a92a..9410fde70 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py @@ -99,7 +99,9 @@ def __init__( self._channel_mode_btn = ChannelModeButton() self._channel_mode_btn.clicked.connect(self.set_channel_mode) # button to reset the zoom of the canvas - self._set_range_btn = QPushButton("reset zoom") + self._set_range_btn = QPushButton( + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "" + ) self._set_range_btn.clicked.connect(self._on_set_range_clicked) # place to display dataset summary @@ -127,7 +129,7 @@ def __init__( # LAYOUT ----------------------------------------------------- - btns = QHBoxLayout() + self._btns = btns = QHBoxLayout() btns.setContentsMargins(0, 0, 0, 0) btns.setSpacing(0) btns.addStretch() From 417274144b04f39258678c1107cfb692cfe805a7 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 16:16:52 -0400 Subject: [PATCH 19/73] remove qt backend --- .../_stack_viewer2/_backends/_qt.py | 243 ------------------ 1 file changed, 243 deletions(-) delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py b/src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py deleted file mode 100644 index e72ebcd36..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_backends/_qt.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import annotations - -import sys -from typing import Any, Callable - -import cmap -import numpy as np -from qtpy.QtCore import Qt, QTimer -from qtpy.QtGui import QImage, QPixmap -from qtpy.QtWidgets import ( - QApplication, - QGraphicsPixmapItem, - QGraphicsScene, - QGraphicsView, - QVBoxLayout, - QWidget, -) - -_FORMATS: dict[tuple[np.dtype, int], QImage.Format] = { - (np.dtype(np.uint8), 1): QImage.Format.Format_Grayscale8, - (np.dtype(np.uint8), 3): QImage.Format.Format_RGB888, - (np.dtype(np.uint8), 4): QImage.Format.Format_RGBA8888, - (np.dtype(np.uint16), 1): QImage.Format.Format_Grayscale16, - (np.dtype(np.uint16), 3): QImage.Format.Format_RGB16, - (np.dtype(np.uint16), 4): QImage.Format.Format_RGBA64, - (np.dtype(np.float32), 1): QImage.Format.Format_Grayscale8, - (np.dtype(np.float32), 3): QImage.Format.Format_RGBA16FPx4, - (np.dtype(np.float32), 4): QImage.Format.Format_RGBA32FPx4, -} - - -def _normalize255( - array: np.ndarray, - normalize: tuple[bool, bool] | bool, - clip: tuple[int, int] = (0, 255), -) -> np.ndarray: - # by default, we do not want to clip in-place - # (the input array should not be modified): - clip_target = None - - if normalize: - if normalize is True: - if array.dtype == bool: - normalize = (False, True) - else: - normalize = array.min(), array.max() - if clip == (0, 255): - clip = None - elif np.isscalar(normalize): - normalize = (0, normalize) - - nmin, nmax = normalize - - if nmin: - array = array - nmin - clip_target = array - - if nmax != nmin: - if array.dtype == bool: - scale = 255.0 - else: - scale = 255.0 / (nmax - nmin) - - if scale != 1.0: - array = array * scale - clip_target = array - - if clip: - low, high = clip - array = np.clip(array, low, high, clip_target) - - return array - - -def np2qimg(data: np.ndarray) -> QImage: - if np.ndim(data) == 2: - data = data[..., None] - elif np.ndim(data) != 3: - raise ValueError("data must be 2D or 3D") - if data.shape[-1] not in (1, 3, 4): - raise ValueError( - "Last dimension must contain one (scalar/gray), " - "three (R,G,B), or four (R,G,B,A) channels" - ) - h, w, nc = data.shape - - fmt = _FORMATS.get((data.dtype, data.shape[-1])) - if fmt is None: - raise ValueError(f"Unsupported data type {data.dtype} with {nc} channels") - - if data.dtype == np.float32 and data.shape[-1] == 1: - dmin = data.min() - data = ((data - dmin) / (data.max() - dmin) * 255).astype(np.uint8) - fmt = QImage.Format.Format_Grayscale8 - print(data.shape, w, h, fmt, data.min(), data.max()) - qimage = QImage(data, w, h, fmt) - return qimage - - -class QtImageHandle: - def __init__(self, item: QGraphicsPixmapItem, data: np.ndarray) -> None: - self._data = data - self._item = item - - @property - def data(self) -> np.ndarray: - return self._data - - @data.setter - def data(self, data: np.ndarray) -> None: - self._data = data.squeeze() - self._item.setPixmap(QPixmap.fromImage(np2qimg(self._data))) - - @property - def visible(self) -> bool: - return self._item.isVisible() - - @visible.setter - def visible(self, visible: bool) -> None: - self._item.setVisible(visible) - - @property - def clim(self) -> Any: - return (0, 255) - - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: - pass - - @property - def cmap(self) -> cmap.Colormap: - return cmap.Colormap("viridis") - - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: - pass - - def remove(self) -> None: - """Remove the image from the scene.""" - if scene := self._item.scene(): - scene.removeItem(self._item) - - -class QtViewerCanvas(QWidget): - """Vispy-based viewer for data. - - All vispy-specific code is encapsulated in this class (and non-vispy canvases - could be swapped in if needed as long as they implement the same interface). - """ - - def __init__(self, set_info: Callable[[str], None]) -> None: - super().__init__() - - # Create a QGraphicsScene which holds the graphics items - self.scene = QGraphicsScene() - self.view = QGraphicsView(self.scene, self) - self.view.setBackgroundBrush(Qt.GlobalColor.black) - - # make baground of this widget black - self.setStyleSheet("background-color: black;") - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.view) - - def qwidget(self) -> QWidget: - return self - - def refresh(self) -> None: - """Refresh the canvas.""" - self.update() - - def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> QtImageHandle: - """Add a new Image node to the scene.""" - item = QGraphicsPixmapItem(QPixmap.fromImage(np2qimg(data))) - self.scene.addItem(item) - return QtImageHandle(item, data) - - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - margin: float | None = 0.01, - ) -> None: - """Update the range of the PanZoomCamera. - - When called with no arguments, the range is set to the full extent of the data. - """ - - -class ImageWindow(QWidget): - def __init__(self) -> None: - super().__init__() - - # Create a QGraphicsScene which holds the graphics items - self.scene = QGraphicsScene() - - # Create a QGraphicsView which provides a widget for displaying the contents of a QGraphicsScene - self.view = QGraphicsView(self.scene, self) - self.view.setBackgroundBrush(Qt.GlobalColor.black) - - # make baground of this widget black - self.setStyleSheet("background-color: black;") - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self.view) - - # Create a QImage from random data - self.image_data = next(images) - qimage = QImage(self.image_data, *shape, QImage.Format.Format_RGB888) - - # Convert QImage to QPixmap and add it to the scene using QGraphicsPixmapItem - self.pixmap_item = QGraphicsPixmapItem(QPixmap.fromImage(qimage)) - self.scene.addItem(self.pixmap_item) - - # Use a timer to update the image - self.timer = QTimer() - self.timer.timeout.connect(self.update_image) - self.timer.start(10) - - def resizeEvent(self, event: Any) -> None: - self.fitInView() - - def fitInView(self) -> None: - # Scale view to fit the pixmap preserving the aspect ratio - if not self.pixmap_item.pixmap().isNull(): - self.view.fitInView(self.pixmap_item, Qt.AspectRatioMode.KeepAspectRatio) - - def update_image(self) -> None: - # Update the image with new random data - self.image_data = next(images) - qimage = QImage(self.image_data, *shape, QImage.Format.Format_RGB888) - self.pixmap_item.setPixmap(QPixmap.fromImage(qimage)) - - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = ImageWindow() - window.show() - sys.exit(app.exec()) From 63e61d9c12bbcb1acc9dc4730a2694bc03b1603e Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 16:31:41 -0400 Subject: [PATCH 20/73] move and rename --- examples/stack_viewer2.py | 2 +- examples/stack_viewer_numpy.py | 2 +- examples/stack_viewer_xr.py | 2 +- .../_stack_viewer/__init__.py | 5 +- .../_backends/__init__.py | 2 +- .../_backends/_pygfx.py | 0 .../_backends/_vispy.py | 0 .../_dims_slider.py | 0 .../_indexing.py | 0 .../_lut_control.py | 0 .../_stack_viewer/_mda_viewer.py | 49 + .../_protocols.py | 0 .../_stack_viewer/_save_button.py | 83 +- .../_stack_viewer/_stack_viewer.py | 838 ++++++++---------- .../_stack_viewer2/__init__.py | 0 .../_stack_viewer2/_mda_viewer.py | 39 - .../_stack_viewer2/_save_button.py | 50 -- .../_stack_viewer2/_stack_viewer.py | 399 --------- .../_stack_viewer_v1/__init__.py | 5 + .../_channel_row.py | 0 .../_datastore.py | 0 .../_labeled_slider.py | 0 .../_stack_viewer_v1/_save_button.py | 69 ++ .../_stack_viewer_v1/_stack_viewer.py | 489 ++++++++++ src/pymmcore_widgets/experimental.py | 5 +- tests/test_datastore.py | 26 - ..._stack_viewer.py => test_stack_viewer1.py} | 25 +- 27 files changed, 1051 insertions(+), 1039 deletions(-) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_backends/__init__.py (93%) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_backends/_pygfx.py (100%) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_backends/_vispy.py (100%) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_dims_slider.py (100%) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_indexing.py (100%) rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_lut_control.py (100%) create mode 100644 src/pymmcore_widgets/_stack_viewer/_mda_viewer.py rename src/pymmcore_widgets/{_stack_viewer2 => _stack_viewer}/_protocols.py (100%) delete mode 100644 src/pymmcore_widgets/_stack_viewer2/__init__.py delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_save_button.py delete mode 100644 src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py create mode 100644 src/pymmcore_widgets/_stack_viewer_v1/__init__.py rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v1}/_channel_row.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v1}/_datastore.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v1}/_labeled_slider.py (100%) create mode 100644 src/pymmcore_widgets/_stack_viewer_v1/_save_button.py create mode 100644 src/pymmcore_widgets/_stack_viewer_v1/_stack_viewer.py delete mode 100644 tests/test_datastore.py rename tests/{test_stack_viewer.py => test_stack_viewer1.py} (86%) diff --git a/examples/stack_viewer2.py b/examples/stack_viewer2.py index 0027caa63..27d9020eb 100644 --- a/examples/stack_viewer2.py +++ b/examples/stack_viewer2.py @@ -4,7 +4,7 @@ from qtpy import QtWidgets from useq import MDASequence -from pymmcore_widgets._stack_viewer2._mda_viewer import MDAViewer +from pymmcore_widgets._stack_viewer._mda_viewer import MDAViewer configure_logging(stderr_level="WARNING") diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer_numpy.py index 0b283e5e7..e49ed2163 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer_numpy.py @@ -3,7 +3,7 @@ import numpy as np from qtpy import QtWidgets -from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer def generate_5d_sine_wave( diff --git a/examples/stack_viewer_xr.py b/examples/stack_viewer_xr.py index f47922876..899a2d6c7 100644 --- a/examples/stack_viewer_xr.py +++ b/examples/stack_viewer_xr.py @@ -4,7 +4,7 @@ import nd2 from qtpy import QtWidgets -from pymmcore_widgets._stack_viewer2._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer data = nd2.imread("/Users/talley/Downloads/6D_test.nd2", xarray=True, dask=True) qapp = QtWidgets.QApplication([]) diff --git a/src/pymmcore_widgets/_stack_viewer/__init__.py b/src/pymmcore_widgets/_stack_viewer/__init__.py index 2c4beb6a5..d144dff42 100644 --- a/src/pymmcore_widgets/_stack_viewer/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer/__init__.py @@ -1,5 +1,4 @@ -from ._channel_row import CMAPS -from ._datastore import QOMEZarrDatastore +from ._mda_viewer import MDAViewer from ._stack_viewer import StackViewer -__all__ = ["StackViewer", "CMAPS", "QOMEZarrDatastore"] +__all__ = ["StackViewer", "MDAViewer"] diff --git a/src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py b/src/pymmcore_widgets/_stack_viewer/_backends/__init__.py similarity index 93% rename from src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py rename to src/pymmcore_widgets/_stack_viewer/_backends/__init__.py index 7683e8cff..045d85fd1 100644 --- a/src/pymmcore_widgets/_stack_viewer2/_backends/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer/_backends/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from pymmcore_widgets._stack_viewer2._protocols import PCanvas + from pymmcore_widgets._stack_viewer._protocols import PCanvas def get_canvas(backend: str | None = None) -> type[PCanvas]: diff --git a/src/pymmcore_widgets/_stack_viewer2/_backends/_pygfx.py b/src/pymmcore_widgets/_stack_viewer/_backends/_pygfx.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_backends/_pygfx.py rename to src/pymmcore_widgets/_stack_viewer/_backends/_pygfx.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_backends/_vispy.py rename to src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_dims_slider.py rename to src/pymmcore_widgets/_stack_viewer/_dims_slider.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_indexing.py b/src/pymmcore_widgets/_stack_viewer/_indexing.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_indexing.py rename to src/pymmcore_widgets/_stack_viewer/_indexing.py diff --git a/src/pymmcore_widgets/_stack_viewer2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer/_lut_control.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_lut_control.py rename to src/pymmcore_widgets/_stack_viewer/_lut_control.py diff --git a/src/pymmcore_widgets/_stack_viewer/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer/_mda_viewer.py new file mode 100644 index 000000000..5d73e84be --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer/_mda_viewer.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any + +import superqt +import useq + +from ._save_button import SaveButton +from ._stack_viewer import StackViewer + +if TYPE_CHECKING: + from qtpy.QtWidgets import QWidget + + +class MDAViewer(StackViewer): + """StackViewer specialized for pymmcore-plus MDA acquisitions.""" + + def __init__(self, datastore: Any = None, *, parent: QWidget | None = None): + if datastore is None: + from pymmcore_plus.mda.handlers import OMEZarrWriter + + datastore = OMEZarrWriter() + + # patch the frameReady method to call the superframeReady method + # AFTER handling the event + self._superframeReady = getattr(datastore, "frameReady", None) + if callable(self._superframeReady): + datastore.frameReady = self._patched_frame_ready + else: # pragma: no cover + warnings.warn( + "MDAViewer: datastore does not have a frameReady method to patch, " + "are you sure this is a valid data handler?", + stacklevel=2, + ) + + super().__init__(datastore, parent=parent, channel_axis="c") + self._save_btn = SaveButton(self.data) + self._btns.addWidget(self._save_btn) + self.dims_sliders.set_locks_visible(True) + + def _patched_frame_ready(self, *args: Any) -> None: + self._superframeReady(*args) # type: ignore + if len(args) >= 2 and isinstance(e := args[1], useq.MDAEvent): + self._on_frame_ready(e) + + @superqt.ensure_main_thread # type: ignore + def _on_frame_ready(self, event: useq.MDAEvent) -> None: + self.setIndex(event.index) # type: ignore diff --git a/src/pymmcore_widgets/_stack_viewer2/_protocols.py b/src/pymmcore_widgets/_stack_viewer/_protocols.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer2/_protocols.py rename to src/pymmcore_widgets/_stack_viewer/_protocols.py diff --git a/src/pymmcore_widgets/_stack_viewer/_save_button.py b/src/pymmcore_widgets/_stack_viewer/_save_button.py index ce686d3bf..c526ab258 100644 --- a/src/pymmcore_widgets/_stack_viewer/_save_button.py +++ b/src/pymmcore_widgets/_stack_viewer/_save_button.py @@ -1,69 +1,50 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING +from typing import Any -import zarr -from fonticon_mdi6 import MDI6 -from qtpy.QtCore import QSize +import numpy as np from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget -from superqt import fonticon +from superqt.iconify import QIconifyIcon -from ._datastore import QOMEZarrDatastore - -if TYPE_CHECKING: - from qtpy.QtGui import QCloseEvent +from ._indexing import is_xarray_dataarray class SaveButton(QPushButton): def __init__( self, - datastore: QOMEZarrDatastore, + datastore: Any, parent: QWidget | None = None, ): super().__init__(parent=parent) - # self.setFont(QFont('Arial', 50)) - # self.setMinimumHeight(30) - self.setIcon(fonticon.icon(MDI6.content_save_outline, color="gray")) - self.setIconSize(QSize(25, 25)) - self.setFixedSize(30, 30) + self.setIcon(QIconifyIcon("mdi:content-save")) self.clicked.connect(self._on_click) - self.datastore = datastore - self.save_loc = Path.home() + self._data = datastore + self._last_loc = str(Path.home()) def _on_click(self) -> None: - self.save_loc, _ = QFileDialog.getSaveFileName(directory=str(self.save_loc)) - if self.save_loc: - self._save_as_zarr(self.save_loc) - - def _save_as_zarr(self, save_loc: str | Path) -> None: - dir_store = zarr.DirectoryStore(save_loc) - zarr.copy_store(self.datastore._group.attrs.store, dir_store) - - def closeEvent(self, a0: QCloseEvent | None) -> None: - super().closeEvent(a0) - - -if __name__ == "__main__": - from pymmcore_plus import CMMCorePlus - from qtpy.QtWidgets import QApplication - from useq import MDASequence - - mmc = CMMCorePlus() - mmc.loadSystemConfiguration() - - app = QApplication([]) - seq = MDASequence( - time_plan={"interval": 0.01, "loops": 10}, - z_plan={"range": 5, "step": 1}, - channels=[{"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}], - ) - datastore = QOMEZarrDatastore() - mmc.mda.events.sequenceStarted.connect(datastore.sequenceStarted) - mmc.mda.events.frameReady.connect(datastore.frameReady) - - widget = SaveButton(datastore) - mmc.run_mda(seq) - widget.show() - app.exec_() + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + _save_as_zarr(self._last_loc, self._data) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") + + +def _save_as_zarr(save_loc: str | Path, data: Any) -> None: + import zarr + from pymmcore_plus.mda.handlers import OMEZarrWriter + + if isinstance(data, OMEZarrWriter): + zarr.copy_store(data.group.store, zarr.DirectoryStore(save_loc)) + elif isinstance(data, zarr.Array): + data.store = zarr.DirectoryStore(save_loc) + elif isinstance(data, np.ndarray): + zarr.save(str(save_loc), data) + elif is_xarray_dataarray(data): + data.to_zarr(save_loc) + else: + raise ValueError(f"Cannot save data of type {type(data)} to Zarr format.") diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py index 9b9211736..9410fde70 100644 --- a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py @@ -1,489 +1,399 @@ from __future__ import annotations -import copy import warnings -from typing import TYPE_CHECKING, cast +from collections import defaultdict +from enum import Enum +from itertools import cycle +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast +import cmap import numpy as np -import superqt -from fonticon_mdi6 import MDI6 -from qtpy import QtCore, QtWidgets -from qtpy.QtCore import QTimer -from superqt import fonticon -from useq import MDAEvent, MDASequence, _channel - -from ._channel_row import ChannelRow, try_cast_colormap -from ._datastore import QOMEZarrDatastore -from ._labeled_slider import LabeledVisibilitySlider -from ._save_button import SaveButton - -DIMENSIONS = ["t", "z", "c", "p", "g"] -AUTOCLIM_RATE = 1 # Hz 0 = inf - -try: - from vispy import scene - from vispy.visuals.transforms import MatrixTransform -except ImportError as e: - raise ImportError( - "vispy is required for StackViewer. " - "Please run `pip install pymmcore-widgets[image]`" - ) from e +from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QCollapsible, QElidingLabel, QIconifyIcon + +from ._backends import get_canvas +from ._dims_slider import DimsSliders +from ._indexing import is_xarray_dataarray, isel +from ._lut_control import LutControl if TYPE_CHECKING: - import cmap - from pymmcore_plus import CMMCorePlus - from qtpy.QtCore import QCloseEvent - from qtpy.QtWidgets import QWidget - from vispy.scene.events import SceneMouseEvent + from typing import Any, Callable, Hashable, TypeAlias + + from ._dims_slider import DimKey, Indices, Sizes + from ._protocols import PCanvas, PImageHandle + + ImgKey: TypeAlias = Hashable + # any mapping of dimensions to sizes + SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] + + +GRAYS = cmap.Colormap("gray") +COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] + + +class ChannelMode(str, Enum): + COMPOSITE = "composite" + MONO = "mono" + + def __str__(self) -> str: + return self.value + + +class ChannelModeButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.toggled.connect(self.next_mode) + def next_mode(self) -> None: + if self.isChecked(): + self.setMode(ChannelMode.MONO) + else: + self.setMode(ChannelMode.COMPOSITE) + + def mode(self) -> ChannelMode: + return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE + + def setMode(self, mode: ChannelMode) -> None: + # we show the name of the next mode, not the current one + other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO + self.setText(str(other)) + self.setChecked(mode == ChannelMode.MONO) -class StackViewer(QtWidgets.QWidget): - """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events. - Parameters - ---------- - transform: (int, bool, bool) rotation mirror_x mirror_y. - """ +class StackViewer(QWidget): + """A viewer for AND arrays.""" def __init__( self, - datastore: QOMEZarrDatastore | None = None, - sequence: MDASequence | None = None, - mmcore: CMMCorePlus | None = None, + data: Any, + *, parent: QWidget | None = None, - size: tuple[int, int] | None = None, - transform: tuple[int, bool, bool] = (0, True, False), - save_button: bool = True, + channel_axis: DimKey | None = None, + channel_mode: ChannelMode = ChannelMode.MONO, ): super().__init__(parent=parent) - self._reload_position() - self.sequence = sequence - self.canvas_size = size - self.transform = transform - self._mmc = mmcore - self._clim = "auto" - self.cmaps = [ - cm for x in self.cmap_names if (cm := try_cast_colormap(x)) is not None - ] - self.display_index = {dim: 0 for dim in DIMENSIONS} - - self.main_layout = QtWidgets.QVBoxLayout() - self.setLayout(self.main_layout) - self.construct_canvas() - self.main_layout.addWidget(self._canvas.native) - - self.info_bar = QtWidgets.QLabel() - self.info_bar.setSizePolicy( - QtWidgets.QSizePolicy.Policy.Fixed, QtWidgets.QSizePolicy.Policy.Fixed + + # ATTRIBUTES ---------------------------------------------------- + + # dimensions of the data in the datastore + self._sizes: Sizes = {} + # mapping of key to a list of objects that control image nodes in the canvas + self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) + # mapping of same keys to the LutControl objects control image display props + self._lut_ctrls: dict[ImgKey, LutControl] = {} + # the set of dimensions we are currently visualizing (e.g. XY) + # this is used to control which dimensions have sliders and the behavior + # of isel when selecting data from the datastore + self._visualized_dims: set[DimKey] = set() + # the axis that represents the channels in the data + self._channel_axis = channel_axis + self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode + # colormaps that will be cycled through when displaying composite images + # TODO: allow user to set this + self._cmaps = cycle(COLORMAPS) + + # WIDGETS ---------------------------------------------------- + + # the button that controls the display mode of the channels + self._channel_mode_btn = ChannelModeButton() + self._channel_mode_btn.clicked.connect(self.set_channel_mode) + # button to reset the zoom of the canvas + self._set_range_btn = QPushButton( + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "" ) - self.main_layout.addWidget(self.info_bar) - - self._create_sliders(sequence) - - self.datastore = datastore or QOMEZarrDatastore() - self.datastore.frame_ready.connect(self.frameReady) - if not datastore: - if self._mmc: - self._mmc.mda.events.frameReady.connect(self.datastore.frameReady) - self._mmc.mda.events.sequenceFinished.connect( - self.datastore.sequenceFinished - ) - self._mmc.mda.events.sequenceStarted.connect( - self.datastore.sequenceStarted - ) - else: - warnings.warn( - "No datastore or mmcore provided, connect manually.", stacklevel=2 - ) - - if self._mmc: - # Otherwise connect via listeners_connected or manually - self._mmc.mda.events.sequenceStarted.connect(self.sequenceStarted) - - self.images: dict[tuple, scene.visuals.Image] = {} - self.frame = 0 - self.ready = False - self.current_channel = 0 - self.pixel_size = 1.0 - self.missed_events: list[MDAEvent] = [] - - self.destroyed.connect(self._disconnect) - - self.collapse_btn = QtWidgets.QPushButton() - self.collapse_btn.setIcon(fonticon.icon(MDI6.arrow_collapse_all)) - self.collapse_btn.clicked.connect(self._collapse_view) - - self.bottom_buttons = QtWidgets.QHBoxLayout() - self.bottom_buttons.addWidget(self.collapse_btn) - if save_button: - self.save_btn = SaveButton(self.datastore) - self.bottom_buttons.addWidget(self.save_btn) - self.main_layout.addLayout(self.bottom_buttons) - - if sequence: - self.sequenceStarted(sequence) - - def construct_canvas(self) -> None: - if self.canvas_size: - self.img_size = self.canvas_size - elif ( - self._mmc - and (h := self._mmc.getImageHeight()) - and (w := self._mmc.getImageWidth()) + self._set_range_btn.clicked.connect(self._on_set_range_clicked) + + # place to display dataset summary + self._data_info = QElidingLabel("") + # place to display arbitrary text + self._hover_info = QLabel("Info") + # the canvas that displays the images + self._canvas: PCanvas = get_canvas()(self._hover_info.setText) + # the sliders that control the index of the displayed image + self._dims_sliders = DimsSliders() + self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) + + self._lut_drop = QCollapsible("LUTs") + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down")) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up")) + lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) + lut_layout.setContentsMargins(0, 1, 0, 1) + lut_layout.setSpacing(0) + if ( + hasattr(self._lut_drop, "_content") + and (layout := self._lut_drop._content.layout()) is not None ): - self.img_size = (h, w) - else: - self.img_size = (512, 512) - if any(x < 1 for x in self.img_size): - raise ValueError("Image size must be greater than 0.") - self._canvas = scene.SceneCanvas( - size=self.img_size, parent=self, keys="interactive" - ) - self._canvas._send_hover_events = True - self._canvas.events.mouse_move.connect(self.on_mouse_move) - self.view = self._canvas.central_widget.add_view() - self.view.camera = scene.PanZoomCamera(aspect=1) - self.view.camera.flip = (self.transform[1], self.transform[2], False) - self.view.camera.set_range( - (0, self.img_size[0]), (0, self.img_size[1]), margin=0 - ) - rect = self.view.camera.rect - self.view_rect = (rect.pos, rect.size) - self.view.camera.aspect = 1 - - def _create_sliders(self, sequence: MDASequence | None = None) -> None: - self.channel_row: ChannelRow = ChannelRow(parent=self) - self.channel_row.visible.connect(self._handle_channel_visibility) - self.channel_row.autoscale.connect(self._handle_channel_autoscale) - self.channel_row.new_clims.connect(self._handle_channel_clim) - self.channel_row.new_cmap.connect(self._handle_channel_cmap) - self.channel_row.selected.connect(self._handle_channel_choice) - self.layout().addWidget(self.channel_row) - - self.slider_layout = QtWidgets.QVBoxLayout() - self.layout().addLayout(self.slider_layout) - self.sliders: dict[str, LabeledVisibilitySlider] = {} - - @superqt.ensure_main_thread # type: ignore - def add_slider(self, dim: str) -> None: - slider = LabeledVisibilitySlider( - dim, orientation=QtCore.Qt.Orientation.Horizontal - ) - slider.sliderMoved.connect(self.on_display_timer) - slider.setRange(0, 1) - self.slider_layout.addWidget(slider) - self.sliders[dim] = slider - - @superqt.ensure_main_thread # type: ignore - def add_image(self, event: MDAEvent) -> None: - image = scene.visuals.Image( - np.zeros(self._canvas.size).astype(np.uint16), - parent=self.view.scene, - cmap=self.cmaps[event.index.get("c", 0)].to_vispy(), - clim=(0, 1), - ) - trans = MatrixTransform() - trans.rotate(self.transform[0], (0, 0, 1)) - image.transform = self._get_image_position(trans, event) - image.interactive = True - if event.index.get("c", 0) > 0: - image.set_gl_state("additive", depth_test=False) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + # LAYOUT ----------------------------------------------------- + + self._btns = btns = QHBoxLayout() + btns.setContentsMargins(0, 0, 0, 0) + btns.setSpacing(0) + btns.addStretch() + btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._set_range_btn) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self._data_info) + layout.addWidget(self._canvas.qwidget(), 1) + layout.addWidget(self._hover_info) + layout.addWidget(self._dims_sliders) + layout.addWidget(self._lut_drop) + layout.addLayout(btns) + + # SETUP ------------------------------------------------------ + + self.set_data(data) + self.set_channel_mode(channel_mode) + + # ------------------- PUBLIC API ---------------------------- + + @property + def data(self) -> Any: + """Return the data backing the view.""" + return self._data + + def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: + """Set the datastore, and, optionally, the sizes of the data.""" + if sizes is None: + if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): + sizes = sz + elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): + sizes = shp + self._sizes = _to_sizes(sizes) + self._data = data + if self._channel_axis is None: + self._channel_axis = self._guess_channel_axis(data) + self.set_visualized_dims(list(self._sizes)[-2:]) + self.update_slider_maxima() + self.setIndex({}) + + info = f"{getattr(type(data), '__qualname__', '')}" + + if self._sizes: + if all(isinstance(x, int) for x in self._sizes): + size_str = repr(tuple(self._sizes.values())) + else: + size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + self._data_info.setText(info) + + def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: + """Set the dimensions that will be visualized. + + This dims will NOT have sliders associated with them. + """ + self._visualized_dims = set(dims) + for d in self._dims_sliders._sliders: + self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) + for d in self._visualized_dims: + self._dims_sliders.set_dimension_visible(d, False) + + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def sizes(self) -> Sizes: + """Return sizes {dimkey: int} of the dimensions in the datastore.""" + return self._sizes + + def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: + """Set the maximum values of the sliders. + + If `sizes` is not provided, sizes will be inferred from the datastore. + """ + if sizes is None: + sizes = self.sizes + sizes = _to_sizes(sizes) + self._dims_sliders.setMaximum({k: v - 1 for k, v in sizes.items()}) + + # FIXME: this needs to be moved and made user-controlled + for dim in list(sizes.values())[-2:]: + self._dims_sliders.set_dimension_visible(dim, False) + + def set_channel_mode(self, mode: ChannelMode | None = None) -> None: + """Set the mode for displaying the channels. + + In "composite" mode, the channels are displayed as a composite image, using + self._channel_axis as the channel axis. In "grayscale" mode, each channel is + displayed separately. (If mode is None, the current value of the + channel_mode_picker button is used) + """ + if mode is None or isinstance(mode, bool): + mode = self._channel_mode_btn.mode() else: - image.set_gl_state(depth_test=False) - c = event.index.get("c", 0) - g = event.index.get("g", 0) - self.images[(("c", c), ("g", g))] = image - - def sequenceStarted(self, sequence: MDASequence) -> None: - """Sequence started by the mmcore. Adjust our settings, make layers etc.""" - self.ready = False - self.sequence = sequence - self.pixel_size = self._mmc.getPixelSizeUm() if self._mmc else self.pixel_size - - self.ng = max(sequence.sizes.get("g", 1), 1) - self.current_channel = 0 - - self._collapse_view() - self.ready = True - - def frameReady(self, event: MDAEvent) -> None: - """Frame received from acquisition, display the image, update sliders etc.""" - if not self.ready: - self._redisplay(event) + self._channel_mode_btn.setMode(mode) + if mode == getattr(self, "_channel_mode", None): return - indices = dict(event.index) - img = self.datastore.get_frame(event) - # Update display - try: - display_indices = self._set_sliders(indices) - except KeyError as e: - self.add_slider(e.args[0]) - self._redisplay(event) + + self._channel_mode = mode + # reset the colormap cycle + self._cmaps = cycle(COLORMAPS) + # set the visibility of the channel slider + c_visible = mode != ChannelMode.COMPOSITE + self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) + + if not self._img_handles: return - if display_indices == indices: - # Get controls - try: - clim_slider = self.channel_row.boxes[indices.get("c", 0)].slider - except KeyError: - this_channel = cast(_channel.Channel, event.channel) - self.channel_row.add_channel(this_channel, indices.get("c", 0)) - self._redisplay(event) - return - try: - self.display_image(img, indices.get("c", 0), indices.get("g", 0)) - except KeyError: - self.add_image(event) - self._redisplay(event) - return - - # Handle autoscaling - clim_slider.setRange( - min(clim_slider.minimum(), img.min()), - max(clim_slider.maximum(), img.max()), - ) - if self.channel_row.boxes[indices.get("c", 0)].autoscale_chbx.isChecked(): - clim_slider.setValue( - [ - min(clim_slider.minimum(), img.min()), - max(clim_slider.maximum(), img.max()), - ] - ) - try: - self.on_clim_timer(indices.get("c", 0)) - except KeyError: - return - if sum([event.index.get("t", 0), event.index.get("z", 0)]) == 0: - self._collapse_view() - - def _handle_channel_clim( - self, values: tuple[int, int], channel: int, set_autoscale: bool = True - ) -> None: - for g in range(self.ng): - self.images[(("c", channel), ("g", g))].clim = values - if self.channel_row.boxes[channel].autoscale_chbx.isChecked() and set_autoscale: - self.channel_row.boxes[channel].autoscale_chbx.setCheckState( - QtCore.Qt.CheckState.Unchecked - ) - self._canvas.update() - - def _handle_channel_cmap(self, colormap: cmap.Colormap, channel: int) -> None: - for g in range(self.ng): - try: - self.images[(("c", channel), ("g", g))].cmap = colormap.to_vispy() - except KeyError: - return - if colormap.name not in self.cmap_names: - self.cmap_names.append(self.cmap_names[channel]) - self.cmap_names[channel] = colormap.name - self._canvas.update() - - def _handle_channel_visibility(self, state: bool, channel: int) -> None: - for g in range(self.ng): - checked = self.channel_row.boxes[channel].show_channel.isChecked() - self.images[(("c", channel), ("g", g))].visible = checked - if self.current_channel == channel: - channel_to_set = channel - 1 if channel > 0 else channel + 1 - channel_to_set = 0 if len(self.channel_row.boxes) == 1 else channel_to_set - self.channel_row._handle_channel_choice( - self.channel_row.boxes[channel_to_set].channel - ) - self._canvas.update() - def _handle_channel_autoscale(self, state: bool, channel: int) -> None: - slider = self.channel_row.boxes[channel].slider - if state == 0: - self._handle_channel_clim(slider.value(), channel, set_autoscale=False) + # determine what needs to be updated + n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 + value = self._dims_sliders.value() # get before clearing + self._clear_images() + indices = ( + [value] + if c_visible + else [{**value, self._channel_axis: i} for i in range(n_channels)] + ) + + # update the displayed images + for idx in indices: + self._update_data_for_index(idx) + self._canvas.refresh() + + def setIndex(self, index: Indices) -> None: + """Set the index of the displayed image.""" + self._dims_sliders.setValue(index) + + # ------------------- PRIVATE METHODS ---------------------------- + + def _guess_channel_axis(self, data: Any) -> DimKey: + """Guess the channel axis from the data.""" + if isinstance(data, np.ndarray): + # for numpy arrays, use the smallest dimension as the channel axis + return data.shape.index(min(data.shape)) + if is_xarray_dataarray(data): + for d in data.dims: + if str(d).lower() in ("channel", "ch", "c"): + return cast("DimKey", d) + return 0 + + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._img_handles.values(): + for handle in handles: + handle.remove() + self._img_handles.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + + def _on_set_range_clicked(self) -> None: + self._canvas.set_range() + + def _image_key(self, index: Indices) -> ImgKey: + """Return the key for image handle(s) corresponding to `index`.""" + if self._channel_mode == ChannelMode.COMPOSITE: + val = index.get(self._channel_axis, 0) + if isinstance(val, slice): + return (val.start, val.stop) + return val + return 0 + + def _isel(self, index: Indices) -> np.ndarray: + """Select data from the datastore using the given index.""" + idx = {k: v for k, v in index.items() if k not in self._visualized_dims} + try: + return isel(self._data, idx) + except Exception as e: + raise type(e)(f"Failed to index data with {idx}: {e}") from e + + def _on_dims_sliders_changed(self, index: Indices) -> None: + """Update the displayed image when the sliders are changed.""" + c = index.get(self._channel_axis, 0) + indices: list[Indices] = [index] + if self._channel_mode == ChannelMode.COMPOSITE: + for i, handles in self._img_handles.items(): + if isinstance(i, (int, slice)): + if handles and c != i: + indices.append({**index, self._channel_axis: i}) + else: # pragma: no cover + warnings.warn(f"Invalid key for composite image: {i}", stacklevel=2) + + for idx in indices: + self._update_data_for_index(idx) + self._canvas.refresh() + + def _update_data_for_index(self, index: Indices) -> None: + """Update the displayed image for the given index. + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. + """ + imkey = self._image_key(index) + data = self._isel(index).squeeze() + data = self._reduce_dims_for_display(data) + if handles := self._img_handles[imkey]: + for handle in handles: + handle.data = data + if ctrl := self._lut_ctrls.get(imkey, None): + ctrl.update_autoscale() else: - clim = ( - slider.minimum(), - slider.maximum(), + cm = ( + next(self._cmaps) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS ) - self._handle_channel_clim(clim, channel, set_autoscale=False) - - def _handle_channel_choice(self, channel: int) -> None: - self.current_channel = channel - - def on_mouse_move(self, event: SceneMouseEvent) -> None: - """Mouse moved on the canvas, display the pixel value and position.""" - # https://groups.google.com/g/vispy/c/sUNKoDL1Gc0/m/E5AG7lgPFQAJ - self.view.interactive = False - images = [] - all_images = [] - # Get the images the mouse is over - while image := self._canvas.visual_at(event.pos): - if image in self.images.values(): - images.append(image) - image.interactive = False - all_images.append(image) - for image in all_images: - image.interactive = True - - self.view.interactive = True - if images == []: - transform = self.view.get_transform("canvas", "visual") - p = [int(x) for x in transform.map(event.pos)] - info = f"[{p[0]}, {p[1]}]" - self.info_bar.setText(info) - return - # Adjust channel index is channel(s) are not visible - real_channel = self.current_channel - for i in range(self.current_channel): - i_visible = self.channel_row.boxes[i].show_channel.isChecked() - real_channel = real_channel - 1 if not i_visible else real_channel - images.reverse() - transform = images[real_channel].get_transform("canvas", "visual") - p = [int(x) for x in transform.map(event.pos)] - try: - pos = f"[{p[0]}, {p[1]}]" - value = f"{images[real_channel]._data[p[1], p[0]]}" - info = f"{pos}: {value}" - self.info_bar.setText(info) - except IndexError: - info = f"[{p[0]}, {p[1]}]" - self.info_bar.setText(info) - - def on_display_timer(self) -> None: - """Update display, usually triggered by QTimer started by slider click.""" - old_index = self.display_index.copy() - for slider in self.sliders.values(): - self.display_index[slider.name] = slider.value() - if old_index == self.display_index: - return - if (sequence := self.sequence) is None: - return - for g in range(self.ng): - for c in range(sequence.sizes.get("c", 1)): - frame = self.datastore.get_frame( - MDAEvent( - index={ - "t": self.display_index["t"], - "z": self.display_index["z"], - "c": c, - "g": g, - "p": 0, - } - ) - ) - self.display_image(frame, c, g) - self._canvas.update() - - def _set_sliders(self, indices: dict) -> dict: - """New indices from outside the sliders, update.""" - display_indices = copy.deepcopy(indices) - for index in display_indices: - if index not in ["t", "z"] or display_indices.get(index, 0) == 0: - continue - if self.sliders[index].lock_btn.isChecked(): - display_indices[index] = self.sliders[index].value() - continue - # This blocking doesn't seem to work - # blocked = slider.blockSignals(True) - self.sliders[index].setValue(display_indices.get(index, 0)) - if display_indices.get(index, 0) > self.sliders[index].maximum(): - self.sliders[index].setMaximum(display_indices.get(index, 0)) - # slider.setValue(indices[slider.name]) - # slider.blockSignals(blocked) - return display_indices - - def display_image(self, img: np.ndarray, channel: int = 0, grid: int = 0) -> None: - self.images[(("c", channel), ("g", grid))].set_data(img) - # Should we do this? Might it slow down acquisition while in the same thread? - self._canvas.update() - - def on_clim_timer(self, channel: int | None = None) -> None: - channel_list = ( - list(range(len(self.channel_row.boxes))) if channel is None else [channel] - ) - for grid in range(self.ng): - for channel in channel_list: - if ( - self.channel_row.boxes[channel].autoscale_chbx.isChecked() - and (img := self.images[(("c", channel), ("g", grid))]).visible - ): - # TODO: percentile here, could be in gui - img.clim = np.percentile(img._data, [0, 100]) - self._canvas.update() - - def _get_image_position( - self, - trans: MatrixTransform, - event: MDAEvent, - ) -> MatrixTransform: - translate = [round(x) for x in ((1, 1) - trans.matrix[:2, :2].dot((1, 1))) / 2] - - x_pos = event.x_pos or 0 - y_pos = event.y_pos or 0 - w, h = self.img_size - if x_pos == 0 and y_pos == 0: - trans.translate(((1 - translate[1]) * w, translate[0] * h, 0)) - self.view_rect = ((0 - w / 2, 0 - h / 2), (w, h)) + handles.append(self._canvas.add_image(data, cmap=cm)) + if imkey not in self._lut_ctrls: + channel_name = f"Ch {imkey}" # TODO: get name from user + self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) + c.update_autoscale() + self._lut_drop.addWidget(c) + + def _reduce_dims_for_display( + self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max + ) -> np.ndarray: + """Reduce the number of dimensions in the data for display. + + This function takes a data array and reduces the number of dimensions to + the max allowed for display. The default behavior is to reduce the smallest + dimensions, using np.max. This can be improved in the future. + """ + # TODO + # - allow for 3d data + # - allow dimensions to control how they are reduced + # - for better way to determine which dims need to be reduced + visualized_dims = 2 + if extra_dims := data.ndim - visualized_dims: + shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) + return reductor(data, axis=smallest_dims) + + if data.dtype == np.float64: + data = data.astype(np.float32) + return data + + +def _to_sizes(sizes: SizesLike | None) -> Sizes: + """Coerce `sizes` to a {dimKey -> int} mapping.""" + if sizes is None: + return {} + if isinstance(sizes, Mapping): + return {k: int(v) for k, v in sizes.items()} + if not isinstance(sizes, Iterable): + raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}") + _sizes: dict[Hashable, int] = {} + for i, val in enumerate(sizes): + if isinstance(val, int): + _sizes[i] = val + elif isinstance(val, Sequence) and len(val) == 2: + _sizes[val[0]] = int(val[1]) else: - trans.translate(((translate[1] - 1) * w, translate[0] * h, 0)) - trans.translate((x_pos / self.pixel_size, y_pos / self.pixel_size, 0)) - self._expand_canvas_view(event) - return trans - - def _expand_canvas_view(self, event: MDAEvent) -> None: - """Expand the canvas view to include the new image.""" - x_pos = event.x_pos or 0 - y_pos = event.y_pos or 0 - img_position = ( - x_pos / self.pixel_size - self.img_size[0] / 2, - x_pos / self.pixel_size + self.img_size[0] / 2, - y_pos / self.pixel_size - self.img_size[1] / 2, - y_pos / self.pixel_size + self.img_size[1] / 2, - ) - camera_rect = [ - self.view_rect[0][0], - self.view_rect[0][0] + self.view_rect[1][0], - self.view_rect[0][1], - self.view_rect[0][1] + self.view_rect[1][1], - ] - if camera_rect[0] > img_position[0]: - camera_rect[0] = img_position[0] - if camera_rect[1] < img_position[1]: - camera_rect[1] = img_position[1] - if camera_rect[2] > img_position[2]: - camera_rect[2] = img_position[2] - if camera_rect[3] < img_position[3]: - camera_rect[3] = img_position[3] - self.view_rect = ( - (camera_rect[0], camera_rect[2]), - (camera_rect[1] - camera_rect[0], camera_rect[3] - camera_rect[2]), - ) - - def _disconnect(self) -> None: - if self._mmc: - self._mmc.mda.events.sequenceStarted.disconnect(self.sequenceStarted) - self.datastore.frame_ready.disconnect(self.frameReady) - - def _reload_position(self) -> None: - self.qt_settings = QtCore.QSettings("pymmcore_plus", self.__class__.__name__) - self.resize(self.qt_settings.value("size", QtCore.QSize(270, 225))) - self.move(self.qt_settings.value("pos", QtCore.QPoint(50, 50))) - self.cmap_names = self.qt_settings.value("cmaps", ["gray", "cyan", "magenta"]) - - def _collapse_view(self) -> None: - w, h = self.img_size - view_rect = ( - (self.view_rect[0][0] - w / 2, self.view_rect[0][1] + h / 2), - self.view_rect[1], - ) - self.view.camera.rect = view_rect - - def _reemit_missed_events(self) -> None: - while self.missed_events: - self.frameReady(self.missed_events.pop(0)) - - @superqt.ensure_main_thread # type: ignore - def _redisplay(self, event: MDAEvent) -> None: - self.missed_events.append(event) - QTimer.singleShot(0, self._reemit_missed_events) - - def closeEvent(self, e: QCloseEvent) -> None: - """Write window size and position to config file.""" - self.qt_settings.setValue("size", self.size()) - self.qt_settings.setValue("pos", self.pos()) - self.qt_settings.setValue("cmaps", self.cmap_names) - self._canvas.close() - super().closeEvent(e) + raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.") + return _sizes diff --git a/src/pymmcore_widgets/_stack_viewer2/__init__.py b/src/pymmcore_widgets/_stack_viewer2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py deleted file mode 100644 index 7999dea7b..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_mda_viewer.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import superqt -import useq -from psygnal import Signal as psygnalSignal -from pymmcore_plus.mda.handlers import OMEZarrWriter - -from ._save_button import SaveButton -from ._stack_viewer import StackViewer - -if TYPE_CHECKING: - import numpy as np - from qtpy.QtWidgets import QWidget - - -# FIXME: get rid of this thin subclass -class DataStore(OMEZarrWriter): - frame_ready = psygnalSignal(object, useq.MDAEvent) - - def frameReady(self, frame: np.ndarray, event: useq.MDAEvent, meta: dict) -> None: - super().frameReady(frame, event, meta) - self.frame_ready.emit(frame, event) - - -class MDAViewer(StackViewer): - """StackViewer specialized for pymmcore-plus MDA acquisitions.""" - - def __init__(self, *, parent: QWidget | None = None): - super().__init__(DataStore(), parent=parent, channel_axis="c") - self._save_btn = SaveButton(self.data) - self._btns.addWidget(self._save_btn) - self._data.frame_ready.connect(self.on_frame_ready) - self.dims_sliders.set_locks_visible(True) - - @superqt.ensure_main_thread # type: ignore - def on_frame_ready(self, frame: np.ndarray, event: useq.MDAEvent) -> None: - self.setIndex(event.index) # type: ignore diff --git a/src/pymmcore_widgets/_stack_viewer2/_save_button.py b/src/pymmcore_widgets/_stack_viewer2/_save_button.py deleted file mode 100644 index c526ab258..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_save_button.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any - -import numpy as np -from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget -from superqt.iconify import QIconifyIcon - -from ._indexing import is_xarray_dataarray - - -class SaveButton(QPushButton): - def __init__( - self, - datastore: Any, - parent: QWidget | None = None, - ): - super().__init__(parent=parent) - self.setIcon(QIconifyIcon("mdi:content-save")) - self.clicked.connect(self._on_click) - - self._data = datastore - self._last_loc = str(Path.home()) - - def _on_click(self) -> None: - self._last_loc, _ = QFileDialog.getSaveFileName( - self, "Choose destination", str(self._last_loc), "" - ) - suffix = Path(self._last_loc).suffix - if suffix in (".zarr", ".ome.zarr", ""): - _save_as_zarr(self._last_loc, self._data) - else: - raise ValueError(f"Unsupported file format: {self._last_loc}") - - -def _save_as_zarr(save_loc: str | Path, data: Any) -> None: - import zarr - from pymmcore_plus.mda.handlers import OMEZarrWriter - - if isinstance(data, OMEZarrWriter): - zarr.copy_store(data.group.store, zarr.DirectoryStore(save_loc)) - elif isinstance(data, zarr.Array): - data.store = zarr.DirectoryStore(save_loc) - elif isinstance(data, np.ndarray): - zarr.save(str(save_loc), data) - elif is_xarray_dataarray(data): - data.to_zarr(save_loc) - else: - raise ValueError(f"Cannot save data of type {type(data)} to Zarr format.") diff --git a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py deleted file mode 100644 index 9410fde70..000000000 --- a/src/pymmcore_widgets/_stack_viewer2/_stack_viewer.py +++ /dev/null @@ -1,399 +0,0 @@ -from __future__ import annotations - -import warnings -from collections import defaultdict -from enum import Enum -from itertools import cycle -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast - -import cmap -import numpy as np -from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QCollapsible, QElidingLabel, QIconifyIcon - -from ._backends import get_canvas -from ._dims_slider import DimsSliders -from ._indexing import is_xarray_dataarray, isel -from ._lut_control import LutControl - -if TYPE_CHECKING: - from typing import Any, Callable, Hashable, TypeAlias - - from ._dims_slider import DimKey, Indices, Sizes - from ._protocols import PCanvas, PImageHandle - - ImgKey: TypeAlias = Hashable - # any mapping of dimensions to sizes - SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] - - -GRAYS = cmap.Colormap("gray") -COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] - - -class ChannelMode(str, Enum): - COMPOSITE = "composite" - MONO = "mono" - - def __str__(self) -> str: - return self.value - - -class ChannelModeButton(QPushButton): - def __init__(self, parent: QWidget | None = None): - super().__init__(parent) - self.setCheckable(True) - self.toggled.connect(self.next_mode) - - def next_mode(self) -> None: - if self.isChecked(): - self.setMode(ChannelMode.MONO) - else: - self.setMode(ChannelMode.COMPOSITE) - - def mode(self) -> ChannelMode: - return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE - - def setMode(self, mode: ChannelMode) -> None: - # we show the name of the next mode, not the current one - other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO - self.setText(str(other)) - self.setChecked(mode == ChannelMode.MONO) - - -class StackViewer(QWidget): - """A viewer for AND arrays.""" - - def __init__( - self, - data: Any, - *, - parent: QWidget | None = None, - channel_axis: DimKey | None = None, - channel_mode: ChannelMode = ChannelMode.MONO, - ): - super().__init__(parent=parent) - - # ATTRIBUTES ---------------------------------------------------- - - # dimensions of the data in the datastore - self._sizes: Sizes = {} - # mapping of key to a list of objects that control image nodes in the canvas - self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) - # mapping of same keys to the LutControl objects control image display props - self._lut_ctrls: dict[ImgKey, LutControl] = {} - # the set of dimensions we are currently visualizing (e.g. XY) - # this is used to control which dimensions have sliders and the behavior - # of isel when selecting data from the datastore - self._visualized_dims: set[DimKey] = set() - # the axis that represents the channels in the data - self._channel_axis = channel_axis - self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode - # colormaps that will be cycled through when displaying composite images - # TODO: allow user to set this - self._cmaps = cycle(COLORMAPS) - - # WIDGETS ---------------------------------------------------- - - # the button that controls the display mode of the channels - self._channel_mode_btn = ChannelModeButton() - self._channel_mode_btn.clicked.connect(self.set_channel_mode) - # button to reset the zoom of the canvas - self._set_range_btn = QPushButton( - QIconifyIcon("fluent:full-screen-maximize-24-filled"), "" - ) - self._set_range_btn.clicked.connect(self._on_set_range_clicked) - - # place to display dataset summary - self._data_info = QElidingLabel("") - # place to display arbitrary text - self._hover_info = QLabel("Info") - # the canvas that displays the images - self._canvas: PCanvas = get_canvas()(self._hover_info.setText) - # the sliders that control the index of the displayed image - self._dims_sliders = DimsSliders() - self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) - - self._lut_drop = QCollapsible("LUTs") - self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down")) - self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up")) - lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) - lut_layout.setContentsMargins(0, 1, 0, 1) - lut_layout.setSpacing(0) - if ( - hasattr(self._lut_drop, "_content") - and (layout := self._lut_drop._content.layout()) is not None - ): - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(0) - - # LAYOUT ----------------------------------------------------- - - self._btns = btns = QHBoxLayout() - btns.setContentsMargins(0, 0, 0, 0) - btns.setSpacing(0) - btns.addStretch() - btns.addWidget(self._channel_mode_btn) - btns.addWidget(self._set_range_btn) - - layout = QVBoxLayout(self) - layout.setSpacing(2) - layout.setContentsMargins(6, 6, 6, 6) - layout.addWidget(self._data_info) - layout.addWidget(self._canvas.qwidget(), 1) - layout.addWidget(self._hover_info) - layout.addWidget(self._dims_sliders) - layout.addWidget(self._lut_drop) - layout.addLayout(btns) - - # SETUP ------------------------------------------------------ - - self.set_data(data) - self.set_channel_mode(channel_mode) - - # ------------------- PUBLIC API ---------------------------- - - @property - def data(self) -> Any: - """Return the data backing the view.""" - return self._data - - def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: - """Set the datastore, and, optionally, the sizes of the data.""" - if sizes is None: - if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): - sizes = sz - elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): - sizes = shp - self._sizes = _to_sizes(sizes) - self._data = data - if self._channel_axis is None: - self._channel_axis = self._guess_channel_axis(data) - self.set_visualized_dims(list(self._sizes)[-2:]) - self.update_slider_maxima() - self.setIndex({}) - - info = f"{getattr(type(data), '__qualname__', '')}" - - if self._sizes: - if all(isinstance(x, int) for x in self._sizes): - size_str = repr(tuple(self._sizes.values())) - else: - size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) - size_str = f"({size_str})" - info += f" {size_str}" - if dtype := getattr(data, "dtype", ""): - info += f", {dtype}" - if nbytes := getattr(data, "nbytes", 0) / 1e6: - info += f", {nbytes:.2f}MB" - self._data_info.setText(info) - - def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: - """Set the dimensions that will be visualized. - - This dims will NOT have sliders associated with them. - """ - self._visualized_dims = set(dims) - for d in self._dims_sliders._sliders: - self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) - for d in self._visualized_dims: - self._dims_sliders.set_dimension_visible(d, False) - - @property - def dims_sliders(self) -> DimsSliders: - """Return the DimsSliders widget.""" - return self._dims_sliders - - @property - def sizes(self) -> Sizes: - """Return sizes {dimkey: int} of the dimensions in the datastore.""" - return self._sizes - - def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: - """Set the maximum values of the sliders. - - If `sizes` is not provided, sizes will be inferred from the datastore. - """ - if sizes is None: - sizes = self.sizes - sizes = _to_sizes(sizes) - self._dims_sliders.setMaximum({k: v - 1 for k, v in sizes.items()}) - - # FIXME: this needs to be moved and made user-controlled - for dim in list(sizes.values())[-2:]: - self._dims_sliders.set_dimension_visible(dim, False) - - def set_channel_mode(self, mode: ChannelMode | None = None) -> None: - """Set the mode for displaying the channels. - - In "composite" mode, the channels are displayed as a composite image, using - self._channel_axis as the channel axis. In "grayscale" mode, each channel is - displayed separately. (If mode is None, the current value of the - channel_mode_picker button is used) - """ - if mode is None or isinstance(mode, bool): - mode = self._channel_mode_btn.mode() - else: - self._channel_mode_btn.setMode(mode) - if mode == getattr(self, "_channel_mode", None): - return - - self._channel_mode = mode - # reset the colormap cycle - self._cmaps = cycle(COLORMAPS) - # set the visibility of the channel slider - c_visible = mode != ChannelMode.COMPOSITE - self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) - - if not self._img_handles: - return - - # determine what needs to be updated - n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 - value = self._dims_sliders.value() # get before clearing - self._clear_images() - indices = ( - [value] - if c_visible - else [{**value, self._channel_axis: i} for i in range(n_channels)] - ) - - # update the displayed images - for idx in indices: - self._update_data_for_index(idx) - self._canvas.refresh() - - def setIndex(self, index: Indices) -> None: - """Set the index of the displayed image.""" - self._dims_sliders.setValue(index) - - # ------------------- PRIVATE METHODS ---------------------------- - - def _guess_channel_axis(self, data: Any) -> DimKey: - """Guess the channel axis from the data.""" - if isinstance(data, np.ndarray): - # for numpy arrays, use the smallest dimension as the channel axis - return data.shape.index(min(data.shape)) - if is_xarray_dataarray(data): - for d in data.dims: - if str(d).lower() in ("channel", "ch", "c"): - return cast("DimKey", d) - return 0 - - def _clear_images(self) -> None: - """Remove all images from the canvas.""" - for handles in self._img_handles.values(): - for handle in handles: - handle.remove() - self._img_handles.clear() - - # clear the current LutControls as well - for c in self._lut_ctrls.values(): - cast("QVBoxLayout", self.layout()).removeWidget(c) - c.deleteLater() - self._lut_ctrls.clear() - - def _on_set_range_clicked(self) -> None: - self._canvas.set_range() - - def _image_key(self, index: Indices) -> ImgKey: - """Return the key for image handle(s) corresponding to `index`.""" - if self._channel_mode == ChannelMode.COMPOSITE: - val = index.get(self._channel_axis, 0) - if isinstance(val, slice): - return (val.start, val.stop) - return val - return 0 - - def _isel(self, index: Indices) -> np.ndarray: - """Select data from the datastore using the given index.""" - idx = {k: v for k, v in index.items() if k not in self._visualized_dims} - try: - return isel(self._data, idx) - except Exception as e: - raise type(e)(f"Failed to index data with {idx}: {e}") from e - - def _on_dims_sliders_changed(self, index: Indices) -> None: - """Update the displayed image when the sliders are changed.""" - c = index.get(self._channel_axis, 0) - indices: list[Indices] = [index] - if self._channel_mode == ChannelMode.COMPOSITE: - for i, handles in self._img_handles.items(): - if isinstance(i, (int, slice)): - if handles and c != i: - indices.append({**index, self._channel_axis: i}) - else: # pragma: no cover - warnings.warn(f"Invalid key for composite image: {i}", stacklevel=2) - - for idx in indices: - self._update_data_for_index(idx) - self._canvas.refresh() - - def _update_data_for_index(self, index: Indices) -> None: - """Update the displayed image for the given index. - - This will pull the data from the datastore using the given index, and update - the image handle(s) with the new data. - """ - imkey = self._image_key(index) - data = self._isel(index).squeeze() - data = self._reduce_dims_for_display(data) - if handles := self._img_handles[imkey]: - for handle in handles: - handle.data = data - if ctrl := self._lut_ctrls.get(imkey, None): - ctrl.update_autoscale() - else: - cm = ( - next(self._cmaps) - if self._channel_mode == ChannelMode.COMPOSITE - else GRAYS - ) - handles.append(self._canvas.add_image(data, cmap=cm)) - if imkey not in self._lut_ctrls: - channel_name = f"Ch {imkey}" # TODO: get name from user - self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) - c.update_autoscale() - self._lut_drop.addWidget(c) - - def _reduce_dims_for_display( - self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max - ) -> np.ndarray: - """Reduce the number of dimensions in the data for display. - - This function takes a data array and reduces the number of dimensions to - the max allowed for display. The default behavior is to reduce the smallest - dimensions, using np.max. This can be improved in the future. - """ - # TODO - # - allow for 3d data - # - allow dimensions to control how they are reduced - # - for better way to determine which dims need to be reduced - visualized_dims = 2 - if extra_dims := data.ndim - visualized_dims: - shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) - smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) - return reductor(data, axis=smallest_dims) - - if data.dtype == np.float64: - data = data.astype(np.float32) - return data - - -def _to_sizes(sizes: SizesLike | None) -> Sizes: - """Coerce `sizes` to a {dimKey -> int} mapping.""" - if sizes is None: - return {} - if isinstance(sizes, Mapping): - return {k: int(v) for k, v in sizes.items()} - if not isinstance(sizes, Iterable): - raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}") - _sizes: dict[Hashable, int] = {} - for i, val in enumerate(sizes): - if isinstance(val, int): - _sizes[i] = val - elif isinstance(val, Sequence) and len(val) == 2: - _sizes[val[0]] = int(val[1]) - else: - raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.") - return _sizes diff --git a/src/pymmcore_widgets/_stack_viewer_v1/__init__.py b/src/pymmcore_widgets/_stack_viewer_v1/__init__.py new file mode 100644 index 000000000..2c4beb6a5 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer_v1/__init__.py @@ -0,0 +1,5 @@ +from ._channel_row import CMAPS +from ._datastore import QOMEZarrDatastore +from ._stack_viewer import StackViewer + +__all__ = ["StackViewer", "CMAPS", "QOMEZarrDatastore"] diff --git a/src/pymmcore_widgets/_stack_viewer/_channel_row.py b/src/pymmcore_widgets/_stack_viewer_v1/_channel_row.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_channel_row.py rename to src/pymmcore_widgets/_stack_viewer_v1/_channel_row.py diff --git a/src/pymmcore_widgets/_stack_viewer/_datastore.py b/src/pymmcore_widgets/_stack_viewer_v1/_datastore.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_datastore.py rename to src/pymmcore_widgets/_stack_viewer_v1/_datastore.py diff --git a/src/pymmcore_widgets/_stack_viewer/_labeled_slider.py b/src/pymmcore_widgets/_stack_viewer_v1/_labeled_slider.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_labeled_slider.py rename to src/pymmcore_widgets/_stack_viewer_v1/_labeled_slider.py diff --git a/src/pymmcore_widgets/_stack_viewer_v1/_save_button.py b/src/pymmcore_widgets/_stack_viewer_v1/_save_button.py new file mode 100644 index 000000000..ce686d3bf --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer_v1/_save_button.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import zarr +from fonticon_mdi6 import MDI6 +from qtpy.QtCore import QSize +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt import fonticon + +from ._datastore import QOMEZarrDatastore + +if TYPE_CHECKING: + from qtpy.QtGui import QCloseEvent + + +class SaveButton(QPushButton): + def __init__( + self, + datastore: QOMEZarrDatastore, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + # self.setFont(QFont('Arial', 50)) + # self.setMinimumHeight(30) + self.setIcon(fonticon.icon(MDI6.content_save_outline, color="gray")) + self.setIconSize(QSize(25, 25)) + self.setFixedSize(30, 30) + self.clicked.connect(self._on_click) + + self.datastore = datastore + self.save_loc = Path.home() + + def _on_click(self) -> None: + self.save_loc, _ = QFileDialog.getSaveFileName(directory=str(self.save_loc)) + if self.save_loc: + self._save_as_zarr(self.save_loc) + + def _save_as_zarr(self, save_loc: str | Path) -> None: + dir_store = zarr.DirectoryStore(save_loc) + zarr.copy_store(self.datastore._group.attrs.store, dir_store) + + def closeEvent(self, a0: QCloseEvent | None) -> None: + super().closeEvent(a0) + + +if __name__ == "__main__": + from pymmcore_plus import CMMCorePlus + from qtpy.QtWidgets import QApplication + from useq import MDASequence + + mmc = CMMCorePlus() + mmc.loadSystemConfiguration() + + app = QApplication([]) + seq = MDASequence( + time_plan={"interval": 0.01, "loops": 10}, + z_plan={"range": 5, "step": 1}, + channels=[{"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}], + ) + datastore = QOMEZarrDatastore() + mmc.mda.events.sequenceStarted.connect(datastore.sequenceStarted) + mmc.mda.events.frameReady.connect(datastore.frameReady) + + widget = SaveButton(datastore) + mmc.run_mda(seq) + widget.show() + app.exec_() diff --git a/src/pymmcore_widgets/_stack_viewer_v1/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v1/_stack_viewer.py new file mode 100644 index 000000000..9b9211736 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer_v1/_stack_viewer.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +import copy +import warnings +from typing import TYPE_CHECKING, cast + +import numpy as np +import superqt +from fonticon_mdi6 import MDI6 +from qtpy import QtCore, QtWidgets +from qtpy.QtCore import QTimer +from superqt import fonticon +from useq import MDAEvent, MDASequence, _channel + +from ._channel_row import ChannelRow, try_cast_colormap +from ._datastore import QOMEZarrDatastore +from ._labeled_slider import LabeledVisibilitySlider +from ._save_button import SaveButton + +DIMENSIONS = ["t", "z", "c", "p", "g"] +AUTOCLIM_RATE = 1 # Hz 0 = inf + +try: + from vispy import scene + from vispy.visuals.transforms import MatrixTransform +except ImportError as e: + raise ImportError( + "vispy is required for StackViewer. " + "Please run `pip install pymmcore-widgets[image]`" + ) from e + +if TYPE_CHECKING: + import cmap + from pymmcore_plus import CMMCorePlus + from qtpy.QtCore import QCloseEvent + from qtpy.QtWidgets import QWidget + from vispy.scene.events import SceneMouseEvent + + +class StackViewer(QtWidgets.QWidget): + """A viewer for MDA acquisitions started by MDASequence in pymmcore-plus events. + + Parameters + ---------- + transform: (int, bool, bool) rotation mirror_x mirror_y. + """ + + def __init__( + self, + datastore: QOMEZarrDatastore | None = None, + sequence: MDASequence | None = None, + mmcore: CMMCorePlus | None = None, + parent: QWidget | None = None, + size: tuple[int, int] | None = None, + transform: tuple[int, bool, bool] = (0, True, False), + save_button: bool = True, + ): + super().__init__(parent=parent) + self._reload_position() + self.sequence = sequence + self.canvas_size = size + self.transform = transform + self._mmc = mmcore + self._clim = "auto" + self.cmaps = [ + cm for x in self.cmap_names if (cm := try_cast_colormap(x)) is not None + ] + self.display_index = {dim: 0 for dim in DIMENSIONS} + + self.main_layout = QtWidgets.QVBoxLayout() + self.setLayout(self.main_layout) + self.construct_canvas() + self.main_layout.addWidget(self._canvas.native) + + self.info_bar = QtWidgets.QLabel() + self.info_bar.setSizePolicy( + QtWidgets.QSizePolicy.Policy.Fixed, QtWidgets.QSizePolicy.Policy.Fixed + ) + self.main_layout.addWidget(self.info_bar) + + self._create_sliders(sequence) + + self.datastore = datastore or QOMEZarrDatastore() + self.datastore.frame_ready.connect(self.frameReady) + if not datastore: + if self._mmc: + self._mmc.mda.events.frameReady.connect(self.datastore.frameReady) + self._mmc.mda.events.sequenceFinished.connect( + self.datastore.sequenceFinished + ) + self._mmc.mda.events.sequenceStarted.connect( + self.datastore.sequenceStarted + ) + else: + warnings.warn( + "No datastore or mmcore provided, connect manually.", stacklevel=2 + ) + + if self._mmc: + # Otherwise connect via listeners_connected or manually + self._mmc.mda.events.sequenceStarted.connect(self.sequenceStarted) + + self.images: dict[tuple, scene.visuals.Image] = {} + self.frame = 0 + self.ready = False + self.current_channel = 0 + self.pixel_size = 1.0 + self.missed_events: list[MDAEvent] = [] + + self.destroyed.connect(self._disconnect) + + self.collapse_btn = QtWidgets.QPushButton() + self.collapse_btn.setIcon(fonticon.icon(MDI6.arrow_collapse_all)) + self.collapse_btn.clicked.connect(self._collapse_view) + + self.bottom_buttons = QtWidgets.QHBoxLayout() + self.bottom_buttons.addWidget(self.collapse_btn) + if save_button: + self.save_btn = SaveButton(self.datastore) + self.bottom_buttons.addWidget(self.save_btn) + self.main_layout.addLayout(self.bottom_buttons) + + if sequence: + self.sequenceStarted(sequence) + + def construct_canvas(self) -> None: + if self.canvas_size: + self.img_size = self.canvas_size + elif ( + self._mmc + and (h := self._mmc.getImageHeight()) + and (w := self._mmc.getImageWidth()) + ): + self.img_size = (h, w) + else: + self.img_size = (512, 512) + if any(x < 1 for x in self.img_size): + raise ValueError("Image size must be greater than 0.") + self._canvas = scene.SceneCanvas( + size=self.img_size, parent=self, keys="interactive" + ) + self._canvas._send_hover_events = True + self._canvas.events.mouse_move.connect(self.on_mouse_move) + self.view = self._canvas.central_widget.add_view() + self.view.camera = scene.PanZoomCamera(aspect=1) + self.view.camera.flip = (self.transform[1], self.transform[2], False) + self.view.camera.set_range( + (0, self.img_size[0]), (0, self.img_size[1]), margin=0 + ) + rect = self.view.camera.rect + self.view_rect = (rect.pos, rect.size) + self.view.camera.aspect = 1 + + def _create_sliders(self, sequence: MDASequence | None = None) -> None: + self.channel_row: ChannelRow = ChannelRow(parent=self) + self.channel_row.visible.connect(self._handle_channel_visibility) + self.channel_row.autoscale.connect(self._handle_channel_autoscale) + self.channel_row.new_clims.connect(self._handle_channel_clim) + self.channel_row.new_cmap.connect(self._handle_channel_cmap) + self.channel_row.selected.connect(self._handle_channel_choice) + self.layout().addWidget(self.channel_row) + + self.slider_layout = QtWidgets.QVBoxLayout() + self.layout().addLayout(self.slider_layout) + self.sliders: dict[str, LabeledVisibilitySlider] = {} + + @superqt.ensure_main_thread # type: ignore + def add_slider(self, dim: str) -> None: + slider = LabeledVisibilitySlider( + dim, orientation=QtCore.Qt.Orientation.Horizontal + ) + slider.sliderMoved.connect(self.on_display_timer) + slider.setRange(0, 1) + self.slider_layout.addWidget(slider) + self.sliders[dim] = slider + + @superqt.ensure_main_thread # type: ignore + def add_image(self, event: MDAEvent) -> None: + image = scene.visuals.Image( + np.zeros(self._canvas.size).astype(np.uint16), + parent=self.view.scene, + cmap=self.cmaps[event.index.get("c", 0)].to_vispy(), + clim=(0, 1), + ) + trans = MatrixTransform() + trans.rotate(self.transform[0], (0, 0, 1)) + image.transform = self._get_image_position(trans, event) + image.interactive = True + if event.index.get("c", 0) > 0: + image.set_gl_state("additive", depth_test=False) + else: + image.set_gl_state(depth_test=False) + c = event.index.get("c", 0) + g = event.index.get("g", 0) + self.images[(("c", c), ("g", g))] = image + + def sequenceStarted(self, sequence: MDASequence) -> None: + """Sequence started by the mmcore. Adjust our settings, make layers etc.""" + self.ready = False + self.sequence = sequence + self.pixel_size = self._mmc.getPixelSizeUm() if self._mmc else self.pixel_size + + self.ng = max(sequence.sizes.get("g", 1), 1) + self.current_channel = 0 + + self._collapse_view() + self.ready = True + + def frameReady(self, event: MDAEvent) -> None: + """Frame received from acquisition, display the image, update sliders etc.""" + if not self.ready: + self._redisplay(event) + return + indices = dict(event.index) + img = self.datastore.get_frame(event) + # Update display + try: + display_indices = self._set_sliders(indices) + except KeyError as e: + self.add_slider(e.args[0]) + self._redisplay(event) + return + if display_indices == indices: + # Get controls + try: + clim_slider = self.channel_row.boxes[indices.get("c", 0)].slider + except KeyError: + this_channel = cast(_channel.Channel, event.channel) + self.channel_row.add_channel(this_channel, indices.get("c", 0)) + self._redisplay(event) + return + try: + self.display_image(img, indices.get("c", 0), indices.get("g", 0)) + except KeyError: + self.add_image(event) + self._redisplay(event) + return + + # Handle autoscaling + clim_slider.setRange( + min(clim_slider.minimum(), img.min()), + max(clim_slider.maximum(), img.max()), + ) + if self.channel_row.boxes[indices.get("c", 0)].autoscale_chbx.isChecked(): + clim_slider.setValue( + [ + min(clim_slider.minimum(), img.min()), + max(clim_slider.maximum(), img.max()), + ] + ) + try: + self.on_clim_timer(indices.get("c", 0)) + except KeyError: + return + if sum([event.index.get("t", 0), event.index.get("z", 0)]) == 0: + self._collapse_view() + + def _handle_channel_clim( + self, values: tuple[int, int], channel: int, set_autoscale: bool = True + ) -> None: + for g in range(self.ng): + self.images[(("c", channel), ("g", g))].clim = values + if self.channel_row.boxes[channel].autoscale_chbx.isChecked() and set_autoscale: + self.channel_row.boxes[channel].autoscale_chbx.setCheckState( + QtCore.Qt.CheckState.Unchecked + ) + self._canvas.update() + + def _handle_channel_cmap(self, colormap: cmap.Colormap, channel: int) -> None: + for g in range(self.ng): + try: + self.images[(("c", channel), ("g", g))].cmap = colormap.to_vispy() + except KeyError: + return + if colormap.name not in self.cmap_names: + self.cmap_names.append(self.cmap_names[channel]) + self.cmap_names[channel] = colormap.name + self._canvas.update() + + def _handle_channel_visibility(self, state: bool, channel: int) -> None: + for g in range(self.ng): + checked = self.channel_row.boxes[channel].show_channel.isChecked() + self.images[(("c", channel), ("g", g))].visible = checked + if self.current_channel == channel: + channel_to_set = channel - 1 if channel > 0 else channel + 1 + channel_to_set = 0 if len(self.channel_row.boxes) == 1 else channel_to_set + self.channel_row._handle_channel_choice( + self.channel_row.boxes[channel_to_set].channel + ) + self._canvas.update() + + def _handle_channel_autoscale(self, state: bool, channel: int) -> None: + slider = self.channel_row.boxes[channel].slider + if state == 0: + self._handle_channel_clim(slider.value(), channel, set_autoscale=False) + else: + clim = ( + slider.minimum(), + slider.maximum(), + ) + self._handle_channel_clim(clim, channel, set_autoscale=False) + + def _handle_channel_choice(self, channel: int) -> None: + self.current_channel = channel + + def on_mouse_move(self, event: SceneMouseEvent) -> None: + """Mouse moved on the canvas, display the pixel value and position.""" + # https://groups.google.com/g/vispy/c/sUNKoDL1Gc0/m/E5AG7lgPFQAJ + self.view.interactive = False + images = [] + all_images = [] + # Get the images the mouse is over + while image := self._canvas.visual_at(event.pos): + if image in self.images.values(): + images.append(image) + image.interactive = False + all_images.append(image) + for image in all_images: + image.interactive = True + + self.view.interactive = True + if images == []: + transform = self.view.get_transform("canvas", "visual") + p = [int(x) for x in transform.map(event.pos)] + info = f"[{p[0]}, {p[1]}]" + self.info_bar.setText(info) + return + # Adjust channel index is channel(s) are not visible + real_channel = self.current_channel + for i in range(self.current_channel): + i_visible = self.channel_row.boxes[i].show_channel.isChecked() + real_channel = real_channel - 1 if not i_visible else real_channel + images.reverse() + transform = images[real_channel].get_transform("canvas", "visual") + p = [int(x) for x in transform.map(event.pos)] + try: + pos = f"[{p[0]}, {p[1]}]" + value = f"{images[real_channel]._data[p[1], p[0]]}" + info = f"{pos}: {value}" + self.info_bar.setText(info) + except IndexError: + info = f"[{p[0]}, {p[1]}]" + self.info_bar.setText(info) + + def on_display_timer(self) -> None: + """Update display, usually triggered by QTimer started by slider click.""" + old_index = self.display_index.copy() + for slider in self.sliders.values(): + self.display_index[slider.name] = slider.value() + if old_index == self.display_index: + return + if (sequence := self.sequence) is None: + return + for g in range(self.ng): + for c in range(sequence.sizes.get("c", 1)): + frame = self.datastore.get_frame( + MDAEvent( + index={ + "t": self.display_index["t"], + "z": self.display_index["z"], + "c": c, + "g": g, + "p": 0, + } + ) + ) + self.display_image(frame, c, g) + self._canvas.update() + + def _set_sliders(self, indices: dict) -> dict: + """New indices from outside the sliders, update.""" + display_indices = copy.deepcopy(indices) + for index in display_indices: + if index not in ["t", "z"] or display_indices.get(index, 0) == 0: + continue + if self.sliders[index].lock_btn.isChecked(): + display_indices[index] = self.sliders[index].value() + continue + # This blocking doesn't seem to work + # blocked = slider.blockSignals(True) + self.sliders[index].setValue(display_indices.get(index, 0)) + if display_indices.get(index, 0) > self.sliders[index].maximum(): + self.sliders[index].setMaximum(display_indices.get(index, 0)) + # slider.setValue(indices[slider.name]) + # slider.blockSignals(blocked) + return display_indices + + def display_image(self, img: np.ndarray, channel: int = 0, grid: int = 0) -> None: + self.images[(("c", channel), ("g", grid))].set_data(img) + # Should we do this? Might it slow down acquisition while in the same thread? + self._canvas.update() + + def on_clim_timer(self, channel: int | None = None) -> None: + channel_list = ( + list(range(len(self.channel_row.boxes))) if channel is None else [channel] + ) + for grid in range(self.ng): + for channel in channel_list: + if ( + self.channel_row.boxes[channel].autoscale_chbx.isChecked() + and (img := self.images[(("c", channel), ("g", grid))]).visible + ): + # TODO: percentile here, could be in gui + img.clim = np.percentile(img._data, [0, 100]) + self._canvas.update() + + def _get_image_position( + self, + trans: MatrixTransform, + event: MDAEvent, + ) -> MatrixTransform: + translate = [round(x) for x in ((1, 1) - trans.matrix[:2, :2].dot((1, 1))) / 2] + + x_pos = event.x_pos or 0 + y_pos = event.y_pos or 0 + w, h = self.img_size + if x_pos == 0 and y_pos == 0: + trans.translate(((1 - translate[1]) * w, translate[0] * h, 0)) + self.view_rect = ((0 - w / 2, 0 - h / 2), (w, h)) + else: + trans.translate(((translate[1] - 1) * w, translate[0] * h, 0)) + trans.translate((x_pos / self.pixel_size, y_pos / self.pixel_size, 0)) + self._expand_canvas_view(event) + return trans + + def _expand_canvas_view(self, event: MDAEvent) -> None: + """Expand the canvas view to include the new image.""" + x_pos = event.x_pos or 0 + y_pos = event.y_pos or 0 + img_position = ( + x_pos / self.pixel_size - self.img_size[0] / 2, + x_pos / self.pixel_size + self.img_size[0] / 2, + y_pos / self.pixel_size - self.img_size[1] / 2, + y_pos / self.pixel_size + self.img_size[1] / 2, + ) + camera_rect = [ + self.view_rect[0][0], + self.view_rect[0][0] + self.view_rect[1][0], + self.view_rect[0][1], + self.view_rect[0][1] + self.view_rect[1][1], + ] + if camera_rect[0] > img_position[0]: + camera_rect[0] = img_position[0] + if camera_rect[1] < img_position[1]: + camera_rect[1] = img_position[1] + if camera_rect[2] > img_position[2]: + camera_rect[2] = img_position[2] + if camera_rect[3] < img_position[3]: + camera_rect[3] = img_position[3] + self.view_rect = ( + (camera_rect[0], camera_rect[2]), + (camera_rect[1] - camera_rect[0], camera_rect[3] - camera_rect[2]), + ) + + def _disconnect(self) -> None: + if self._mmc: + self._mmc.mda.events.sequenceStarted.disconnect(self.sequenceStarted) + self.datastore.frame_ready.disconnect(self.frameReady) + + def _reload_position(self) -> None: + self.qt_settings = QtCore.QSettings("pymmcore_plus", self.__class__.__name__) + self.resize(self.qt_settings.value("size", QtCore.QSize(270, 225))) + self.move(self.qt_settings.value("pos", QtCore.QPoint(50, 50))) + self.cmap_names = self.qt_settings.value("cmaps", ["gray", "cyan", "magenta"]) + + def _collapse_view(self) -> None: + w, h = self.img_size + view_rect = ( + (self.view_rect[0][0] - w / 2, self.view_rect[0][1] + h / 2), + self.view_rect[1], + ) + self.view.camera.rect = view_rect + + def _reemit_missed_events(self) -> None: + while self.missed_events: + self.frameReady(self.missed_events.pop(0)) + + @superqt.ensure_main_thread # type: ignore + def _redisplay(self, event: MDAEvent) -> None: + self.missed_events.append(event) + QTimer.singleShot(0, self._reemit_missed_events) + + def closeEvent(self, e: QCloseEvent) -> None: + """Write window size and position to config file.""" + self.qt_settings.setValue("size", self.size()) + self.qt_settings.setValue("pos", self.pos()) + self.qt_settings.setValue("cmaps", self.cmap_names) + self._canvas.close() + super().closeEvent(e) diff --git a/src/pymmcore_widgets/experimental.py b/src/pymmcore_widgets/experimental.py index 3a704297b..c17acb003 100644 --- a/src/pymmcore_widgets/experimental.py +++ b/src/pymmcore_widgets/experimental.py @@ -1,3 +1,4 @@ -from ._stack_viewer import StackViewer +from ._stack_viewer import MDAViewer +from ._stack_viewer_v1 import StackViewer -__all__ = ["StackViewer"] +__all__ = ["StackViewer", "MDAViewer"] diff --git a/tests/test_datastore.py b/tests/test_datastore.py deleted file mode 100644 index a31df07b7..000000000 --- a/tests/test_datastore.py +++ /dev/null @@ -1,26 +0,0 @@ -from pymmcore_plus import CMMCorePlus -from useq import MDAEvent, MDASequence - -from pymmcore_widgets._stack_viewer._datastore import QOMEZarrDatastore - -sequence = MDASequence( - channels=[{"config": "DAPI", "exposure": 10}], - time_plan={"interval": 0.3, "loops": 3}, -) - - -def test_reception(qtbot): - mmcore = CMMCorePlus.instance() - - datastore = QOMEZarrDatastore() - mmcore.mda.events.frameReady.connect(datastore.frameReady) - mmcore.mda.events.sequenceFinished.connect(datastore.sequenceFinished) - mmcore.mda.events.sequenceStarted.connect(datastore.sequenceStarted) - - with qtbot.waitSignal(datastore.frame_ready, timeout=5000): - mmcore.run_mda(sequence, block=False) - with qtbot.waitSignal(datastore.frame_ready, timeout=5000): - pass - - assert datastore.get_frame(MDAEvent(index={"c": 0, "t": 0})).flatten()[0] != 0 - qtbot.wait(1000) diff --git a/tests/test_stack_viewer.py b/tests/test_stack_viewer1.py similarity index 86% rename from tests/test_stack_viewer.py rename to tests/test_stack_viewer1.py index ce4c688ab..bb64fddb1 100644 --- a/tests/test_stack_viewer.py +++ b/tests/test_stack_viewer1.py @@ -10,7 +10,8 @@ from vispy.app.canvas import MouseEvent from vispy.scene.events import SceneMouseEvent -from pymmcore_widgets._stack_viewer import CMAPS +from pymmcore_widgets._stack_viewer_v1 import CMAPS +from pymmcore_widgets._stack_viewer_v1._datastore import QOMEZarrDatastore from pymmcore_widgets.experimental import StackViewer if TYPE_CHECKING: @@ -177,3 +178,25 @@ def test_not_ready(qtbot: QtBot) -> None: # TODO: we should do something here that checks if the loop finishes canvas.frameReady(MDAEvent()) mmcore.mda.run(sequence) + + +def test_reception(qtbot: QtBot) -> None: + sequence = MDASequence( + channels=[{"config": "DAPI", "exposure": 10}], + time_plan={"interval": 0.3, "loops": 3}, + ) + + mmcore = CMMCorePlus.instance() + + datastore = QOMEZarrDatastore() + mmcore.mda.events.frameReady.connect(datastore.frameReady) + mmcore.mda.events.sequenceFinished.connect(datastore.sequenceFinished) + mmcore.mda.events.sequenceStarted.connect(datastore.sequenceStarted) + + with qtbot.waitSignal(datastore.frame_ready, timeout=5000): + mmcore.run_mda(sequence, block=False) + with qtbot.waitSignal(datastore.frame_ready, timeout=5000): + pass + + assert datastore.get_frame(MDAEvent(index={"c": 0, "t": 0})).flatten()[0] != 0 + qtbot.wait(1000) From 91ecd0c8228241debcaa24ca151fee964649c9fd Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 4 May 2024 16:39:13 -0400 Subject: [PATCH 21/73] more renames --- examples/{stack_viewer2.py => mda_viewer.py} | 9 ++------- examples/{stack_viewer.py => stack_viewer_v1.py} | 2 +- examples/stack_viewer_xr.py | 13 ------------- src/pymmcore_widgets/_stack_viewer/_stack_viewer.py | 3 ++- 4 files changed, 5 insertions(+), 22 deletions(-) rename examples/{stack_viewer2.py => mda_viewer.py} (75%) rename examples/{stack_viewer.py => stack_viewer_v1.py} (93%) delete mode 100644 examples/stack_viewer_xr.py diff --git a/examples/stack_viewer2.py b/examples/mda_viewer.py similarity index 75% rename from examples/stack_viewer2.py rename to examples/mda_viewer.py index 27d9020eb..8ff12df56 100644 --- a/examples/stack_viewer2.py +++ b/examples/mda_viewer.py @@ -14,11 +14,7 @@ mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Noise") sequence = MDASequence( - channels=( - {"config": "DAPI", "exposure": 5}, - {"config": "FITC", "exposure": 20}, - # {"config": "Cy5", "exposure": 20}, - ), + channels=({"config": "DAPI", "exposure": 5}, {"config": "FITC", "exposure": 20}), stage_positions=[(0, 0), (1, 1)], z_plan={"range": 9, "step": 0.4}, time_plan={"interval": 0.2, "loops": 4}, @@ -28,8 +24,7 @@ qapp = QtWidgets.QApplication([]) v = MDAViewer() -v.dims_sliders.set_locks_visible(False) v.show() -mmcore.run_mda(sequence, output=v._data) +mmcore.run_mda(sequence, output=v.data) qapp.exec() diff --git a/examples/stack_viewer.py b/examples/stack_viewer_v1.py similarity index 93% rename from examples/stack_viewer.py rename to examples/stack_viewer_v1.py index 03dc1a78a..6b85c2da8 100644 --- a/examples/stack_viewer.py +++ b/examples/stack_viewer_v1.py @@ -6,7 +6,7 @@ from qtpy import QtWidgets from useq import MDASequence -from pymmcore_widgets.experimental import StackViewer +from pymmcore_widgets._stack_viewer_v1 import StackViewer size = 1028 diff --git a/examples/stack_viewer_xr.py b/examples/stack_viewer_xr.py deleted file mode 100644 index 899a2d6c7..000000000 --- a/examples/stack_viewer_xr.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -# from stack_viewer_numpy import generate_5d_sine_wave -import nd2 -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer - -data = nd2.imread("/Users/talley/Downloads/6D_test.nd2", xarray=True, dask=True) -qapp = QtWidgets.QApplication([]) -v = StackViewer(data, channel_axis="C") -v.show() -qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py index 9410fde70..1172fe409 100644 --- a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py @@ -173,7 +173,8 @@ def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: self.update_slider_maxima() self.setIndex({}) - info = f"{getattr(type(data), '__qualname__', '')}" + package = getattr(data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(data), '__qualname__', '')}" if self._sizes: if all(isinstance(x, int) for x in self._sizes): From 81758ecdd60eb674082c5472df575a280d6887d1 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 6 May 2024 17:08:03 -0400 Subject: [PATCH 22/73] remove bar color --- .../_stack_viewer/_dims_slider.py | 18 ++++++++---------- .../_stack_viewer/_lut_control.py | 4 +--- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py index c195fd0d4..6b0ddd722 100644 --- a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py @@ -36,7 +36,6 @@ # mapping of dimension keys to the maximum value for that dimension Sizes: TypeAlias = Mapping[DimKey, int] -BAR_COLOR = "#2258575B" SS = """ QSlider::groove:horizontal { @@ -51,17 +50,17 @@ QSlider::handle:horizontal { width: 38px; - background: qlineargradient( - x1:0, y1:0, x2:0, y2:1, - stop:0 rgba(148, 148, 148, 1), - stop:1 rgba(148, 148, 148, 1) - ); + background: #999999; border-radius: 3px; } -QLabel { - font-size: 12px; -} +QLabel { font-size: 12px; } + +QRangeSlider { qproperty-barColor: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 80, 120, 0.2), + stop:1 rgba(100, 80, 120, 0.4) + )} SliderLabel { font-size: 12px; @@ -178,7 +177,6 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None # self._int_slider.layout().addWidget(self._max_label) self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) - slc._slider.barColor = BAR_COLOR slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) slc.setVisible(False) diff --git a/src/pymmcore_widgets/_stack_viewer/_lut_control.py b/src/pymmcore_widgets/_stack_viewer/_lut_control.py index fc1a78535..2a1aa27c1 100644 --- a/src/pymmcore_widgets/_stack_viewer/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer/_lut_control.py @@ -9,7 +9,7 @@ from superqt.cmap import QColormapComboBox from superqt.utils import signals_blocked -from ._dims_slider import BAR_COLOR, SS +from ._dims_slider import SS if TYPE_CHECKING: import cmap @@ -53,8 +53,6 @@ def __init__( self._cmap.addColormap(color) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - if hasattr(self._clims, "_slider"): - self._clims._slider.barColor = BAR_COLOR self._clims.setStyleSheet(SS) self._clims.setHandleLabelPosition( QLabeledRangeSlider.LabelPosition.LabelsOnHandle From b34b4dfadddcab491be1caba97d8d2ea5e950d1d Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 6 May 2024 17:47:25 -0400 Subject: [PATCH 23/73] bump superqt and add popup fps --- pyproject.toml | 2 +- .../_stack_viewer/_dims_slider.py | 89 +++++++++++-------- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ce5d6e07..796e15755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ 'fonticon-materialdesignicons6', 'pymmcore-plus[cli] >=0.9.5', 'qtpy >=2.0', - 'superqt[quantity] >=0.5.3', + 'superqt[quantity] >=0.6.5', 'useq-schema >=0.4.7', ] diff --git a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py index 6b0ddd722..d3817fa14 100644 --- a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py @@ -3,10 +3,13 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn -from qtpy.QtCore import QPointF, QSize, Qt, Signal -from qtpy.QtGui import QResizeEvent +from qtpy.QtCore import QPoint, QPointF, QSize, Qt, Signal +from qtpy.QtGui import QCursor, QResizeEvent from qtpy.QtWidgets import ( QDialog, + QDoubleSpinBox, + QFormLayout, + QFrame, QHBoxLayout, QLabel, QPushButton, @@ -24,7 +27,6 @@ from typing import Hashable, Mapping, TypeAlias from PyQt6.QtGui import QResizeEvent - from qtpy.QtGui import QKeyEvent # any hashable represent a single dimension in a AND array DimKey: TypeAlias = Hashable @@ -69,56 +71,63 @@ """ -class _DissmissableDialog(QDialog): +class QtPopup(QDialog): + """A generic popup window.""" + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self.setWindowFlags( - self.windowFlags() | Qt.WindowType.FramelessWindowHint | Qt.WindowType.Popup - ) + self.setModal(False) # if False, then clicking anywhere else closes it + self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) + + self.frame = QFrame() + layout = QVBoxLayout(self) + layout.addWidget(self.frame) + layout.setContentsMargins(0, 0, 0, 0) - def keyPressEvent(self, e: QKeyEvent | None) -> None: - if e and e.key() in (Qt.Key.Key_Enter, Qt.Key.Key_Return, Qt.Key.Key_Escape): - self.accept() - print("accept") + def show_above_mouse(self, *args: Any) -> None: + """Show popup dialog above the mouse cursor position.""" + pos = QCursor().pos() # mouse position + szhint = self.sizeHint() + pos -= QPoint(szhint.width() // 2, szhint.height() + 14) + self.move(pos) + self.resize(self.sizeHint()) + self.show() class PlayButton(QPushButton): """Just a styled QPushButton that toggles between play and pause icons.""" - fpsChanged = Signal(int) + fpsChanged = Signal(float) PLAY_ICON = "bi:play-fill" PAUSE_ICON = "bi:pause-fill" - def __init__(self, text: str = "", parent: QWidget | None = None) -> None: + def __init__(self, fps: float = 30, parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.PLAY_ICON) icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) - super().__init__(icn, text, parent) + super().__init__(icn, "", parent) + self.spin = QDoubleSpinBox() + self.spin.setRange(0.5, 100) + self.spin.setValue(fps) + self.spin.valueChanged.connect(self.fpsChanged) self.setCheckable(True) self.setFixedSize(14, 18) self.setIconSize(QSize(16, 16)) self.setStyleSheet("border: none; padding: 0; margin: 0;") - # def mousePressEvent(self, e: QMouseEvent | None) -> None: - # if e and e.button() == Qt.MouseButton.RightButton: - # self._show_fps_dialog(e.globalPosition()) - # else: - # super().mousePressEvent(e) + def mousePressEvent(self, e: Any) -> None: + if e and e.button() == Qt.MouseButton.RightButton: + self._show_fps_dialog(e.globalPosition()) + else: + super().mousePressEvent(e) def _show_fps_dialog(self, pos: QPointF) -> None: - dialog = _DissmissableDialog() - - sb = QSpinBox() - sb.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) - sb.valueChanged.connect(self.fpsChanged) - - layout = QHBoxLayout(dialog) - layout.setContentsMargins(4, 0, 4, 0) - layout.addWidget(QLabel("FPS")) - layout.addWidget(sb) - - dialog.setGeometry(int(pos.x()) - 20, int(pos.y()) - 50, 40, 40) - dialog.exec() + if not hasattr(self, "popup"): + self.popup = QtPopup(self) + form = QFormLayout(self.popup.frame) + form.setContentsMargins(6, 6, 6, 6) + form.addRow("FPS", self.spin) + self.popup.show_above_mouse() class LockButton(QPushButton): @@ -149,10 +158,10 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None super().__init__(parent) self.setStyleSheet(SS) self._slice_mode = False - self._animation_fps = 30 self._dim_key = dimension_key - self._play_btn = PlayButton() + self._timer_id: int | None = None # timer for play button + self._play_btn = PlayButton(fps=30) self._play_btn.fpsChanged.connect(self.set_fps) self._play_btn.toggled.connect(self._toggle_animation) @@ -251,14 +260,18 @@ def _set_slice_mode(self, mode: bool = True) -> None: self._slice_slider.setVisible(False) self.valueChanged.emit(self._dim_key, self.value()) - def set_fps(self, fps: int) -> None: - self._animation_fps = fps + def set_fps(self, fps: float) -> None: + self._play_btn.spin.setValue(fps) + self._toggle_animation(self._play_btn.isChecked()) def _toggle_animation(self, checked: bool) -> None: if checked: - self._timer_id = self.startTimer(1000 // self._animation_fps) - else: + if self._timer_id is not None: + self.killTimer(self._timer_id) + self._timer_id = self.startTimer(int(1000 / self._play_btn.spin.value())) + elif self._timer_id is not None: self.killTimer(self._timer_id) + self._timer_id = None def timerEvent(self, event: Any) -> None: if self._slice_mode: From d2ff45b688ab019180b201937dc7cdd23ff49929 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 6 May 2024 19:43:28 -0400 Subject: [PATCH 24/73] use async, add colors, start transform --- .../_stack_viewer/_backends/_vispy.py | 24 +++++-- .../_stack_viewer/_dims_slider.py | 8 +-- .../_stack_viewer/_indexing.py | 12 ++++ .../_stack_viewer/_stack_viewer.py | 66 +++++++++++-------- 4 files changed, 73 insertions(+), 37 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py index d7eee7533..d17e49eab 100644 --- a/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, cast import numpy as np +from superqt.utils import qthrottled from vispy import scene if TYPE_CHECKING: @@ -50,6 +51,14 @@ def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap self._image.cmap = cmap.to_vispy() + @property + def transform(self) -> np.ndarray: + raise NotImplementedError + + @transform.setter + def transform(self, transform: np.ndarray) -> None: + raise NotImplementedError + def remove(self) -> None: self._image.parent = None @@ -64,7 +73,7 @@ class VispyViewerCanvas: def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info self._canvas = scene.SceneCanvas() - self._canvas.events.mouse_move.connect(self._on_mouse_move) + self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) self._has_set_range = False @@ -110,11 +119,14 @@ def _on_mouse_move(self, event: SceneMouseEvent) -> None: # Get the images the mouse is over # FIXME: must be a better way to do this seen = set() - while visual := self._canvas.visual_at(event.pos): - if isinstance(visual, scene.visuals.Image): - images.append(visual) - visual.interactive = False - seen.add(visual) + try: + while visual := self._canvas.visual_at(event.pos): + if isinstance(visual, scene.visuals.Image): + images.append(visual) + visual.interactive = False + seen.add(visual) + except Exception: + return for visual in seen: visual.interactive = True if not images: diff --git a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py index d3817fa14..49d445b9b 100644 --- a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer/_dims_slider.py @@ -103,8 +103,8 @@ class PlayButton(QPushButton): PAUSE_ICON = "bi:pause-fill" def __init__(self, fps: float = 30, parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.PLAY_ICON) - icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On) + icn = QIconifyIcon(self.PLAY_ICON, color="#888888") + icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") super().__init__(icn, "", parent) self.spin = QDoubleSpinBox() self.spin.setRange(0.5, 100) @@ -135,8 +135,8 @@ class LockButton(QPushButton): UNLOCK_ICON = "uis:lock" def __init__(self, text: str = "", parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.LOCK_ICON) - icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On) + icn = QIconifyIcon(self.LOCK_ICON, color="#888888") + icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On, color="red") super().__init__(icn, text, parent) self.setCheckable(True) self.setFixedSize(20, 20) diff --git a/src/pymmcore_widgets/_stack_viewer/_indexing.py b/src/pymmcore_widgets/_stack_viewer/_indexing.py index f3df119ea..f28f830d5 100644 --- a/src/pymmcore_widgets/_stack_viewer/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer/_indexing.py @@ -7,6 +7,7 @@ import numpy as np if TYPE_CHECKING: + from concurrent.futures import Future from typing import Any, Protocol, TypeGuard import dask.array as da @@ -72,6 +73,17 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: raise NotImplementedError(f"Don't know how to index into type {type(store)}") +def isel_async(store: Any, indexers: Indices) -> Future[np.ndarray]: + """Asynchronous version of isel.""" + from concurrent.futures import Future + from threading import Thread + + fut: Future[np.ndarray] = Future() + thread = Thread(target=lambda: fut.set_result(isel(store, indexers))) + thread.start() + return fut + + def isel_np_array(data: SupportsIndexing, indexers: Indices) -> np.ndarray: idx = tuple(indexers.get(k, slice(None)) for k in range(len(data.shape))) return np.asarray(data[idx]) diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py index 1172fe409..1b5b64017 100644 --- a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py @@ -13,10 +13,11 @@ from ._backends import get_canvas from ._dims_slider import DimsSliders -from ._indexing import is_xarray_dataarray, isel +from ._indexing import is_xarray_dataarray, isel_async from ._lut_control import LutControl if TYPE_CHECKING: + from concurrent.futures import Future from typing import Any, Callable, Hashable, TypeAlias from ._dims_slider import DimKey, Indices, Sizes @@ -26,9 +27,10 @@ # any mapping of dimensions to sizes SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] - +MID_GRAY = "#888888" GRAYS = cmap.Colormap("gray") COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] +MAX_CHANNELS = 16 class ChannelMode(str, Enum): @@ -107,7 +109,7 @@ def __init__( # place to display dataset summary self._data_info = QElidingLabel("") # place to display arbitrary text - self._hover_info = QLabel("Info") + self._hover_info = QLabel("") # the canvas that displays the images self._canvas: PCanvas = get_canvas()(self._hover_info.setText) # the sliders that control the index of the displayed image @@ -115,8 +117,8 @@ def __init__( self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) self._lut_drop = QCollapsible("LUTs") - self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down")) - self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up")) + self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) + self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) lut_layout.setContentsMargins(0, 1, 0, 1) lut_layout.setSpacing(0) @@ -148,8 +150,8 @@ def __init__( # SETUP ------------------------------------------------------ - self.set_data(data) self.set_channel_mode(channel_mode) + self.set_data(data) # ------------------- PUBLIC API ---------------------------- @@ -158,7 +160,19 @@ def data(self) -> Any: """Return the data backing the view.""" return self._data - def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def sizes(self) -> Sizes: + """Return sizes {dimkey: int} of the dimensions in the datastore.""" + return self._sizes + + def set_data( + self, data: Any, sizes: SizesLike | None = None, channel_axis: int | None = None + ) -> None: """Set the datastore, and, optionally, the sizes of the data.""" if sizes is None: if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): @@ -167,7 +181,9 @@ def set_data(self, data: Any, sizes: SizesLike | None = None) -> None: sizes = shp self._sizes = _to_sizes(sizes) self._data = data - if self._channel_axis is None: + if channel_axis is not None: + self._channel_axis = channel_axis + elif self._channel_axis is None: self._channel_axis = self._guess_channel_axis(data) self.set_visualized_dims(list(self._sizes)[-2:]) self.update_slider_maxima() @@ -200,16 +216,6 @@ def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: for d in self._visualized_dims: self._dims_sliders.set_dimension_visible(d, False) - @property - def dims_sliders(self) -> DimsSliders: - """Return the DimsSliders widget.""" - return self._dims_sliders - - @property - def sizes(self) -> Sizes: - """Return sizes {dimkey: int} of the dimensions in the datastore.""" - return self._sizes - def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: """Set the maximum values of the sliders. @@ -235,6 +241,7 @@ def set_channel_mode(self, mode: ChannelMode | None = None) -> None: if mode is None or isinstance(mode, bool): mode = self._channel_mode_btn.mode() else: + mode = ChannelMode(mode) self._channel_mode_btn.setMode(mode) if mode == getattr(self, "_channel_mode", None): return @@ -244,15 +251,19 @@ def set_channel_mode(self, mode: ChannelMode | None = None) -> None: self._cmaps = cycle(COLORMAPS) # set the visibility of the channel slider c_visible = mode != ChannelMode.COMPOSITE - self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) + if self._channel_axis is not None: + self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) if not self._img_handles: return + self._clear_images() + if self._channel_axis is None: + return + # determine what needs to be updated n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 value = self._dims_sliders.value() # get before clearing - self._clear_images() indices = ( [value] if c_visible @@ -272,14 +283,15 @@ def setIndex(self, index: Indices) -> None: def _guess_channel_axis(self, data: Any) -> DimKey: """Guess the channel axis from the data.""" - if isinstance(data, np.ndarray): - # for numpy arrays, use the smallest dimension as the channel axis - return data.shape.index(min(data.shape)) if is_xarray_dataarray(data): for d in data.dims: if str(d).lower() in ("channel", "ch", "c"): return cast("DimKey", d) - return 0 + if isinstance(shp := getattr(data, "shape", None), Sequence): + # for numpy arrays, use the smallest dimension as the channel axis + if min(shp) <= MAX_CHANNELS: + return shp.index(min(shp)) + return None def _clear_images(self) -> None: """Remove all images from the canvas.""" @@ -306,11 +318,11 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 - def _isel(self, index: Indices) -> np.ndarray: + def _isel(self, index: Indices) -> Future[np.ndarray]: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: - return isel(self._data, idx) + return isel_async(self._data, idx) except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e @@ -337,7 +349,7 @@ def _update_data_for_index(self, index: Indices) -> None: the image handle(s) with the new data. """ imkey = self._image_key(index) - data = self._isel(index).squeeze() + data = self._isel(index).result().squeeze() data = self._reduce_dims_for_display(data) if handles := self._img_handles[imkey]: for handle in handles: From c3ae898ea86cfff8dde51e6a75faf1bd813628a8 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 10:30:29 -0400 Subject: [PATCH 25/73] fix typo --- src/pymmcore_widgets/_stack_viewer/_stack_viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py index 1b5b64017..783b50e7f 100644 --- a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py @@ -64,7 +64,7 @@ def setMode(self, mode: ChannelMode) -> None: class StackViewer(QWidget): - """A viewer for AND arrays.""" + """A viewer for ND arrays.""" def __init__( self, From 5ebdcc3b7650806c1667c0bcc0122e33e9d1285b Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 11:09:19 -0400 Subject: [PATCH 26/73] rename --- examples/mda_viewer.py | 2 +- examples/stack_viewer_numpy.py | 2 +- .../{_stack_viewer => _stack_viewer_v2}/__init__.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_backends/__init__.py | 2 +- .../{_stack_viewer => _stack_viewer_v2}/_backends/_pygfx.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_backends/_vispy.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_dims_slider.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_indexing.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_lut_control.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_mda_viewer.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_protocols.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_save_button.py | 0 .../{_stack_viewer => _stack_viewer_v2}/_stack_viewer.py | 0 src/pymmcore_widgets/experimental.py | 2 +- 14 files changed, 4 insertions(+), 4 deletions(-) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/__init__.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_backends/__init__.py (93%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_backends/_pygfx.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_backends/_vispy.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_dims_slider.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_indexing.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_lut_control.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_mda_viewer.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_protocols.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_save_button.py (100%) rename src/pymmcore_widgets/{_stack_viewer => _stack_viewer_v2}/_stack_viewer.py (100%) diff --git a/examples/mda_viewer.py b/examples/mda_viewer.py index 8ff12df56..26a32a2ac 100644 --- a/examples/mda_viewer.py +++ b/examples/mda_viewer.py @@ -4,7 +4,7 @@ from qtpy import QtWidgets from useq import MDASequence -from pymmcore_widgets._stack_viewer._mda_viewer import MDAViewer +from pymmcore_widgets._stack_viewer_v2._mda_viewer import MDAViewer configure_logging(stderr_level="WARNING") diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer_numpy.py index e49ed2163..3a9a6ce6e 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer_numpy.py @@ -3,7 +3,7 @@ import numpy as np from qtpy import QtWidgets -from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer_v2._stack_viewer import StackViewer def generate_5d_sine_wave( diff --git a/src/pymmcore_widgets/_stack_viewer/__init__.py b/src/pymmcore_widgets/_stack_viewer_v2/__init__.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/__init__.py rename to src/pymmcore_widgets/_stack_viewer_v2/__init__.py diff --git a/src/pymmcore_widgets/_stack_viewer/_backends/__init__.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py similarity index 93% rename from src/pymmcore_widgets/_stack_viewer/_backends/__init__.py rename to src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py index 045d85fd1..9650021f9 100644 --- a/src/pymmcore_widgets/_stack_viewer/_backends/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from pymmcore_widgets._stack_viewer._protocols import PCanvas + from pymmcore_widgets._stack_viewer_v2._protocols import PCanvas def get_canvas(backend: str | None = None) -> type[PCanvas]: diff --git a/src/pymmcore_widgets/_stack_viewer/_backends/_pygfx.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_backends/_pygfx.py rename to src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py diff --git a/src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_backends/_vispy.py rename to src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py diff --git a/src/pymmcore_widgets/_stack_viewer/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_dims_slider.py rename to src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py diff --git a/src/pymmcore_widgets/_stack_viewer/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_indexing.py rename to src/pymmcore_widgets/_stack_viewer_v2/_indexing.py diff --git a/src/pymmcore_widgets/_stack_viewer/_lut_control.py b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_lut_control.py rename to src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py diff --git a/src/pymmcore_widgets/_stack_viewer/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_mda_viewer.py rename to src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py diff --git a/src/pymmcore_widgets/_stack_viewer/_protocols.py b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_protocols.py rename to src/pymmcore_widgets/_stack_viewer_v2/_protocols.py diff --git a/src/pymmcore_widgets/_stack_viewer/_save_button.py b/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_save_button.py rename to src/pymmcore_widgets/_stack_viewer_v2/_save_button.py diff --git a/src/pymmcore_widgets/_stack_viewer/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py similarity index 100% rename from src/pymmcore_widgets/_stack_viewer/_stack_viewer.py rename to src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py diff --git a/src/pymmcore_widgets/experimental.py b/src/pymmcore_widgets/experimental.py index c17acb003..f09378e0c 100644 --- a/src/pymmcore_widgets/experimental.py +++ b/src/pymmcore_widgets/experimental.py @@ -1,4 +1,4 @@ -from ._stack_viewer import MDAViewer from ._stack_viewer_v1 import StackViewer +from ._stack_viewer_v2 import MDAViewer __all__ = ["StackViewer", "MDAViewer"] From ca4347d215dc89394454a1c68e04688eabad4834 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 15:41:23 -0400 Subject: [PATCH 27/73] hide sliders of size=1 --- .../_stack_viewer_v2/_dims_slider.py | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 49d445b9b..e91564256 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -102,7 +102,7 @@ class PlayButton(QPushButton): PLAY_ICON = "bi:play-fill" PAUSE_ICON = "bi:pause-fill" - def __init__(self, fps: float = 30, parent: QWidget | None = None) -> None: + def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.PLAY_ICON, color="#888888") icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") super().__init__(icn, "", parent) @@ -161,10 +161,11 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None self._dim_key = dimension_key self._timer_id: int | None = None # timer for play button - self._play_btn = PlayButton(fps=30) + self._play_btn = PlayButton() self._play_btn.fpsChanged.connect(self.set_fps) self._play_btn.toggled.connect(self._toggle_animation) + self._dim_key = dimension_key self._dim_label = QElidingLabel(str(dimension_key).upper()) # note, this lock button only prevents the slider from updating programmatically @@ -180,10 +181,9 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None ) self._out_of_label = QLabel() - self._int_slider = QSlider(Qt.Orientation.Horizontal, parent=self) + self._int_slider = QSlider(Qt.Orientation.Horizontal) self._int_slider.rangeChanged.connect(self._on_range_changed) self._int_slider.valueChanged.connect(self._on_int_value_changed) - # self._int_slider.layout().addWidget(self._max_label) self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) @@ -251,13 +251,10 @@ def forceValue(self, val: Index) -> None: def _set_slice_mode(self, mode: bool = True) -> None: if mode == self._slice_mode: return - self._slice_mode = mode - if mode: - self._slice_slider.setVisible(True) - self._int_slider.setVisible(False) - else: - self._int_slider.setVisible(True) - self._slice_slider.setVisible(False) + self._slice_mode = bool(mode) + self._slice_slider.setVisible(self._slice_mode) + self._int_slider.setVisible(not self._slice_mode) + # self._pos_label.setVisible(not self._slice_mode) self.valueChanged.emit(self._dim_key, self.value()) def set_fps(self, fps: float) -> None: @@ -288,7 +285,7 @@ def timerEvent(self, event: Any) -> None: def _on_pos_label_edited(self) -> None: if self._slice_mode: self._slice_slider.setValue( - (self._pos_label.value(), self._pos_label.value() + 1) + (self._slice_slider.value()[0], self._pos_label.value()) ) else: self._int_slider.setValue(self._pos_label.value()) @@ -297,6 +294,17 @@ def _on_range_changed(self, min: int, max: int) -> None: self._out_of_label.setText(f"| {max}") self._pos_label.setRange(min, max) self.resizeEvent(None) + self.setVisible(min != max) + + def setVisible(self, visible: bool) -> None: + if self._has_no_range(): + visible = False + super().setVisible(visible) + + def _has_no_range(self) -> bool: + if self._slice_mode: + return bool(self._slice_slider.minimum() == self._slice_slider.maximum()) + return bool(self._int_slider.minimum() == self._int_slider.maximum()) def _on_int_value_changed(self, value: int) -> None: self._pos_label.setValue(value) @@ -304,6 +312,7 @@ def _on_int_value_changed(self, value: int) -> None: self.valueChanged.emit(self._dim_key, value) def _on_slice_value_changed(self, value: tuple[int, int]) -> None: + self._pos_label.setValue(int(value[1])) if self._slice_mode: self.valueChanged.emit(self._dim_key, slice(*value)) @@ -365,12 +374,14 @@ def add_dimension(self, name: DimKey, val: Index | None = None) -> None: else: slider._lock_btn.setVisible(bool(self._locks_visible)) - slider.setRange(0, 1) + val_int = val.start if isinstance(val, slice) else val + slider.setVisible(name not in self._invisible_dims) + slider.setRange(0, val_int if isinstance(val_int, int) else 0) + val = val if val is not None else 0 self._current_index[name] = val slider.forceValue(val) slider.valueChanged.connect(self._on_dim_slider_value_changed) - slider.setVisible(name not in self._invisible_dims) cast("QVBoxLayout", self.layout()).addWidget(slider) def set_dimension_visible(self, key: DimKey, visible: bool) -> None: From f7cddf591b47ed4a188416367119ea7739f24335 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 16:30:49 -0400 Subject: [PATCH 28/73] minor typing --- .../_stack_viewer_v2/_mda_viewer.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 5d73e84be..7ebeff641 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -10,23 +10,33 @@ from ._stack_viewer import StackViewer if TYPE_CHECKING: + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase from qtpy.QtWidgets import QWidget class MDAViewer(StackViewer): """StackViewer specialized for pymmcore-plus MDA acquisitions.""" - def __init__(self, datastore: Any = None, *, parent: QWidget | None = None): - if datastore is None: - from pymmcore_plus.mda.handlers import OMEZarrWriter + _data: _5DWriterBase + + def __init__( + self, datastore: _5DWriterBase | None = None, *, parent: QWidget | None = None + ): + from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + if datastore is None: datastore = OMEZarrWriter() + elif not isinstance(datastore, (OMEZarrWriter, OMETiffWriter)): + raise TypeError( + "MDAViewer currently only supports _5DWriterBase datastores." + ) # patch the frameReady method to call the superframeReady method # AFTER handling the event self._superframeReady = getattr(datastore, "frameReady", None) if callable(self._superframeReady): - datastore.frameReady = self._patched_frame_ready + datastore.frameReady = self._patched_frame_ready # type: ignore + else: # pragma: no cover warnings.warn( "MDAViewer: datastore does not have a frameReady method to patch, " @@ -39,6 +49,10 @@ def __init__(self, datastore: Any = None, *, parent: QWidget | None = None): self._btns.addWidget(self._save_btn) self.dims_sliders.set_locks_visible(True) + @property + def data(self) -> _5DWriterBase: + return self._data + def _patched_frame_ready(self, *args: Any) -> None: self._superframeReady(*args) # type: ignore if len(args) >= 2 and isinstance(e := args[1], useq.MDAEvent): From 5aae2af0ca547f9c16903f6d3210ee5ad4609100 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 17:35:18 -0400 Subject: [PATCH 29/73] additive mode --- src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py index bbd1eedab..37fe110b4 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py @@ -76,7 +76,8 @@ def __init__(self, set_info: Callable[[str], None]) -> None: self._canvas = _QWgpuCanvas(size=(512, 512)) self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) - self._renderer.blend_mode = "weighted" + # requires https://github.com/pygfx/pygfx/pull/752 + self._renderer.blend_mode = "additive" self._scene = pygfx.Scene() self._camera = cam = pygfx.OrthographicCamera(512, 512) cam.local.scale_y = -1 From 31cbf6408b0a6b72d6cbde62806b364fe04cb92b Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 18:14:25 -0400 Subject: [PATCH 30/73] futures --- examples/mda_viewer.py | 2 +- .../_stack_viewer_v2/_indexing.py | 17 +++++---- .../_stack_viewer_v2/_stack_viewer.py | 35 ++++++++++++++----- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/examples/mda_viewer.py b/examples/mda_viewer.py index 26a32a2ac..183bcec41 100644 --- a/examples/mda_viewer.py +++ b/examples/mda_viewer.py @@ -14,7 +14,7 @@ mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Noise") sequence = MDASequence( - channels=({"config": "DAPI", "exposure": 5}, {"config": "FITC", "exposure": 20}), + channels=({"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}), stage_positions=[(0, 0), (1, 1)], z_plan={"range": 9, "step": 0.4}, time_plan={"interval": 0.2, "loops": 4}, diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index f28f830d5..599644273 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -2,12 +2,14 @@ import sys import warnings +from concurrent.futures import Future, InvalidStateError +from contextlib import suppress +from threading import Thread from typing import TYPE_CHECKING, cast import numpy as np if TYPE_CHECKING: - from concurrent.futures import Future from typing import Any, Protocol, TypeGuard import dask.array as da @@ -73,13 +75,16 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: raise NotImplementedError(f"Don't know how to index into type {type(store)}") -def isel_async(store: Any, indexers: Indices) -> Future[np.ndarray]: +def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: """Asynchronous version of isel.""" - from concurrent.futures import Future - from threading import Thread + fut: Future[tuple[Indices, np.ndarray]] = Future() - fut: Future[np.ndarray] = Future() - thread = Thread(target=lambda: fut.set_result(isel(store, indexers))) + def _thread_target() -> None: + data = isel(store, indexers) + with suppress(InvalidStateError): + fut.set_result((indexers, data)) + + thread = Thread(target=_thread_target) thread.start() return fut diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 783b50e7f..d84bda8d8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -9,7 +9,7 @@ import cmap import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QCollapsible, QElidingLabel, QIconifyIcon +from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from ._backends import get_canvas from ._dims_slider import DimsSliders @@ -95,6 +95,8 @@ def __init__( # TODO: allow user to set this self._cmaps = cycle(COLORMAPS) + self._data_future: Future[tuple[Indices, np.ndarray]] | None = None + # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels @@ -318,7 +320,7 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 - def _isel(self, index: Indices) -> Future[np.ndarray]: + def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: @@ -342,14 +344,16 @@ def _on_dims_sliders_changed(self, index: Indices) -> None: self._update_data_for_index(idx) self._canvas.refresh() - def _update_data_for_index(self, index: Indices) -> None: - """Update the displayed image for the given index. + @ensure_main_thread + def _on_data_future_done(self, future: Future[tuple[Indices, np.ndarray]]) -> None: + """Update the displayed image for the given index.""" + if future.cancelled(): + print(">>> was cancelled, do nothing") + return - This will pull the data from the datastore using the given index, and update - the image handle(s) with the new data. - """ + print(">>> yay, plotting") + index, data = future.result() imkey = self._image_key(index) - data = self._isel(index).result().squeeze() data = self._reduce_dims_for_display(data) if handles := self._img_handles[imkey]: for handle in handles: @@ -369,6 +373,20 @@ def _update_data_for_index(self, index: Indices) -> None: c.update_autoscale() self._lut_drop.addWidget(c) + def _update_data_for_index(self, index: Indices) -> None: + """Update the displayed image for the given index. + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. + """ + # if we're still processing a previous request, cancel it + if self._data_future is not None: + print("CANCEL") + self._data_future.cancel() + + self._data_future = df = self._isel(index) + df.add_done_callback(self._on_data_future_done) + def _reduce_dims_for_display( self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max ) -> np.ndarray: @@ -382,6 +400,7 @@ def _reduce_dims_for_display( # - allow for 3d data # - allow dimensions to control how they are reduced # - for better way to determine which dims need to be reduced + data = data.squeeze() visualized_dims = 2 if extra_dims := data.ndim - visualized_dims: shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) From fb66e29bd0f9414ce4235b706c1b8f2ddcb6d754 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 18:45:31 -0400 Subject: [PATCH 31/73] more futurestuff --- .../_stack_viewer_v2/_stack_viewer.py | 96 +++++++++---------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index d84bda8d8..f3e60d54c 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from enum import Enum from itertools import cycle @@ -31,6 +30,7 @@ GRAYS = cmap.Colormap("gray") COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] MAX_CHANNELS = 16 +ALL_CHANNELS = slice(None) class ChannelMode(str, Enum): @@ -263,19 +263,20 @@ def set_channel_mode(self, mode: ChannelMode | None = None) -> None: if self._channel_axis is None: return - # determine what needs to be updated - n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 - value = self._dims_sliders.value() # get before clearing - indices = ( - [value] - if c_visible - else [{**value, self._channel_axis: i} for i in range(n_channels)] - ) - - # update the displayed images - for idx in indices: - self._update_data_for_index(idx) - self._canvas.refresh() + self.setIndex({}) + # # determine what needs to be updated + # n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 + # value = self._dims_sliders.value() # get before clearing + # indices = ( + # [value] + # if c_visible + # else [{**value, self._channel_axis: i} for i in range(n_channels)] + # ) + + # # update the displayed images + # for idx in indices: + # self._update_data_for_index(idx) + # self._canvas.refresh() def setIndex(self, index: Indices) -> None: """Set the index of the displayed image.""" @@ -330,48 +331,44 @@ def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: def _on_dims_sliders_changed(self, index: Indices) -> None: """Update the displayed image when the sliders are changed.""" - c = index.get(self._channel_axis, 0) - indices: list[Indices] = [index] - if self._channel_mode == ChannelMode.COMPOSITE: - for i, handles in self._img_handles.items(): - if isinstance(i, (int, slice)): - if handles and c != i: - indices.append({**index, self._channel_axis: i}) - else: # pragma: no cover - warnings.warn(f"Invalid key for composite image: {i}", stacklevel=2) - - for idx in indices: - self._update_data_for_index(idx) - self._canvas.refresh() + if ( + self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis is not None + ): + index = {**index, self._channel_axis: ALL_CHANNELS} + self._update_data_for_index(index) - @ensure_main_thread + @ensure_main_thread # type: ignore def _on_data_future_done(self, future: Future[tuple[Indices, np.ndarray]]) -> None: """Update the displayed image for the given index.""" if future.cancelled(): - print(">>> was cancelled, do nothing") return - print(">>> yay, plotting") index, data = future.result() - imkey = self._image_key(index) - data = self._reduce_dims_for_display(data) - if handles := self._img_handles[imkey]: - for handle in handles: - handle.data = data - if ctrl := self._lut_ctrls.get(imkey, None): - ctrl.update_autoscale() - else: - cm = ( - next(self._cmaps) - if self._channel_mode == ChannelMode.COMPOSITE - else GRAYS - ) - handles.append(self._canvas.add_image(data, cmap=cm)) - if imkey not in self._lut_ctrls: - channel_name = f"Ch {imkey}" # TODO: get name from user - self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) - c.update_autoscale() - self._lut_drop.addWidget(c) + # assume that if we have channels remaining, that they are the first axis + # FIXME: this is a bad assumption + data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] + for i, datum in enumerate(data): + imkey = self._image_key({**index, self._channel_axis: i}) + datum = self._reduce_dims_for_display(datum) + if handles := self._img_handles[imkey]: + for handle in handles: + handle.data = datum + if ctrl := self._lut_ctrls.get(imkey, None): + ctrl.update_autoscale() + else: + cm = ( + next(self._cmaps) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS + ) + handles.append(self._canvas.add_image(datum, cmap=cm)) + if imkey not in self._lut_ctrls: + channel_name = f"Ch {imkey}" # TODO: get name from user + self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) + c.update_autoscale() + self._lut_drop.addWidget(c) + self._canvas.refresh() def _update_data_for_index(self, index: Indices) -> None: """Update the displayed image for the given index. @@ -381,7 +378,6 @@ def _update_data_for_index(self, index: Indices) -> None: """ # if we're still processing a previous request, cancel it if self._data_future is not None: - print("CANCEL") self._data_future.cancel() self._data_future = df = self._isel(index) From 5eb54627230facd07bf5f0854f5075c8dff25aae Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 7 May 2024 20:40:33 -0400 Subject: [PATCH 32/73] more cleanup --- examples/mda_viewer.py | 2 +- examples/stack_viewer_dask.py | 26 +++ .../_stack_viewer_v2/_dims_slider.py | 106 +++++++--- .../_stack_viewer_v2/_lut_control.py | 2 + .../_stack_viewer_v2/_stack_viewer.py | 190 +++++++++--------- tests/test_stack_viewer2.py | 44 ++++ 6 files changed, 252 insertions(+), 118 deletions(-) create mode 100644 examples/stack_viewer_dask.py create mode 100644 tests/test_stack_viewer2.py diff --git a/examples/mda_viewer.py b/examples/mda_viewer.py index 183bcec41..664dcd625 100644 --- a/examples/mda_viewer.py +++ b/examples/mda_viewer.py @@ -11,7 +11,7 @@ mmcore = CMMCorePlus.instance() mmcore.loadSystemConfiguration() mmcore.defineConfig("Channel", "DAPI", "Camera", "Mode", "Artificial Waves") -mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Noise") +mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Color Test Pattern") sequence = MDASequence( channels=({"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}), diff --git a/examples/stack_viewer_dask.py b/examples/stack_viewer_dask.py new file mode 100644 index 000000000..a7b908ae6 --- /dev/null +++ b/examples/stack_viewer_dask.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import numpy as np +from dask.array.core import map_blocks +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer + + +def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: + if isinstance(block_id, np.ndarray): + return None + data = np.random.randint(0, 255, size=(1000, 1000), dtype=np.uint8) + return data[(None,) * 3] + + +shape = (1000, 64, 3, 512, 512) +chunks = [(1,) * x for x in shape[:-2]] +chunks += [(x,) for x in shape[-2:]] +dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(dask_arr) + v.show() + qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index e91564256..c26dbb1ec 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from typing import TYPE_CHECKING, Any, cast from warnings import warn @@ -167,6 +168,7 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None self._dim_key = dimension_key self._dim_label = QElidingLabel(str(dimension_key).upper()) + self._dim_label.setToolTip("Double-click to toggle slice mode") # note, this lock button only prevents the slider from updating programmatically # using self.setValue, it doesn't prevent the user from changing the value. @@ -219,6 +221,12 @@ def setMaximum(self, max_val: int) -> None: if max_val > self._slice_slider.maximum(): self._slice_slider.setMaximum(max_val) + def setMinimum(self, min_val: int) -> None: + if min_val < self._int_slider.minimum(): + self._int_slider.setMinimum(min_val) + if min_val < self._slice_slider.minimum(): + self._slice_slider.setMinimum(min_val) + def setRange(self, min_val: int, max_val: int) -> None: self._int_slider.setRange(min_val, max_val) self._slice_slider.setRange(min_val, max_val) @@ -237,8 +245,11 @@ def setValue(self, val: Index) -> None: if self._lock_btn.isChecked(): return if isinstance(val, slice): - self._slice_slider.setValue((val.start, val.stop)) - # self._int_slider.setValue(int((val.stop + val.start) / 2)) + start = int(val.start) if val.start is not None else 0 + stop = ( + int(val.stop) if val.stop is not None else self._slice_slider.maximum() + ) + self._slice_slider.setValue((start, stop)) else: self._int_slider.setValue(val) # self._slice_slider.setValue((val, val + 1)) @@ -265,21 +276,29 @@ def _toggle_animation(self, checked: bool) -> None: if checked: if self._timer_id is not None: self.killTimer(self._timer_id) - self._timer_id = self.startTimer(int(1000 / self._play_btn.spin.value())) + interval = int(1000 / self._play_btn.spin.value()) + self._timer_id = self.startTimer(interval) elif self._timer_id is not None: self.killTimer(self._timer_id) self._timer_id = None def timerEvent(self, event: Any) -> None: + """Handle timer event for play button, move to the next frame.""" + # TODO + # for now just increment the value by 1, but we should be able to + # take FPS into account better and skip additional frames if the timerEvent + # is delayed for some reason. + inc = 1 if self._slice_mode: val = cast(tuple[int, int], self._slice_slider.value()) - next_val = [v + 1 for v in val] + next_val = [v + inc for v in val] if next_val[1] > self._slice_slider.maximum(): + # wrap around, without going below the min handle next_val = [v - val[0] for v in val] self._slice_slider.setValue(next_val) else: ival = self._int_slider.value() - ival = (ival + 1) % (self._int_slider.maximum() + 1) + ival = (ival + inc) % (self._int_slider.maximum() + 1) self._int_slider.setValue(ival) def _on_pos_label_edited(self) -> None: @@ -313,6 +332,8 @@ def _on_int_value_changed(self, value: int) -> None: def _on_slice_value_changed(self, value: tuple[int, int]) -> None: self._pos_label.setValue(int(value[1])) + with signals_blocked(self._int_slider): + self._int_slider.setValue(int(value[0])) if self._slice_mode: self.valueChanged.emit(self._dim_key, slice(*value)) @@ -338,13 +359,20 @@ def __init__(self, parent: QWidget | None = None) -> None: layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) - def sizeHint(self) -> QSize: - return super().sizeHint().boundedTo(QSize(9999, 0)) - def value(self) -> Indices: + """Return mapping of {dim_key -> current index} for each dimension.""" return self._current_index.copy() def setValue(self, values: Indices) -> None: + """Set the current index for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int | slice] + Mapping of {dim_key -> index} for each dimension. If value is a slice, + the slider will be in slice mode. If the dimension is not present in the + DimsSliders, it will be added. + """ if self._current_index == values: return with signals_blocked(self): @@ -352,22 +380,58 @@ def setValue(self, values: Indices) -> None: self.add_or_update_dimension(dim, index) self.valueChanged.emit(self.value()) - def maximum(self) -> Sizes: + def minima(self) -> Sizes: + """Return mapping of {dim_key -> minimum value} for each dimension.""" + return {k: v._int_slider.minimum() for k, v in self._sliders.items()} + + def setMinima(self, values: Sizes) -> None: + """Set the minimum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> minimum value} for each dimension. + """ + for name, min_val in values.items(): + if name not in self._sliders: + self.add_dimension(name) + self._sliders[name].setMinimum(min_val) + + def maxima(self) -> Sizes: + """Return mapping of {dim_key -> maximum value} for each dimension.""" return {k: v._int_slider.maximum() for k, v in self._sliders.items()} - def setMaximum(self, values: Sizes) -> None: + def setMaxima(self, values: Sizes) -> None: + """Set the maximum value for each dimension. + + Parameters + ---------- + values : Mapping[Hashable, int] + Mapping of {dim_key -> maximum value} for each dimension. + """ for name, max_val in values.items(): if name not in self._sliders: self.add_dimension(name) self._sliders[name].setMaximum(max_val) def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: + """Set the visibility of the lock buttons for all dimensions.""" self._locks_visible = visible for dim, slider in self._sliders.items(): viz = visible if isinstance(visible, bool) else visible.get(dim, False) slider._lock_btn.setVisible(viz) def add_dimension(self, name: DimKey, val: Index | None = None) -> None: + """Add a new dimension to the DimsSliders widget. + + Parameters + ---------- + name : Hashable + The name of the dimension. + val : int | slice, optional + The initial value for the dimension. If a slice, the slider will be in + slice mode. + """ self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) if isinstance(self._locks_visible, dict) and name in self._locks_visible: slider._lock_btn.setVisible(self._locks_visible[name]) @@ -385,6 +449,11 @@ def add_dimension(self, name: DimKey, val: Index | None = None) -> None: cast("QVBoxLayout", self.layout()).addWidget(slider) def set_dimension_visible(self, key: DimKey, visible: bool) -> None: + """Set the visibility of a dimension in the DimsSliders widget. + + Once a dimension is hidden, it will not be shown again until it is explicitly + made visible again with this method. + """ if visible: self._invisible_dims.discard(key) else: @@ -393,6 +462,7 @@ def set_dimension_visible(self, key: DimKey, visible: bool) -> None: self._sliders[key].setVisible(visible) def remove_dimension(self, key: DimKey) -> None: + """Remove a dimension from the DimsSliders widget.""" try: slider = self._sliders.pop(key) except KeyError: @@ -406,6 +476,7 @@ def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: self.valueChanged.emit(self.value()) def add_or_update_dimension(self, key: DimKey, value: Index) -> None: + """Add a dimension if it doesn't exist, otherwise update the value.""" if key in self._sliders: self._sliders[key].forceValue(value) else: @@ -421,16 +492,5 @@ def resizeEvent(self, a0: QResizeEvent | None) -> None: super().resizeEvent(a0) - -if __name__ == "__main__": - from qtpy.QtWidgets import QApplication - - app = QApplication([]) - w = DimsSliders() - w.add_dimension("x", 5) - w.add_dimension("ysadfdasas", 20) - w.add_dimension("z", slice(10, 20)) - w.add_dimension("w", 10) - w.valueChanged.connect(print) - w.show() - app.exec() + def sizeHint(self) -> QSize: + return super().sizeHint().boundedTo(QSize(9999, 0)) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py index 2a1aa27c1..91483b8e3 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py @@ -74,6 +74,8 @@ def __init__( layout.addWidget(self._clims) layout.addWidget(self._auto_clim) + self.update_autoscale() + def autoscaleChecked(self) -> bool: return cast("bool", self._auto_clim.isChecked()) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index f3e60d54c..f860d4bd8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -116,7 +116,7 @@ def __init__( self._canvas: PCanvas = get_canvas()(self._hover_info.setText) # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders() - self._dims_sliders.valueChanged.connect(self._on_dims_sliders_changed) + self._dims_sliders.valueChanged.connect(self._update_data_for_index) self._lut_drop = QCollapsible("LUTs") self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) @@ -188,7 +188,7 @@ def set_data( elif self._channel_axis is None: self._channel_axis = self._guess_channel_axis(data) self.set_visualized_dims(list(self._sizes)[-2:]) - self.update_slider_maxima() + self.update_slider_ranges() self.setIndex({}) package = getattr(data, "__module__", "").split(".")[0] @@ -218,18 +218,23 @@ def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: for d in self._visualized_dims: self._dims_sliders.set_dimension_visible(d, False) - def update_slider_maxima(self, sizes: SizesLike | None = None) -> None: + def update_slider_ranges( + self, mins: SizesLike | None = None, maxes: SizesLike | None = None + ) -> None: """Set the maximum values of the sliders. If `sizes` is not provided, sizes will be inferred from the datastore. + This is mostly here as a public way to reset the """ - if sizes is None: - sizes = self.sizes - sizes = _to_sizes(sizes) - self._dims_sliders.setMaximum({k: v - 1 for k, v in sizes.items()}) + if maxes is None: + maxes = self.sizes + maxes = _to_sizes(maxes) + self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) + if mins is not None: + self._dims_sliders.setMinima(_to_sizes(mins)) # FIXME: this needs to be moved and made user-controlled - for dim in list(sizes.values())[-2:]: + for dim in list(maxes.values())[-2:]: self._dims_sliders.set_dimension_visible(dim, False) def set_channel_mode(self, mode: ChannelMode | None = None) -> None: @@ -249,34 +254,16 @@ def set_channel_mode(self, mode: ChannelMode | None = None) -> None: return self._channel_mode = mode - # reset the colormap cycle - self._cmaps = cycle(COLORMAPS) - # set the visibility of the channel slider - c_visible = mode != ChannelMode.COMPOSITE + self._cmaps = cycle(COLORMAPS) # reset the colormap cycle if self._channel_axis is not None: - self._dims_sliders.set_dimension_visible(self._channel_axis, c_visible) + # set the visibility of the channel slider + self._dims_sliders.set_dimension_visible( + self._channel_axis, mode != ChannelMode.COMPOSITE + ) - if not self._img_handles: - return - - self._clear_images() - if self._channel_axis is None: - return - - self.setIndex({}) - # # determine what needs to be updated - # n_channels = self._dims_sliders.maximum().get(self._channel_axis, -1) + 1 - # value = self._dims_sliders.value() # get before clearing - # indices = ( - # [value] - # if c_visible - # else [{**value, self._channel_axis: i} for i in range(n_channels)] - # ) - - # # update the displayed images - # for idx in indices: - # self._update_data_for_index(idx) - # self._canvas.refresh() + if self._img_handles: + self._clear_images() + self._update_data_for_index(self._dims_sliders.value()) def setIndex(self, index: Indices) -> None: """Set the index of the displayed image.""" @@ -284,7 +271,7 @@ def setIndex(self, index: Indices) -> None: # ------------------- PRIVATE METHODS ---------------------------- - def _guess_channel_axis(self, data: Any) -> DimKey: + def _guess_channel_axis(self, data: Any) -> DimKey | None: """Guess the channel axis from the data.""" if is_xarray_dataarray(data): for d in data.dims: @@ -296,20 +283,8 @@ def _guess_channel_axis(self, data: Any) -> DimKey: return shp.index(min(shp)) return None - def _clear_images(self) -> None: - """Remove all images from the canvas.""" - for handles in self._img_handles.values(): - for handle in handles: - handle.remove() - self._img_handles.clear() - - # clear the current LutControls as well - for c in self._lut_ctrls.values(): - cast("QVBoxLayout", self.layout()).removeWidget(c) - c.deleteLater() - self._lut_ctrls.clear() - def _on_set_range_clicked(self) -> None: + # using method to swallow the parameter passed by _set_range_btn.clicked self._canvas.set_range() def _image_key(self, index: Indices) -> ImgKey: @@ -321,6 +296,24 @@ def _image_key(self, index: Indices) -> ImgKey: return val return 0 + def _update_data_for_index(self, index: Indices) -> None: + """Retrieve data for `index` from datastore and update canvas image(s). + + This will pull the data from the datastore using the given index, and update + the image handle(s) with the new data. This method is *asynchronous*. It + makes a request for the new data slice and queues _on_data_future_done to be + called when the data is ready. + """ + # if we're still processing a previous request, cancel it + if self._data_future is not None: + self._data_future.cancel() + + if self._channel_axis and self._channel_mode == ChannelMode.COMPOSITE: + index = {**index, self._channel_axis: ALL_CHANNELS} + + self._data_future = self._isel(index) + self._data_future.add_done_callback(self._on_data_slice_ready) + def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} @@ -329,18 +322,13 @@ def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e - def _on_dims_sliders_changed(self, index: Indices) -> None: - """Update the displayed image when the sliders are changed.""" - if ( - self._channel_mode == ChannelMode.COMPOSITE - and self._channel_axis is not None - ): - index = {**index, self._channel_axis: ALL_CHANNELS} - self._update_data_for_index(index) - @ensure_main_thread # type: ignore - def _on_data_future_done(self, future: Future[tuple[Indices, np.ndarray]]) -> None: - """Update the displayed image for the given index.""" + def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> None: + """Update the displayed image for the given index. + + Connected to the future returned by _isel. + """ + self._data_future = None if future.cancelled(): return @@ -349,41 +337,36 @@ def _on_data_future_done(self, future: Future[tuple[Indices, np.ndarray]]) -> No # FIXME: this is a bad assumption data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] for i, datum in enumerate(data): - imkey = self._image_key({**index, self._channel_axis: i}) - datum = self._reduce_dims_for_display(datum) - if handles := self._img_handles[imkey]: - for handle in handles: - handle.data = datum - if ctrl := self._lut_ctrls.get(imkey, None): - ctrl.update_autoscale() - else: - cm = ( - next(self._cmaps) - if self._channel_mode == ChannelMode.COMPOSITE - else GRAYS - ) - handles.append(self._canvas.add_image(datum, cmap=cm)) - if imkey not in self._lut_ctrls: - channel_name = f"Ch {imkey}" # TODO: get name from user - self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) - c.update_autoscale() - self._lut_drop.addWidget(c) + self._update_canvas_data(datum, {**index, self._channel_axis: i}) self._canvas.refresh() - def _update_data_for_index(self, index: Indices) -> None: - """Update the displayed image for the given index. + def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: + """Actually update the image handle(s) with the (sliced) data. - This will pull the data from the datastore using the given index, and update - the image handle(s) with the new data. + By this point, data should be sliced from the underlying datastore. Any + dimensions remaining that are more than the number of visualized dimensions + (currently just 2D) will be reduced using max intensity projection (currently). """ - # if we're still processing a previous request, cancel it - if self._data_future is not None: - self._data_future.cancel() - - self._data_future = df = self._isel(index) - df.add_done_callback(self._on_data_future_done) - - def _reduce_dims_for_display( + imkey = self._image_key(index) + datum = self._reduce_data_for_display(data) + if handles := self._img_handles[imkey]: + for handle in handles: + handle.data = datum + if ctrl := self._lut_ctrls.get(imkey, None): + ctrl.update_autoscale() + else: + cm = ( + next(self._cmaps) + if self._channel_mode == ChannelMode.COMPOSITE + else GRAYS + ) + handles.append(self._canvas.add_image(datum, cmap=cm)) + if imkey not in self._lut_ctrls: + channel_name = f"Ch {imkey}" # TODO: get name from user + self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) + self._lut_drop.addWidget(c) + + def _reduce_data_for_display( self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max ) -> np.ndarray: """Reduce the number of dimensions in the data for display. @@ -391,11 +374,14 @@ def _reduce_dims_for_display( This function takes a data array and reduces the number of dimensions to the max allowed for display. The default behavior is to reduce the smallest dimensions, using np.max. This can be improved in the future. + + This also coerces 64-bit data to 32-bit data. """ # TODO # - allow for 3d data - # - allow dimensions to control how they are reduced - # - for better way to determine which dims need to be reduced + # - allow dimensions to control how they are reduced (as opposed to just max) + # - for better way to determine which dims need to be reduced (currently just + # the smallest dims) data = data.squeeze() visualized_dims = 2 if extra_dims := data.ndim - visualized_dims: @@ -403,10 +389,26 @@ def _reduce_dims_for_display( smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) return reductor(data, axis=smallest_dims) - if data.dtype == np.float64: - data = data.astype(np.float32) + if data.dtype.itemsize > 4: # More than 32 bits + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.int32) + else: + data = data.astype(np.float32) return data + def _clear_images(self) -> None: + """Remove all images from the canvas.""" + for handles in self._img_handles.values(): + for handle in handles: + handle.remove() + self._img_handles.clear() + + # clear the current LutControls as well + for c in self._lut_ctrls.values(): + cast("QVBoxLayout", self.layout()).removeWidget(c) + c.deleteLater() + self._lut_ctrls.clear() + def _to_sizes(sizes: SizesLike | None) -> Sizes: """Coerce `sizes` to a {dimKey -> int} mapping.""" diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py new file mode 100644 index 000000000..76f22e169 --- /dev/null +++ b/tests/test_stack_viewer2.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import numpy as np +from dask.array.core import map_blocks +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer + + +def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: + if isinstance(block_id, np.ndarray): + return None + data = np.random.randint(0, 255, size=(1000, 1000), dtype=np.uint8) + return data[(None,) * 3] + + +shape = (1000, 64, 3, 512, 512) +chunks = [(1,) * x for x in shape[:-2]] +chunks += [(x,) for x in shape[-2:]] +dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) + + +def make_lazy_array( + shape: tuple[int, ...], dtype: np.dtype = np.dtype("uint8") +) -> np.ndarray: + shape[:-2] + frame_shape = shape[-2:] + + def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: + if isinstance(block_id, np.ndarray): + return None + data = np.random.rand(*frame_shape) + return data[(None,) * 3] + + chunks = [(1,) * x for x in shape[:-2]] + chunks += [(x,) for x in shape[-2:]] + return map_blocks(_dask_block, chunks=chunks, dtype=dtype) + + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(dask_arr) + v.show() + qapp.exec() From b0cb5da6ad391698561b8f9d6bac443b26ef671a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 May 2024 00:40:51 +0000 Subject: [PATCH 33/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index c26dbb1ec..49c92462c 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time from typing import TYPE_CHECKING, Any, cast from warnings import warn From f1fc4a2f31762fc4f78672d7c48e59c432410299 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 9 May 2024 08:28:30 -0400 Subject: [PATCH 34/73] changes to throttling --- examples/stack_viewer_dask.py | 2 +- .../_stack_viewer_v2/_stack_viewer.py | 20 ++++------- tests/test_stack_viewer2.py | 33 +++++-------------- 3 files changed, 17 insertions(+), 38 deletions(-) diff --git a/examples/stack_viewer_dask.py b/examples/stack_viewer_dask.py index a7b908ae6..73b9bd7ee 100644 --- a/examples/stack_viewer_dask.py +++ b/examples/stack_viewer_dask.py @@ -4,7 +4,7 @@ from dask.array.core import map_blocks from qtpy import QtWidgets -from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer_v2 import StackViewer def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index f860d4bd8..fa00bba2f 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -9,6 +9,7 @@ import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread +from superqt.utils import qthrottled from ._backends import get_canvas from ._dims_slider import DimsSliders @@ -95,8 +96,6 @@ def __init__( # TODO: allow user to set this self._cmaps = cycle(COLORMAPS) - self._data_future: Future[tuple[Indices, np.ndarray]] | None = None - # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels @@ -116,7 +115,9 @@ def __init__( self._canvas: PCanvas = get_canvas()(self._hover_info.setText) # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders() - self._dims_sliders.valueChanged.connect(self._update_data_for_index) + self._dims_sliders.valueChanged.connect( + qthrottled(self._update_data_for_index, 20, leading=True) + ) self._lut_drop = QCollapsible("LUTs") self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) @@ -234,7 +235,7 @@ def update_slider_ranges( self._dims_sliders.setMinima(_to_sizes(mins)) # FIXME: this needs to be moved and made user-controlled - for dim in list(maxes.values())[-2:]: + for dim in list(maxes.keys())[-2:]: self._dims_sliders.set_dimension_visible(dim, False) def set_channel_mode(self, mode: ChannelMode | None = None) -> None: @@ -304,15 +305,10 @@ def _update_data_for_index(self, index: Indices) -> None: makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - # if we're still processing a previous request, cancel it - if self._data_future is not None: - self._data_future.cancel() - if self._channel_axis and self._channel_mode == ChannelMode.COMPOSITE: index = {**index, self._channel_axis: ALL_CHANNELS} - self._data_future = self._isel(index) - self._data_future.add_done_callback(self._on_data_slice_ready) + self._isel(index).add_done_callback(self._on_data_slice_ready) def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" @@ -328,11 +324,9 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No Connected to the future returned by _isel. """ - self._data_future = None - if future.cancelled(): - return index, data = future.result() + print(index) # assume that if we have channels remaining, that they are the first axis # FIXME: this is a bad assumption data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py index 76f22e169..5cc0c3d87 100644 --- a/tests/test_stack_viewer2.py +++ b/tests/test_stack_viewer2.py @@ -1,44 +1,29 @@ from __future__ import annotations +import dask.array as da import numpy as np -from dask.array.core import map_blocks from qtpy import QtWidgets -from pymmcore_widgets._stack_viewer._stack_viewer import StackViewer +from pymmcore_widgets._stack_viewer_v2 import StackViewer -def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: - if isinstance(block_id, np.ndarray): - return None - data = np.random.randint(0, 255, size=(1000, 1000), dtype=np.uint8) - return data[(None,) * 3] - - -shape = (1000, 64, 3, 512, 512) -chunks = [(1,) * x for x in shape[:-2]] -chunks += [(x,) for x in shape[-2:]] -dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) - - -def make_lazy_array( - shape: tuple[int, ...], dtype: np.dtype = np.dtype("uint8") -) -> np.ndarray: - shape[:-2] +def make_lazy_array(shape: tuple[int, ...]) -> da.Array: + rest_shape = shape[:-2] frame_shape = shape[-2:] def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: if isinstance(block_id, np.ndarray): return None - data = np.random.rand(*frame_shape) - return data[(None,) * 3] + size = (1,) * len(rest_shape) + frame_shape + return np.random.randint(0, 255, size=size, dtype=np.uint8) - chunks = [(1,) * x for x in shape[:-2]] - chunks += [(x,) for x in shape[-2:]] - return map_blocks(_dask_block, chunks=chunks, dtype=dtype) + chunks = [(1,) * x for x in rest_shape] + [(x,) for x in frame_shape] + return da.map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) # type: ignore if __name__ == "__main__": qapp = QtWidgets.QApplication([]) + dask_arr = make_lazy_array((1000, 64, 3, 256, 256)) v = StackViewer(dask_arr) v.show() qapp.exec() From b4175ae6bece3282f876940d4924484a27c0a68f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 May 2024 12:29:18 +0000 Subject: [PATCH 35/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index fa00bba2f..b6b25c377 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -324,7 +324,6 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No Connected to the future returned by _isel. """ - index, data = future.result() print(index) # assume that if we have channels remaining, that they are the first axis From 90827ff69e7a695a807d175deee40c999d944d31 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 10 May 2024 08:44:08 -0400 Subject: [PATCH 36/73] change demo --- examples/mda_viewer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/mda_viewer.py b/examples/mda_viewer.py index 664dcd625..675d7b8df 100644 --- a/examples/mda_viewer.py +++ b/examples/mda_viewer.py @@ -11,7 +11,9 @@ mmcore = CMMCorePlus.instance() mmcore.loadSystemConfiguration() mmcore.defineConfig("Channel", "DAPI", "Camera", "Mode", "Artificial Waves") -mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Color Test Pattern") +mmcore.defineConfig("Channel", "DAPI", "Camera", "StripeWidth", "1") +mmcore.defineConfig("Channel", "FITC", "Camera", "Mode", "Artificial Waves") +mmcore.defineConfig("Channel", "FITC", "Camera", "StripeWidth", "4") sequence = MDASequence( channels=({"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}), From cb9d0310345b8143cf2d0668428e5a76891dbd32 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 10 May 2024 08:44:19 -0400 Subject: [PATCH 37/73] remove print --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index b6b25c377..1018c0d90 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -325,7 +325,6 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No Connected to the future returned by _isel. """ index, data = future.result() - print(index) # assume that if we have channels remaining, that they are the first axis # FIXME: this is a bad assumption data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] From c6e7ee174c935c8348f78ffbcbc38d34ab29df9f Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 10 May 2024 08:57:26 -0400 Subject: [PATCH 38/73] add tensor store --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 12 ++++++++++-- src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 599644273..131490dfb 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, cast import numpy as np +from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreWriter if TYPE_CHECKING: from typing import Any, Protocol, TypeGuard @@ -25,7 +26,7 @@ def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... def shape(self) -> tuple[int, ...]: ... -def is_pymmcore_writer(obj: Any) -> TypeGuard[_5DWriterBase]: +def is_pymmcore_5dbase(obj: Any) -> TypeGuard[_5DWriterBase]: try: from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase except ImportError: @@ -66,8 +67,10 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: For any other duck-typed array, use numpy-style indexing, where indexers is a mapping of axis to slice objects or indices. """ - if is_pymmcore_writer(store): + if is_pymmcore_5dbase(store): return isel_mmcore_5dbase(store, indexers) + if isinstance(store, TensorStoreWriter): + return isel_mmcore_tensorstore(store, indexers) if is_xarray_dataarray(store): return cast("np.ndarray", store.isel(indexers).to_numpy()) if is_duck_array(store): @@ -75,6 +78,11 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: raise NotImplementedError(f"Don't know how to index into type {type(store)}") +def isel_mmcore_tensorstore(writer: TensorStoreWriter, indexers: Indices) -> np.ndarray: + index = writer._indices[frozenset(indexers.items())] + return writer._store[index].read().result() + + def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: """Asynchronous version of isel.""" fut: Future[tuple[Indices, np.ndarray]] = Future() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 7ebeff641..9175eb24b 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -5,6 +5,7 @@ import superqt import useq +from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreWriter from ._save_button import SaveButton from ._stack_viewer import StackViewer @@ -25,7 +26,7 @@ def __init__( from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter if datastore is None: - datastore = OMEZarrWriter() + datastore = TensorStoreWriter() elif not isinstance(datastore, (OMEZarrWriter, OMETiffWriter)): raise TypeError( "MDAViewer currently only supports _5DWriterBase datastores." From 5b6aed63f971e3c3210d3b06c65ba52f4b84a28f Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 10 May 2024 18:27:13 -0400 Subject: [PATCH 39/73] wip --- examples/mda_viewer_queue.py | 24 +++++++++++++++++++ .../_stack_viewer_v2/_indexing.py | 10 +++++--- .../_stack_viewer_v2/_mda_viewer.py | 7 +++--- 3 files changed, 34 insertions(+), 7 deletions(-) create mode 100644 examples/mda_viewer_queue.py diff --git a/examples/mda_viewer_queue.py b/examples/mda_viewer_queue.py new file mode 100644 index 000000000..0e1b95c4e --- /dev/null +++ b/examples/mda_viewer_queue.py @@ -0,0 +1,24 @@ +import sys +from queue import Queue + +from pymmcore_plus import CMMCorePlus +from qtpy import QtWidgets +from useq import MDAEvent + +from pymmcore_widgets._stack_viewer_v2._mda_viewer import MDAViewer + +app = QtWidgets.QApplication(sys.argv) +mmcore = CMMCorePlus.instance() +mmcore.loadSystemConfiguration() + +canvas = MDAViewer() +canvas.show() + +q = Queue() +mmcore.run_mda(iter(q.get, None), output=canvas.data) +for i in range(10): + for c in range(2): + q.put(MDAEvent(index={"t": i, "c": c}, exposure=1)) +q.put(None) + +app.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 131490dfb..c99cf6e29 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast import numpy as np -from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreWriter +from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreHandler if TYPE_CHECKING: from typing import Any, Protocol, TypeGuard @@ -60,6 +60,8 @@ def is_duck_array(obj: Any) -> TypeGuard[SupportsIndexing]: return False +# TODO: Change this factory function on a wrapper class so we +# don't have to check the type of the object every time we call def isel(store: Any, indexers: Indices) -> np.ndarray: """Select a slice from a data store using (possibly) named indices. @@ -69,7 +71,7 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: """ if is_pymmcore_5dbase(store): return isel_mmcore_5dbase(store, indexers) - if isinstance(store, TensorStoreWriter): + if isinstance(store, TensorStoreHandler): return isel_mmcore_tensorstore(store, indexers) if is_xarray_dataarray(store): return cast("np.ndarray", store.isel(indexers).to_numpy()) @@ -78,7 +80,9 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: raise NotImplementedError(f"Don't know how to index into type {type(store)}") -def isel_mmcore_tensorstore(writer: TensorStoreWriter, indexers: Indices) -> np.ndarray: +def isel_mmcore_tensorstore( + writer: TensorStoreHandler, indexers: Indices +) -> np.ndarray: index = writer._indices[frozenset(indexers.items())] return writer._store[index].read().result() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 9175eb24b..df7cbeee7 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -5,7 +5,8 @@ import superqt import useq -from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreWriter +from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter +from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreHandler from ._save_button import SaveButton from ._stack_viewer import StackViewer @@ -23,10 +24,8 @@ class MDAViewer(StackViewer): def __init__( self, datastore: _5DWriterBase | None = None, *, parent: QWidget | None = None ): - from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - if datastore is None: - datastore = TensorStoreWriter() + datastore = TensorStoreHandler("datastore") elif not isinstance(datastore, (OMEZarrWriter, OMETiffWriter)): raise TypeError( "MDAViewer currently only supports _5DWriterBase datastores." From f100cf9f8f4d881aaa02a33f1c5f229e47c96249 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 11 May 2024 17:52:08 -0400 Subject: [PATCH 40/73] use tensorstore backingh --- examples/stack_viewer_tensorstore.py | 22 +++++++++++++++++++ .../_stack_viewer_v2/_indexing.py | 18 +++++++++++++-- .../_stack_viewer_v2/_mda_viewer.py | 7 ++++-- 3 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 examples/stack_viewer_tensorstore.py diff --git a/examples/stack_viewer_tensorstore.py b/examples/stack_viewer_tensorstore.py new file mode 100644 index 000000000..d15b7f98f --- /dev/null +++ b/examples/stack_viewer_tensorstore.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import numpy as np +import tensorstore as ts +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer_v2 import StackViewer + +shape = (10, 4, 3, 512, 512) +ts_array = ts.open( + {"driver": "zarr", "kvstore": {"driver": "memory"}}, + create=True, + shape=shape, + dtype=ts.uint8, +).result() +ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8) + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(ts_array) + v.show() + qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index c99cf6e29..c9a866095 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -15,6 +15,7 @@ import dask.array as da import numpy.typing as npt + import tensorstore as ts import xarray as xr from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase @@ -50,6 +51,12 @@ def is_dask_array(obj: Any) -> TypeGuard[da.Array]: return False +def is_tensorstore(obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + def is_duck_array(obj: Any) -> TypeGuard[SupportsIndexing]: if ( isinstance(obj, np.ndarray) @@ -75,16 +82,23 @@ def isel(store: Any, indexers: Indices) -> np.ndarray: return isel_mmcore_tensorstore(store, indexers) if is_xarray_dataarray(store): return cast("np.ndarray", store.isel(indexers).to_numpy()) + if is_tensorstore(store): + return isel_tensorstore(store, indexers) if is_duck_array(store): return isel_np_array(store, indexers) raise NotImplementedError(f"Don't know how to index into type {type(store)}") +def isel_tensorstore(store: ts.TensorStore, indexers: Indices) -> np.ndarray: + import tensorstore + + return store[tensorstore.d[*indexers][*indexers.values()]].read().result() + + def isel_mmcore_tensorstore( writer: TensorStoreHandler, indexers: Indices ) -> np.ndarray: - index = writer._indices[frozenset(indexers.items())] - return writer._store[index].read().result() + return writer.isel(indexers) def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index df7cbeee7..5b32a7d0e 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -22,10 +22,13 @@ class MDAViewer(StackViewer): _data: _5DWriterBase def __init__( - self, datastore: _5DWriterBase | None = None, *, parent: QWidget | None = None + self, + datastore: _5DWriterBase | TensorStoreHandler | None = None, + *, + parent: QWidget | None = None, ): if datastore is None: - datastore = TensorStoreHandler("datastore") + datastore = TensorStoreHandler() elif not isinstance(datastore, (OMEZarrWriter, OMETiffWriter)): raise TypeError( "MDAViewer currently only supports _5DWriterBase datastores." From 120a47e5babcc4a6740a7a2bc09394a9604698f0 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 11 May 2024 18:36:11 -0400 Subject: [PATCH 41/73] fix imports --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 2 +- src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index c9a866095..bde961a3a 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, cast import numpy as np -from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreHandler +from pymmcore_plus.mda.handlers import TensorStoreHandler if TYPE_CHECKING: from typing import Any, Protocol, TypeGuard diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 5b32a7d0e..840d213f7 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -5,8 +5,7 @@ import superqt import useq -from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter -from pymmcore_plus.mda.handlers._tensorstore_writer import TensorStoreHandler +from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter, TensorStoreHandler from ._save_button import SaveButton from ._stack_viewer import StackViewer From 7c895c4f360248894facbb3130656ba26e18e211 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 12 May 2024 09:43:38 -0400 Subject: [PATCH 42/73] channel names and docs --- .../_stack_viewer_v2/_dims_slider.py | 38 ++++-- .../_stack_viewer_v2/_mda_viewer.py | 12 +- .../_stack_viewer_v2/_stack_viewer.py | 128 +++++++++++++++--- 3 files changed, 145 insertions(+), 33 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 49c92462c..98c3cf214 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -214,17 +214,17 @@ def mouseDoubleClickEvent(self, a0: Any) -> None: self._set_slice_mode(not self._slice_mode) super().mouseDoubleClickEvent(a0) - def setMaximum(self, max_val: int) -> None: + def containMaximum(self, max_val: int) -> None: if max_val > self._int_slider.maximum(): self._int_slider.setMaximum(max_val) - if max_val > self._slice_slider.maximum(): - self._slice_slider.setMaximum(max_val) + if max_val > self._slice_slider.maximum(): + self._slice_slider.setMaximum(max_val) - def setMinimum(self, min_val: int) -> None: + def containMinimum(self, min_val: int) -> None: if min_val < self._int_slider.minimum(): self._int_slider.setMinimum(min_val) - if min_val < self._slice_slider.minimum(): - self._slice_slider.setMinimum(min_val) + if min_val < self._slice_slider.minimum(): + self._slice_slider.setMinimum(min_val) def setRange(self, min_val: int, max_val: int) -> None: self._int_slider.setRange(min_val, max_val) @@ -255,7 +255,14 @@ def setValue(self, val: Index) -> None: def forceValue(self, val: Index) -> None: """Set value and increase range if necessary.""" - self.setMaximum(val.stop if isinstance(val, slice) else val) + if isinstance(val, slice): + if isinstance(val.start, int): + self.containMinimum(val.start) + if isinstance(val.stop, int): + self.containMaximum(val.stop) + else: + self.containMinimum(val) + self.containMaximum(val) self.setValue(val) def _set_slice_mode(self, mode: bool = True) -> None: @@ -358,6 +365,14 @@ def __init__(self, parent: QWidget | None = None) -> None: layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) + def __contains__(self, key: DimKey) -> bool: + """Return True if the dimension key is present in the DimsSliders.""" + return key in self._sliders + + def slider(self, key: DimKey) -> DimsSlider: + """Return the DimsSlider widget for the given dimension key.""" + return self._sliders[key] + def value(self) -> Indices: """Return mapping of {dim_key -> current index} for each dimension.""" return self._current_index.copy() @@ -394,7 +409,7 @@ def setMinima(self, values: Sizes) -> None: for name, min_val in values.items(): if name not in self._sliders: self.add_dimension(name) - self._sliders[name].setMinimum(min_val) + self._sliders[name].containMinimum(min_val) def maxima(self) -> Sizes: """Return mapping of {dim_key -> maximum value} for each dimension.""" @@ -411,7 +426,7 @@ def setMaxima(self, values: Sizes) -> None: for name, max_val in values.items(): if name not in self._sliders: self.add_dimension(name) - self._sliders[name].setMaximum(max_val) + self._sliders[name].containMaximum(max_val) def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: """Set the visibility of the lock buttons for all dimensions.""" @@ -439,7 +454,10 @@ def add_dimension(self, name: DimKey, val: Index | None = None) -> None: val_int = val.start if isinstance(val, slice) else val slider.setVisible(name not in self._invisible_dims) - slider.setRange(0, val_int if isinstance(val_int, int) else 0) + if isinstance(val_int, int): + slider.setRange(val_int, val_int) + elif isinstance(val_int, slice): + slider.setRange(val_int.start or 0, val_int.stop or 1) val = val if val is not None else 0 self._current_index[name] = val diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 840d213f7..7750517a4 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Mapping import superqt import useq @@ -50,6 +50,7 @@ def __init__( self._save_btn = SaveButton(self.data) self._btns.addWidget(self._save_btn) self.dims_sliders.set_locks_visible(True) + self._channel_names: dict[int, str] = {} @property def data(self) -> _5DWriterBase: @@ -62,4 +63,13 @@ def _patched_frame_ready(self, *args: Any) -> None: @superqt.ensure_main_thread # type: ignore def _on_frame_ready(self, event: useq.MDAEvent) -> None: + c = event.index.get(self._channel_axis) # type: ignore + if c not in self._channel_names and c is not None and event.channel: + self._channel_names[c] = event.channel.config self.setIndex(event.index) # type: ignore + + def _get_channel_name(self, index: Mapping) -> str: + if self._channel_axis in index: + if name := self._channel_names.get(index[self._channel_axis]): + return name + return super()._get_channel_name(index) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 1018c0d90..56b78dbb8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -3,6 +3,7 @@ from collections import defaultdict from enum import Enum from itertools import cycle +from re import A from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast import cmap @@ -65,7 +66,57 @@ def setMode(self, mode: ChannelMode) -> None: class StackViewer(QWidget): - """A viewer for ND arrays.""" + """A viewer for ND arrays. + + This widget displays a single slice from an ND array (or a composite of slices in + different colormaps). The widget provides sliders to select the slice to display, + and buttons to control the display mode of the channels. + + An important concept in this widget is the "index". The index is a mapping of + dimensions to integers or slices that define the slice of the data to display. For + example, a numpy slice of `[0, 1, 5:10]` would be represented as + `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g. + `{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from + the datastore, and to determine the position of the sliders. + + The flow of data is as follows: + + - The user sets the data using the `set_data` method. This will set the number + and range of the sliders to the shape of the data, and display the first slice. + - The user can then use the sliders to select the slice to display. The current + slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved + with the `_dims_sliders.value()` method. To programmatically set the current + position, use the `setIndex` method. This will set the values of the sliders, + which in turn will trigger the display of the new slice via the + `_update_data_for_index` method. + - `_update_data_for_index` is an asynchronous method that retrieves the data for + the given index from the datastore (using `_isel`) and queues the + `_on_data_slice_ready` method to be called when the data is ready. The logic + for extracting data from the datastore is defined in `_indexing.py`, which handles + idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). + - `_on_data_slice_ready` is called when the data is ready, and updates the image. + Note that if the slice is multidimensional, the data will be reduced to 2D using + max intensity projection (and double-clicking on any given dimension slider will + turn it into a range slider allowing a projection to be made over that dimension). + - The image is displayed on the canvas, which is an object that implements the + `PCanvas` protocol (mostly, it has an `add_image` method that returns a handle + to the added image that can be used to update the data and display). This + small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc). + + Parameters + ---------- + data : Any + The data to display. This can be an ND array, an xarray DataArray, or any + object that supports numpy-style indexing. + parent : QWidget, optional + The parent widget of this widget. + channel_axis : Hashable, optional + The axis that represents the channels in the data. If not provided, this will + be guessed from the data. + channel_mode : ChannelMode, optional + The initial mode for displaying the channels. If not provided, this will be + set to ChannelMode.MONO. + """ def __init__( self, @@ -163,6 +214,11 @@ def data(self) -> Any: """Return the data backing the view.""" return self._data + @data.setter + def data(self, data: Any) -> None: + """Set the data backing the view.""" + raise AttributeError("Cannot set data directly. Use `set_data` method.") + @property def dims_sliders(self) -> DimsSliders: """Return the DimsSliders widget.""" @@ -174,39 +230,41 @@ def sizes(self) -> Sizes: return self._sizes def set_data( - self, data: Any, sizes: SizesLike | None = None, channel_axis: int | None = None + self, + data: Any, + sizes: SizesLike | None = None, + channel_axis: int | None = None, + visualized_dims: Iterable[DimKey] | None = None, ) -> None: """Set the datastore, and, optionally, the sizes of the data.""" + # store the data + self._data = data + + # determine sizes of the data if sizes is None: if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): sizes = sz elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): sizes = shp self._sizes = _to_sizes(sizes) - self._data = data + + # set channel axis if channel_axis is not None: self._channel_axis = channel_axis elif self._channel_axis is None: self._channel_axis = self._guess_channel_axis(data) - self.set_visualized_dims(list(self._sizes)[-2:]) - self.update_slider_ranges() - self.setIndex({}) - package = getattr(data, "__module__", "").split(".")[0] - info = f"{package}.{getattr(type(data), '__qualname__', '')}" + # update the dimensions we are visualizing + if visualized_dims is None: + visualized_dims = list(self._sizes)[-2:] + self.set_visualized_dims(visualized_dims) - if self._sizes: - if all(isinstance(x, int) for x in self._sizes): - size_str = repr(tuple(self._sizes.values())) - else: - size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) - size_str = f"({size_str})" - info += f" {size_str}" - if dtype := getattr(data, "dtype", ""): - info += f", {dtype}" - if nbytes := getattr(data, "nbytes", 0) / 1e6: - info += f", {nbytes:.2f}MB" - self._data_info.setText(info) + # update the range of all the sliders to match the sizes we set above + self.update_slider_ranges() + # redraw + self._update_data_for_index(self._dims_sliders.value()) + # update the data info label + self._update_data_info() def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: """Set the dimensions that will be visualized. @@ -228,7 +286,7 @@ def update_slider_ranges( This is mostly here as a public way to reset the """ if maxes is None: - maxes = self.sizes + maxes = self._sizes maxes = _to_sizes(maxes) self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) if mins is not None: @@ -272,6 +330,25 @@ def setIndex(self, index: Indices) -> None: # ------------------- PRIVATE METHODS ---------------------------- + def _update_data_info(self) -> None: + """Update the data info label with information about the data.""" + data = self._data + package = getattr(data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(data), '__qualname__', '')}" + + if self._sizes: + if all(isinstance(x, int) for x in self._sizes): + size_str = repr(tuple(self._sizes.values())) + else: + size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + self._data_info.setText(info) + def _guess_channel_axis(self, data: Any) -> DimKey | None: """Guess the channel axis from the data.""" if is_xarray_dataarray(data): @@ -328,6 +405,9 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No # assume that if we have channels remaining, that they are the first axis # FIXME: this is a bad assumption data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] + # FIXME: + # `self._channel_axis: i` is a bug; we assume channel indices start at 0 + # but the actual values used for indices are up to the user. for i, datum in enumerate(data): self._update_canvas_data(datum, {**index, self._channel_axis: i}) self._canvas.refresh() @@ -354,10 +434,14 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: ) handles.append(self._canvas.add_image(datum, cmap=cm)) if imkey not in self._lut_ctrls: - channel_name = f"Ch {imkey}" # TODO: get name from user + channel_name = self._get_channel_name(index) self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) self._lut_drop.addWidget(c) + def _get_channel_name(self, index: Indices) -> str: + c = index.get(self._channel_axis, 0) + return f"Ch {c}" # TODO: get name from user + def _reduce_data_for_display( self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max ) -> np.ndarray: From 580cb20f9d35234d0d3db5bd982776e8f3a1619f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 May 2024 13:44:00 +0000 Subject: [PATCH 43/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 56b78dbb8..dca34e608 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -3,7 +3,6 @@ from collections import defaultdict from enum import Enum from itertools import cycle -from re import A from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast import cmap From f02d2faca519e72627d7ae65654e2f8a991b4e6f Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 12 May 2024 09:47:29 -0400 Subject: [PATCH 44/73] fix --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index dca34e608..206cf07b0 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -261,7 +261,7 @@ def set_data( # update the range of all the sliders to match the sizes we set above self.update_slider_ranges() # redraw - self._update_data_for_index(self._dims_sliders.value()) + self.setIndex({}) # update the data info label self._update_data_info() From 155b45c6637e04cf48d2d58cb9d50f9aec210584 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 12 May 2024 13:39:02 -0400 Subject: [PATCH 45/73] rearrange and fix delayed requests --- .../dask_arr.py} | 0 examples/stack_viewer/jax_arr.py | 17 +++++++ .../numpy_arr.py} | 4 +- .../tensorstore_arr.py} | 0 examples/stack_viewer/zarr_arr.py | 16 +++++++ .../_stack_viewer_v2/_dims_slider.py | 12 ++++- .../_stack_viewer_v2/_indexing.py | 18 +++----- .../_stack_viewer_v2/_stack_viewer.py | 45 ++++++++++++++++--- 8 files changed, 92 insertions(+), 20 deletions(-) rename examples/{stack_viewer_dask.py => stack_viewer/dask_arr.py} (100%) create mode 100644 examples/stack_viewer/jax_arr.py rename examples/{stack_viewer_numpy.py => stack_viewer/numpy_arr.py} (94%) rename examples/{stack_viewer_tensorstore.py => stack_viewer/tensorstore_arr.py} (100%) create mode 100644 examples/stack_viewer/zarr_arr.py diff --git a/examples/stack_viewer_dask.py b/examples/stack_viewer/dask_arr.py similarity index 100% rename from examples/stack_viewer_dask.py rename to examples/stack_viewer/dask_arr.py diff --git a/examples/stack_viewer/jax_arr.py b/examples/stack_viewer/jax_arr.py new file mode 100644 index 000000000..b57a129d6 --- /dev/null +++ b/examples/stack_viewer/jax_arr.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import jax.numpy as jnp +from numpy_arr import generate_5d_sine_wave +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer_v2._stack_viewer import StackViewer + +# Example usage +array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions +sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape)) + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(sine_wave_5d, channel_axis=1) + v.show() + qapp.exec() diff --git a/examples/stack_viewer_numpy.py b/examples/stack_viewer/numpy_arr.py similarity index 94% rename from examples/stack_viewer_numpy.py rename to examples/stack_viewer/numpy_arr.py index 3a9a6ce6e..06711f3d4 100644 --- a/examples/stack_viewer_numpy.py +++ b/examples/stack_viewer/numpy_arr.py @@ -49,11 +49,11 @@ def generate_5d_sine_wave( # Example usage -array_shape = (10, 5, 5, 512, 512) # Specify the desired dimensions +array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions sine_wave_5d = generate_5d_sine_wave(array_shape) if __name__ == "__main__": qapp = QtWidgets.QApplication([]) - v = StackViewer(sine_wave_5d, channel_axis=2) + v = StackViewer(sine_wave_5d, channel_axis=1) v.show() qapp.exec() diff --git a/examples/stack_viewer_tensorstore.py b/examples/stack_viewer/tensorstore_arr.py similarity index 100% rename from examples/stack_viewer_tensorstore.py rename to examples/stack_viewer/tensorstore_arr.py diff --git a/examples/stack_viewer/zarr_arr.py b/examples/stack_viewer/zarr_arr.py new file mode 100644 index 000000000..385cbddd0 --- /dev/null +++ b/examples/stack_viewer/zarr_arr.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import zarr +import zarr.storage +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer_v2 import StackViewer + +URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr" +zarr_arr = zarr.open(URL, mode="r") + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(zarr_arr["s0"]) + v.show() + qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 98c3cf214..b473f7859 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -220,6 +220,14 @@ def containMaximum(self, max_val: int) -> None: if max_val > self._slice_slider.maximum(): self._slice_slider.setMaximum(max_val) + def setMaximum(self, max_val: int) -> None: + self._int_slider.setMaximum(max_val) + self._slice_slider.setMaximum(max_val) + + def setMinumum(self, min_val: int) -> None: + self._int_slider.setMinimum(min_val) + self._slice_slider.setMinimum(min_val) + def containMinimum(self, min_val: int) -> None: if min_val < self._int_slider.minimum(): self._int_slider.setMinimum(min_val) @@ -409,7 +417,7 @@ def setMinima(self, values: Sizes) -> None: for name, min_val in values.items(): if name not in self._sliders: self.add_dimension(name) - self._sliders[name].containMinimum(min_val) + self._sliders[name].setMinumum(min_val) def maxima(self) -> Sizes: """Return mapping of {dim_key -> maximum value} for each dimension.""" @@ -426,7 +434,7 @@ def setMaxima(self, values: Sizes) -> None: for name, max_val in values.items(): if name not in self._sliders: self.add_dimension(name) - self._sliders[name].containMaximum(max_val) + self._sliders[name].setMaximum(max_val) def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: """Set the visibility of the lock buttons for all dimensions.""" diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index bde961a3a..feca0be66 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -2,7 +2,7 @@ import sys import warnings -from concurrent.futures import Future, InvalidStateError +from concurrent.futures import Future, InvalidStateError, ThreadPoolExecutor from contextlib import suppress from threading import Thread from typing import TYPE_CHECKING, cast @@ -62,6 +62,7 @@ def is_duck_array(obj: Any) -> TypeGuard[SupportsIndexing]: isinstance(obj, np.ndarray) or hasattr(obj, "__array_function__") or hasattr(obj, "__array_namespace__") + or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) ): return True return False @@ -101,18 +102,13 @@ def isel_mmcore_tensorstore( return writer.isel(indexers) -def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: - """Asynchronous version of isel.""" - fut: Future[tuple[Indices, np.ndarray]] = Future() +# Create a global executor +_ISEL_THREAD_EXECUTOR = ThreadPoolExecutor(max_workers=1) - def _thread_target() -> None: - data = isel(store, indexers) - with suppress(InvalidStateError): - fut.set_result((indexers, data)) - thread = Thread(target=_thread_target) - thread.start() - return fut +def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: + """Asynchronous version of isel.""" + return _ISEL_THREAD_EXECUTOR.submit(lambda: (indexers, isel(store, indexers))) def isel_np_array(data: SupportsIndexing, indexers: Indices) -> np.ndarray: diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 206cf07b0..9db0823ca 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -1,15 +1,16 @@ from __future__ import annotations from collections import defaultdict +from dataclasses import dataclass from enum import Enum from itertools import cycle -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Container, Iterable, Mapping, Sequence, cast import cmap import numpy as np from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread -from superqt.utils import qthrottled +from superqt.utils import qthrottled, signals_blocked from ._backends import get_canvas from ._dims_slider import DimsSliders @@ -64,6 +65,31 @@ def setMode(self, mode: ChannelMode) -> None: self.setChecked(mode == ChannelMode.MONO) +# @dataclass +# class LutModel: +# name: str = "" +# autoscale: bool = True +# min: float = 0.0 +# max: float = 1.0 +# colormap: cmap.Colormap = GRAYS +# visible: bool = True + + +# @dataclass +# class ViewerModel: +# data: Any = None +# # dimensions of the data that will *not* be sliced. +# visualized_dims: Container[DimKey] = (-2, -1) +# # the axis that represents the channels in the data +# channel_axis: DimKey | None = None +# # the mode for displaying the channels +# # if MONO, only the current selection of channel_axis is displayed +# # if COMPOSITE, the full channel_axis is sliced, and luts determine display +# channel_mode: ChannelMode = ChannelMode.MONO +# # map of index in the channel_axis to LutModel +# luts: Mapping[int, LutModel] = {} + + class StackViewer(QWidget): """A viewer for ND arrays. @@ -145,7 +171,8 @@ def __init__( # colormaps that will be cycled through when displaying composite images # TODO: allow user to set this self._cmaps = cycle(COLORMAPS) - + # the last future that was created by _update_data_for_index + self._last_future: Future | None = None # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels @@ -259,7 +286,8 @@ def set_data( self.set_visualized_dims(visualized_dims) # update the range of all the sliders to match the sizes we set above - self.update_slider_ranges() + with signals_blocked(self._dims_sliders): + self.update_slider_ranges() # redraw self.setIndex({}) # update the data info label @@ -384,7 +412,11 @@ def _update_data_for_index(self, index: Indices) -> None: if self._channel_axis and self._channel_mode == ChannelMode.COMPOSITE: index = {**index, self._channel_axis: ALL_CHANNELS} - self._isel(index).add_done_callback(self._on_data_slice_ready) + if self._last_future: + self._last_future.cancel() + + self._last_future = f = self._isel(index) + f.add_done_callback(self._on_data_slice_ready) def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" @@ -400,6 +432,9 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No Connected to the future returned by _isel. """ + if future.cancelled(): + return + index, data = future.result() # assume that if we have channels remaining, that they are the first axis # FIXME: this is a bad assumption From 5f84dde4030d621e98d763751744464632bfbdaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 May 2024 17:47:19 +0000 Subject: [PATCH 46/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py | 4 ++-- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 4 +--- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index b473f7859..3ab996959 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -224,7 +224,7 @@ def setMaximum(self, max_val: int) -> None: self._int_slider.setMaximum(max_val) self._slice_slider.setMaximum(max_val) - def setMinumum(self, min_val: int) -> None: + def setMinimum(self, min_val: int) -> None: self._int_slider.setMinimum(min_val) self._slice_slider.setMinimum(min_val) @@ -417,7 +417,7 @@ def setMinima(self, values: Sizes) -> None: for name, min_val in values.items(): if name not in self._sliders: self.add_dimension(name) - self._sliders[name].setMinumum(min_val) + self._sliders[name].setMinimum(min_val) def maxima(self) -> Sizes: """Return mapping of {dim_key -> maximum value} for each dimension.""" diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index feca0be66..964ee6f5e 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -2,9 +2,7 @@ import sys import warnings -from concurrent.futures import Future, InvalidStateError, ThreadPoolExecutor -from contextlib import suppress -from threading import Thread +from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, cast import numpy as np diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 9db0823ca..af7627f34 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -1,10 +1,9 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass from enum import Enum from itertools import cycle -from typing import TYPE_CHECKING, Container, Iterable, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast import cmap import numpy as np From a136f72a7d2de13f3a3b86f7171b2f711f8d19ed Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 24 May 2024 13:01:10 -0400 Subject: [PATCH 47/73] add xarray example --- examples/stack_viewer/dask_arr.py | 9 ++++---- examples/stack_viewer/xarray_arr.py | 14 +++++++++++++ .../_stack_viewer_v2/_stack_viewer.py | 21 +++++++++++++------ 3 files changed, 34 insertions(+), 10 deletions(-) create mode 100644 examples/stack_viewer/xarray_arr.py diff --git a/examples/stack_viewer/dask_arr.py b/examples/stack_viewer/dask_arr.py index 73b9bd7ee..eba37eb58 100644 --- a/examples/stack_viewer/dask_arr.py +++ b/examples/stack_viewer/dask_arr.py @@ -6,17 +6,18 @@ from pymmcore_widgets._stack_viewer_v2 import StackViewer +frame_size = (1024, 1024) + def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: if isinstance(block_id, np.ndarray): return None - data = np.random.randint(0, 255, size=(1000, 1000), dtype=np.uint8) + data = np.random.randint(0, 255, size=frame_size, dtype=np.uint8) return data[(None,) * 3] -shape = (1000, 64, 3, 512, 512) -chunks = [(1,) * x for x in shape[:-2]] -chunks += [(x,) for x in shape[-2:]] +chunks = [(1,) * x for x in (1000, 64, 3)] +chunks += [(x,) for x in frame_size] dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) if __name__ == "__main__": diff --git a/examples/stack_viewer/xarray_arr.py b/examples/stack_viewer/xarray_arr.py new file mode 100644 index 000000000..3c0871582 --- /dev/null +++ b/examples/stack_viewer/xarray_arr.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import xarray as xr +from qtpy import QtWidgets + +from pymmcore_widgets._stack_viewer_v2 import StackViewer + +da = xr.tutorial.open_dataset("air_temperature").air + +if __name__ == "__main__": + qapp = QtWidgets.QApplication([]) + v = StackViewer(da, colormaps=["thermal"], channel_mode="composite") + v.show() + qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index af7627f34..e62c724ac 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -29,7 +29,11 @@ MID_GRAY = "#888888" GRAYS = cmap.Colormap("gray") -COLORMAPS = [cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan")] +DEFAULT_COLORMAPS = [ + cmap.Colormap("green"), + cmap.Colormap("magenta"), + cmap.Colormap("cyan"), +] MAX_CHANNELS = 16 ALL_CHANNELS = slice(None) @@ -146,9 +150,10 @@ def __init__( self, data: Any, *, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, parent: QWidget | None = None, channel_axis: DimKey | None = None, - channel_mode: ChannelMode = ChannelMode.MONO, + channel_mode: ChannelMode | str = ChannelMode.MONO, ): super().__init__(parent=parent) @@ -169,7 +174,11 @@ def __init__( self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode # colormaps that will be cycled through when displaying composite images # TODO: allow user to set this - self._cmaps = cycle(COLORMAPS) + if colormaps is not None: + self._cmaps = [cmap.Colormap(c) for c in colormaps] + else: + self._cmaps = DEFAULT_COLORMAPS + self._cmap_cycle = cycle(self._cmaps) # the last future that was created by _update_data_for_index self._last_future: Future | None = None # WIDGETS ---------------------------------------------------- @@ -322,7 +331,7 @@ def update_slider_ranges( for dim in list(maxes.keys())[-2:]: self._dims_sliders.set_dimension_visible(dim, False) - def set_channel_mode(self, mode: ChannelMode | None = None) -> None: + def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: """Set the mode for displaying the channels. In "composite" mode, the channels are displayed as a composite image, using @@ -339,7 +348,7 @@ def set_channel_mode(self, mode: ChannelMode | None = None) -> None: return self._channel_mode = mode - self._cmaps = cycle(COLORMAPS) # reset the colormap cycle + self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle if self._channel_axis is not None: # set the visibility of the channel slider self._dims_sliders.set_dimension_visible( @@ -461,7 +470,7 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: ctrl.update_autoscale() else: cm = ( - next(self._cmaps) + next(self._cmap_cycle) if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) From 2efce9cfe40efec662bdff0a007fef8b2b983617 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 29 May 2024 16:41:46 -0400 Subject: [PATCH 48/73] two fixes --- .../_stack_viewer_v2/_stack_viewer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index e62c724ac..5d499df19 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -356,6 +356,7 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: ) if self._img_handles: + print("Changing channel mode will clear the current images") self._clear_images() self._update_data_for_index(self._dims_sliders.value()) @@ -417,7 +418,10 @@ def _update_data_for_index(self, index: Indices) -> None: makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ - if self._channel_axis and self._channel_mode == ChannelMode.COMPOSITE: + if ( + self._channel_axis is not None + and self._channel_mode == ChannelMode.COMPOSITE + ): index = {**index, self._channel_axis: ALL_CHANNELS} if self._last_future: @@ -474,6 +478,12 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) + # FIXME: this is a hack ... + # however, there's a bug in the vispy backend such that if the first + # image is all zeros, it persists even if the data is updated + # it's better just to not add it at all... + if np.max(datum) == 0: + return handles.append(self._canvas.add_image(datum, cmap=cm)) if imkey not in self._lut_ctrls: channel_name = self._get_channel_name(index) From 857991ba8e2dd5de69df15f97835b3d798630cff Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 2 Jun 2024 18:49:08 -0400 Subject: [PATCH 49/73] tried to fix leaks. but failed --- pyproject.toml | 1 + .../_stack_viewer_v2/_dims_slider.py | 12 +++---- .../_stack_viewer_v2/_stack_viewer.py | 25 +++++++++---- tests/conftest.py | 35 +++++++++++++------ tests/test_stack_viewer2.py | 15 +++++--- 5 files changed, 60 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 796e15755..53769051c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ docstring-code-format = true # https://docs.pytest.org/en/6.2.x/customize.html [tool.pytest.ini_options] +markers = ["allow_leaks"] minversion = "6.0" testpaths = ["tests"] filterwarnings = [ diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 3ab996959..332c4c4f8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -79,7 +79,7 @@ def __init__(self, parent: QWidget | None = None) -> None: self.setModal(False) # if False, then clicking anywhere else closes it self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) - self.frame = QFrame() + self.frame = QFrame(self) layout = QVBoxLayout(self) layout.addWidget(self.frame) layout.setContentsMargins(0, 0, 0, 0) @@ -106,7 +106,7 @@ def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.PLAY_ICON, color="#888888") icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") super().__init__(icn, "", parent) - self.spin = QDoubleSpinBox() + self.spin = QDoubleSpinBox(parent) self.spin.setRange(0.5, 100) self.spin.setValue(fps) self.spin.valueChanged.connect(self.fpsChanged) @@ -161,7 +161,7 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None self._dim_key = dimension_key self._timer_id: int | None = None # timer for play button - self._play_btn = PlayButton() + self._play_btn = PlayButton(parent=self) self._play_btn.fpsChanged.connect(self.set_fps) self._play_btn.toggled.connect(self._toggle_animation) @@ -171,16 +171,16 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None # note, this lock button only prevents the slider from updating programmatically # using self.setValue, it doesn't prevent the user from changing the value. - self._lock_btn = LockButton() + self._lock_btn = LockButton(parent=self) - self._pos_label = QSpinBox() + self._pos_label = QSpinBox(self) self._pos_label.valueChanged.connect(self._on_pos_label_edited) self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) self._pos_label.setStyleSheet( "border: none; padding: 0; margin: 0; background: transparent" ) - self._out_of_label = QLabel() + self._out_of_label = QLabel(self) self._int_slider = QSlider(Qt.Orientation.Horizontal) self._int_slider.rangeChanged.connect(self._on_range_changed) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 5d499df19..d8a1c7bec 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +from contextlib import suppress from enum import Enum from itertools import cycle from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast @@ -20,6 +21,8 @@ from concurrent.futures import Future from typing import Any, Callable, Hashable, TypeAlias + from qtpy.QtGui import QCloseEvent + from ._dims_slider import DimKey, Indices, Sizes from ._protocols import PCanvas, PImageHandle @@ -184,27 +187,27 @@ def __init__( # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels - self._channel_mode_btn = ChannelModeButton() + self._channel_mode_btn = ChannelModeButton(self) self._channel_mode_btn.clicked.connect(self.set_channel_mode) # button to reset the zoom of the canvas self._set_range_btn = QPushButton( - QIconifyIcon("fluent:full-screen-maximize-24-filled"), "" + QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self ) self._set_range_btn.clicked.connect(self._on_set_range_clicked) # place to display dataset summary - self._data_info = QElidingLabel("") + self._data_info = QElidingLabel("", parent=self) # place to display arbitrary text - self._hover_info = QLabel("") + self._hover_info = QLabel("", self) # the canvas that displays the images self._canvas: PCanvas = get_canvas()(self._hover_info.setText) # the sliders that control the index of the displayed image - self._dims_sliders = DimsSliders() + self._dims_sliders = DimsSliders(self) self._dims_sliders.valueChanged.connect( qthrottled(self._update_data_for_index, 20, leading=True) ) - self._lut_drop = QCollapsible("LUTs") + self._lut_drop = QCollapsible("LUTs", self) self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) @@ -430,6 +433,14 @@ def _update_data_for_index(self, index: Indices) -> None: self._last_future = f = self._isel(index) f.add_done_callback(self._on_data_slice_ready) + def closeEvent(self, a0: QCloseEvent | None) -> None: + if self._last_future: + self._last_future.cancel() + with suppress(AttributeError): + # just in case there is a hard reference to self._on_data_slice_ready + self._last_future._done_callbacks.clear() # type: ignore + super().closeEvent(a0) + def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} @@ -487,7 +498,7 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: handles.append(self._canvas.add_image(datum, cmap=cm)) if imkey not in self._lut_ctrls: channel_name = self._get_channel_name(index) - self._lut_ctrls[imkey] = c = LutControl(channel_name, handles) + self._lut_ctrls[imkey] = c = LutControl(channel_name, handles, self) self._lut_drop.addWidget(c) def _get_channel_name(self, index: Indices) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index a69582364..f3e979e10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import gc from pathlib import Path from typing import TYPE_CHECKING from unittest.mock import patch @@ -23,7 +24,7 @@ def global_mmcore(): @pytest.fixture(autouse=True) -def _run_after_each_test(request: "FixtureRequest", qapp: "QApplication"): +def _run_after_each_test(request: "FixtureRequest", qapp: "QApplication") -> None: """Run after each test to ensure no widgets have been left around. When this test fails, it means that a widget being tested has an issue closing @@ -39,16 +40,28 @@ def _run_after_each_test(request: "FixtureRequest", qapp: "QApplication"): return remaining = qapp.topLevelWidgets() if len(remaining) > nbefore: - if ( - # os.name == "nt" - # and sys.version_info[:2] <= (3, 9) - type(remaining[0]).__name__ in {"ImagePreview", "SnapButton"} - ): - # I have no idea why, but the ImagePreview widget is leaking. - # And it only came with a seemingly unrelated + test_node = request.node + if any(mark.name == "allow_leaks" for mark in test_node.iter_markers()): + return + if type(remaining[0]).__name__ in {"ImagePreview", "SnapButton"}: + # I have no idea why ImagePreview widget is leaking. + # it only came with a seemingly unrelated # https://github.com/pymmcore-plus/pymmcore-widgets/pull/90 - # we're just ignoring it for now. return - test = f"{request.node.path.name}::{request.node.originalname}" - raise AssertionError(f"topLevelWidgets remaining after {test!r}: {remaining}") + test = f"{test_node.path.name}::{test_node.originalname}" + msg = f"{len(remaining)} topLevelWidgets remaining after {test!r}:" + + for widget in remaining: + try: + obj_name = widget.objectName() + except Exception: + obj_name = None + msg += f"\n{widget!r} {obj_name!r}" + # Get the referrers of the widget + referrers = gc.get_referrers(widget) + msg += "\n Referrers:" + for ref in referrers: + msg += f"\n - {ref}, {id(ref):#x}" + + raise AssertionError(msg) diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py index 5cc0c3d87..64fdb8736 100644 --- a/tests/test_stack_viewer2.py +++ b/tests/test_stack_viewer2.py @@ -1,11 +1,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import dask.array as da import numpy as np -from qtpy import QtWidgets +import pytest from pymmcore_widgets._stack_viewer_v2 import StackViewer +if TYPE_CHECKING: + from pytestqt.qtbot import QtBot + def make_lazy_array(shape: tuple[int, ...]) -> da.Array: rest_shape = shape[:-2] @@ -21,9 +26,11 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: return da.map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) # type: ignore -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) +# this test is still leaking widgets and it's hard to track down... I think +# it might have to do with the cmapComboBox +@pytest.mark.allow_leaks +def test_stack_viewer2(qtbot: QtBot) -> None: dask_arr = make_lazy_array((1000, 64, 3, 256, 256)) v = StackViewer(dask_arr) + qtbot.addWidget(v) v.show() - qapp.exec() From d3e3169f243f7ca4c34cca95e9df89b1fe3737fb Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 2 Jun 2024 20:30:07 -0400 Subject: [PATCH 50/73] use data-wrapper --- .../_stack_viewer_v2/_dims_slider.py | 14 +- .../_stack_viewer_v2/_indexing.py | 298 +++++++++++------- .../_stack_viewer_v2/_mda_viewer.py | 6 +- .../_stack_viewer_v2/_save_button.py | 28 +- .../_stack_viewer_v2/_stack_viewer.py | 27 +- 5 files changed, 210 insertions(+), 163 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 332c4c4f8..7f0cb0388 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -106,7 +106,7 @@ def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: icn = QIconifyIcon(self.PLAY_ICON, color="#888888") icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") super().__init__(icn, "", parent) - self.spin = QDoubleSpinBox(parent) + self.spin = QDoubleSpinBox(self) self.spin.setRange(0.5, 100) self.spin.setValue(fps) self.spin.valueChanged.connect(self.fpsChanged) @@ -115,6 +115,11 @@ def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: self.setIconSize(QSize(16, 16)) self.setStyleSheet("border: none; padding: 0; margin: 0;") + self._popup = QtPopup(self) + form = QFormLayout(self._popup.frame) + form.setContentsMargins(6, 6, 6, 6) + form.addRow("FPS", self.spin) + def mousePressEvent(self, e: Any) -> None: if e and e.button() == Qt.MouseButton.RightButton: self._show_fps_dialog(e.globalPosition()) @@ -122,12 +127,7 @@ def mousePressEvent(self, e: Any) -> None: super().mousePressEvent(e) def _show_fps_dialog(self, pos: QPointF) -> None: - if not hasattr(self, "popup"): - self.popup = QtPopup(self) - form = QFormLayout(self.popup.frame) - form.setContentsMargins(6, 6, 6, 6) - form.addRow("FPS", self.spin) - self.popup.show_above_mouse() + self._popup.show_above_mouse() class LockButton(QPushButton): diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 964ee6f5e..56c307994 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -1,12 +1,13 @@ from __future__ import annotations +from pathlib import Path import sys import warnings +from abc import abstractmethod from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Generic, Hashable, Sequence, TypeVar, cast import numpy as np -from pymmcore_plus.mda.handlers import TensorStoreHandler if TYPE_CHECKING: from typing import Any, Protocol, TypeGuard @@ -15,6 +16,7 @@ import numpy.typing as npt import tensorstore as ts import xarray as xr + from pymmcore_plus.mda.handlers import TensorStoreHandler from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase from ._dims_slider import Index, Indices @@ -25,110 +27,188 @@ def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... def shape(self) -> tuple[int, ...]: ... -def is_pymmcore_5dbase(obj: Any) -> TypeGuard[_5DWriterBase]: - try: - from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase - except ImportError: - from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - - _5DWriterBase = (OMETiffWriter, OMEZarrWriter) # type: ignore - if isinstance(obj, _5DWriterBase): - return True - return False - - -def is_xarray_dataarray(obj: Any) -> TypeGuard[xr.DataArray]: - if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): - return True - return False - - -def is_dask_array(obj: Any) -> TypeGuard[da.Array]: - if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): - return True - return False - - -def is_tensorstore(obj: Any) -> TypeGuard[ts.TensorStore]: - if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): - return True - return False - - -def is_duck_array(obj: Any) -> TypeGuard[SupportsIndexing]: - if ( - isinstance(obj, np.ndarray) - or hasattr(obj, "__array_function__") - or hasattr(obj, "__array_namespace__") - or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) - ): - return True - return False - - -# TODO: Change this factory function on a wrapper class so we -# don't have to check the type of the object every time we call -def isel(store: Any, indexers: Indices) -> np.ndarray: - """Select a slice from a data store using (possibly) named indices. - - For xarray.DataArray, use the built-in isel method. - For any other duck-typed array, use numpy-style indexing, where indexers - is a mapping of axis to slice objects or indices. - """ - if is_pymmcore_5dbase(store): - return isel_mmcore_5dbase(store, indexers) - if isinstance(store, TensorStoreHandler): - return isel_mmcore_tensorstore(store, indexers) - if is_xarray_dataarray(store): - return cast("np.ndarray", store.isel(indexers).to_numpy()) - if is_tensorstore(store): - return isel_tensorstore(store, indexers) - if is_duck_array(store): - return isel_np_array(store, indexers) - raise NotImplementedError(f"Don't know how to index into type {type(store)}") - - -def isel_tensorstore(store: ts.TensorStore, indexers: Indices) -> np.ndarray: - import tensorstore - - return store[tensorstore.d[*indexers][*indexers.values()]].read().result() - - -def isel_mmcore_tensorstore( - writer: TensorStoreHandler, indexers: Indices -) -> np.ndarray: - return writer.isel(indexers) - - -# Create a global executor -_ISEL_THREAD_EXECUTOR = ThreadPoolExecutor(max_workers=1) - - -def isel_async(store: Any, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: - """Asynchronous version of isel.""" - return _ISEL_THREAD_EXECUTOR.submit(lambda: (indexers, isel(store, indexers))) - - -def isel_np_array(data: SupportsIndexing, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(data.shape))) - return np.asarray(data[idx]) - - -def isel_mmcore_5dbase(writer: _5DWriterBase, indexers: Indices) -> np.ndarray: - p_index = indexers.get("p", 0) - if isinstance(p_index, slice): - warnings.warn("Cannot slice over position index", stacklevel=2) # TODO - p_index = p_index.start - p_index = cast(int, p_index) - - try: - sizes = [*list(writer.position_sizes[p_index]), "y", "x"] - except IndexError as e: - raise IndexError( - f"Position index {p_index} out of range for {len(writer.position_sizes)}" - ) from e - - data = writer.position_arrays[writer.get_position_key(p_index)] - full = slice(None, None) - index = tuple(indexers.get(k, full) for k in sizes) - return data[index] # type: ignore +ArrayT = TypeVar("ArrayT") +MAX_CHANNELS = 16 + + +class DataWrapper(Generic[ArrayT]): + # Create a global executor + _EXECUTOR = ThreadPoolExecutor(max_workers=1) + + def __init__(self, data: ArrayT) -> None: + self._data = data + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if MMTensorStoreWrapper.supports(data): + return MMTensorStoreWrapper(data) + if MM5DWriter.supports(data): + return MM5DWriter(data) + if XarrayWrapper.supports(data): + return XarrayWrapper(data) + if DaskWrapper.supports(data): + return DaskWrapper(data) + if TensorstoreWrapper.supports(data): + return TensorstoreWrapper(data) + if ArrayLikeWrapper.supports(data): + return ArrayLikeWrapper(data) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + @abstractmethod + def isel(self, indexers: Indices) -> np.ndarray: + """Select a slice from a data store using (possibly) named indices. + + For xarray.DataArray, use the built-in isel method. + For any other duck-typed array, use numpy-style indexing, where indexers + is a mapping of axis to slice objects or indices. + """ + raise NotImplementedError + + def isel_async(self, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: + """Asynchronous version of isel.""" + return self._EXECUTOR.submit(lambda: (indexers, self.isel(indexers))) + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object.""" + raise NotImplementedError + + def guess_channel_axis(self) -> Hashable | None: + """Return the (best guess) axis name for the channel dimension.""" + if isinstance(shp := getattr(self._data, "shape", None), Sequence): + # for numpy arrays, use the smallest dimension as the channel axis + if min(shp) <= MAX_CHANNELS: + return shp.index(min(shp)) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + raise NotImplementedError("save_as_zarr not implemented for this data type.") + + +class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): + @classmethod + def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: + from pymmcore_plus.mda.handlers import TensorStoreHandler + + return isinstance(obj, TensorStoreHandler) + + def isel(self, indexers: Indices) -> np.ndarray: + return self._data.isel(indexers) # type: ignore + + +class MM5DWriter(DataWrapper["_5DWriterBase"]): + def isel(self, indexers: Indices) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + warnings.warn("Cannot slice over position index", stacklevel=2) # TODO + p_index = p_index.start + p_index = cast(int, p_index) + + try: + sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for " + f"{len(self._data.position_sizes)}" + ) from e + + data = self._data.position_arrays[self._data.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: + try: + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + except ImportError: + from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + + _5DWriterBase = (OMETiffWriter, OMEZarrWriter) # type: ignore + if isinstance(obj, _5DWriterBase): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + import zarr + from pymmcore_plus.mda.handlers import OMEZarrWriter + + if isinstance(self._data, OMEZarrWriter): + zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) + raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") + + +class XarrayWrapper(DataWrapper["xr.DataArray"]): + def isel(self, indexers: Indices) -> np.ndarray: + return np.asarray(self._data.isel(indexers)) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + def guess_channel_axis(self) -> Hashable | None: + for d in self._data.dims: + if str(d).lower() in ("channel", "ch", "c"): + return d + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(save_loc) + + +class DaskWrapper(DataWrapper["da.Array"]): + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx].compute()) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + def __init__(self, data: Any) -> None: + super().__init__(data) + import tensorstore as ts + + self._ts = ts + + def isel(self, indexers: Indices) -> np.ndarray: + result = self._data[self._ts.d[*indexers][*indexers.values()]].read().result() + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper): + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx]) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) + ): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + import zarr + + if isinstance(self._data, zarr.Array): + self._data.store = zarr.DirectoryStore(save_loc) + else: + zarr.save(str(save_loc), self._data) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 7750517a4..6a3da924e 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -47,15 +47,11 @@ def __init__( ) super().__init__(datastore, parent=parent, channel_axis="c") - self._save_btn = SaveButton(self.data) + self._save_btn = SaveButton(self._data_wrapper) self._btns.addWidget(self._save_btn) self.dims_sliders.set_locks_visible(True) self._channel_names: dict[int, str] = {} - @property - def data(self) -> _5DWriterBase: - return self._data - def _patched_frame_ready(self, *args: Any) -> None: self._superframeReady(*args) # type: ignore if len(args) >= 2 and isinstance(e := args[1], useq.MDAEvent): diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py b/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py index c526ab258..85520641a 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py @@ -1,26 +1,26 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING -import numpy as np from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget from superqt.iconify import QIconifyIcon -from ._indexing import is_xarray_dataarray +if TYPE_CHECKING: + from ._indexing import DataWrapper class SaveButton(QPushButton): def __init__( self, - datastore: Any, + data_wrapper: DataWrapper, parent: QWidget | None = None, ): super().__init__(parent=parent) self.setIcon(QIconifyIcon("mdi:content-save")) self.clicked.connect(self._on_click) - self._data = datastore + self._data_wrapper = data_wrapper self._last_loc = str(Path.home()) def _on_click(self) -> None: @@ -29,22 +29,6 @@ def _on_click(self) -> None: ) suffix = Path(self._last_loc).suffix if suffix in (".zarr", ".ome.zarr", ""): - _save_as_zarr(self._last_loc, self._data) + self._data_wrapper.save_as_zarr(self._last_loc) else: raise ValueError(f"Unsupported file format: {self._last_loc}") - - -def _save_as_zarr(save_loc: str | Path, data: Any) -> None: - import zarr - from pymmcore_plus.mda.handlers import OMEZarrWriter - - if isinstance(data, OMEZarrWriter): - zarr.copy_store(data.group.store, zarr.DirectoryStore(save_loc)) - elif isinstance(data, zarr.Array): - data.store = zarr.DirectoryStore(save_loc) - elif isinstance(data, np.ndarray): - zarr.save(str(save_loc), data) - elif is_xarray_dataarray(data): - data.to_zarr(save_loc) - else: - raise ValueError(f"Cannot save data of type {type(data)} to Zarr format.") diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index d8a1c7bec..f99236fd9 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -14,7 +14,7 @@ from ._backends import get_canvas from ._dims_slider import DimsSliders -from ._indexing import is_xarray_dataarray, isel_async +from ._indexing import DataWrapper from ._lut_control import LutControl if TYPE_CHECKING: @@ -37,7 +37,6 @@ cmap.Colormap("magenta"), cmap.Colormap("cyan"), ] -MAX_CHANNELS = 16 ALL_CHANNELS = slice(None) @@ -249,7 +248,7 @@ def __init__( @property def data(self) -> Any: """Return the data backing the view.""" - return self._data + return self._data_wrapper._data @data.setter def data(self, data: Any) -> None: @@ -275,7 +274,7 @@ def set_data( ) -> None: """Set the datastore, and, optionally, the sizes of the data.""" # store the data - self._data = data + self._data_wrapper = DataWrapper.create(data) # determine sizes of the data if sizes is None: @@ -289,7 +288,7 @@ def set_data( if channel_axis is not None: self._channel_axis = channel_axis elif self._channel_axis is None: - self._channel_axis = self._guess_channel_axis(data) + self._channel_axis = self._data_wrapper.guess_channel_axis() # update the dimensions we are visualizing if visualized_dims is None: @@ -359,7 +358,6 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: ) if self._img_handles: - print("Changing channel mode will clear the current images") self._clear_images() self._update_data_for_index(self._dims_sliders.value()) @@ -371,7 +369,8 @@ def setIndex(self, index: Indices) -> None: def _update_data_info(self) -> None: """Update the data info label with information about the data.""" - data = self._data + data = self._data_wrapper._data + package = getattr(data, "__module__", "").split(".")[0] info = f"{package}.{getattr(type(data), '__qualname__', '')}" @@ -388,18 +387,6 @@ def _update_data_info(self) -> None: info += f", {nbytes:.2f}MB" self._data_info.setText(info) - def _guess_channel_axis(self, data: Any) -> DimKey | None: - """Guess the channel axis from the data.""" - if is_xarray_dataarray(data): - for d in data.dims: - if str(d).lower() in ("channel", "ch", "c"): - return cast("DimKey", d) - if isinstance(shp := getattr(data, "shape", None), Sequence): - # for numpy arrays, use the smallest dimension as the channel axis - if min(shp) <= MAX_CHANNELS: - return shp.index(min(shp)) - return None - def _on_set_range_clicked(self) -> None: # using method to swallow the parameter passed by _set_range_btn.clicked self._canvas.set_range() @@ -445,7 +432,7 @@ def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: """Select data from the datastore using the given index.""" idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: - return isel_async(self._data, idx) + return self._data_wrapper.isel_async(idx) except Exception as e: raise type(e)(f"Failed to index data with {idx}: {e}") from e From 550d5d6689e0c04f48623c1555a385053bc51643 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 00:30:27 +0000 Subject: [PATCH 51/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 56c307994..1410878fb 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path import sys import warnings from abc import abstractmethod @@ -10,6 +9,7 @@ import numpy as np if TYPE_CHECKING: + from pathlib import Path from typing import Any, Protocol, TypeGuard import dask.array as da From 4441085906f7c60170923626527c31cc6f0d503f Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 2 Jun 2024 20:31:59 -0400 Subject: [PATCH 52/73] minor --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 56c307994..cf8fe0fd8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path import sys import warnings from abc import abstractmethod @@ -10,6 +9,7 @@ import numpy as np if TYPE_CHECKING: + from pathlib import Path from typing import Any, Protocol, TypeGuard import dask.array as da @@ -29,17 +29,18 @@ def shape(self) -> tuple[int, ...]: ... ArrayT = TypeVar("ArrayT") MAX_CHANNELS = 16 +# Create a global executor +_EXECUTOR = ThreadPoolExecutor(max_workers=1) class DataWrapper(Generic[ArrayT]): - # Create a global executor - _EXECUTOR = ThreadPoolExecutor(max_workers=1) - def __init__(self, data: ArrayT) -> None: self._data = data @classmethod def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data if MMTensorStoreWrapper.supports(data): return MMTensorStoreWrapper(data) if MM5DWriter.supports(data): @@ -66,7 +67,7 @@ def isel(self, indexers: Indices) -> np.ndarray: def isel_async(self, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: """Asynchronous version of isel.""" - return self._EXECUTOR.submit(lambda: (indexers, self.isel(indexers))) + return _EXECUTOR.submit(lambda: (indexers, self.isel(indexers))) @classmethod @abstractmethod From 759956dbcb37e526670630f6c08c99914356fc15 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 2 Jun 2024 20:35:30 -0400 Subject: [PATCH 53/73] misc --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index cf8fe0fd8..16e3a0798 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -153,7 +153,7 @@ def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: def guess_channel_axis(self) -> Hashable | None: for d in self._data.dims: if str(d).lower() in ("channel", "ch", "c"): - return d + return cast("Hashable", d) return None def save_as_zarr(self, save_loc: str | Path) -> None: From 98b2717e6fa811d0d3b39e743447b5951bb001ff Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 11:09:29 -0400 Subject: [PATCH 54/73] fix one leak --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 4 ++++ tests/test_stack_viewer2.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index f99236fd9..d30ca05f7 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -442,6 +442,10 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No Connected to the future returned by _isel. """ + # NOTE: removing the reference to the last future here is important + # because the future has a reference to this widget in its _done_callbacks + # which will prevent the widget from being garbage collected if the future + self._last_future = None if future.cancelled(): return diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py index 64fdb8736..ece9ae448 100644 --- a/tests/test_stack_viewer2.py +++ b/tests/test_stack_viewer2.py @@ -4,7 +4,6 @@ import dask.array as da import numpy as np -import pytest from pymmcore_widgets._stack_viewer_v2 import StackViewer @@ -28,9 +27,13 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: # this test is still leaking widgets and it's hard to track down... I think # it might have to do with the cmapComboBox -@pytest.mark.allow_leaks +# @pytest.mark.allow_leaks def test_stack_viewer2(qtbot: QtBot) -> None: dask_arr = make_lazy_array((1000, 64, 3, 256, 256)) v = StackViewer(dask_arr) qtbot.addWidget(v) v.show() + + # wait until there are no running jobs, because the callbacks + # in the futures hold a strong reference to the viewer + qtbot.waitUntil(lambda: v._last_future is None, timeout=10000) From 18a6047eda3768f6810a17afa3a30ab0a6ff1b98 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 13:24:50 -0400 Subject: [PATCH 55/73] small changes --- .../_stack_viewer_v2/_dims_slider.py | 4 +++- .../_stack_viewer_v2/_stack_viewer.py | 7 ++----- tests/test_stack_viewer2.py | 15 ++++++++++++++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 7f0cb0388..2987c2801 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -400,7 +400,9 @@ def setValue(self, values: Indices) -> None: with signals_blocked(self): for dim, index in values.items(): self.add_or_update_dimension(dim, index) - self.valueChanged.emit(self.value()) + # FIXME: i don't know why this this is ever empty ... only happens on pyside6 + if val := self.value(): + self.valueChanged.emit(val) def minima(self) -> Sizes: """Return mapping of {dim_key -> minimum value} for each dimension.""" diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index d30ca05f7..7f4c3fe2d 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections import defaultdict -from contextlib import suppress from enum import Enum from itertools import cycle from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast @@ -421,11 +420,9 @@ def _update_data_for_index(self, index: Indices) -> None: f.add_done_callback(self._on_data_slice_ready) def closeEvent(self, a0: QCloseEvent | None) -> None: - if self._last_future: + if self._last_future is not None: self._last_future.cancel() - with suppress(AttributeError): - # just in case there is a hard reference to self._on_data_slice_ready - self._last_future._done_callbacks.clear() # type: ignore + self._last_future = None super().closeEvent(a0) def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py index ece9ae448..3dfa17926 100644 --- a/tests/test_stack_viewer2.py +++ b/tests/test_stack_viewer2.py @@ -36,4 +36,17 @@ def test_stack_viewer2(qtbot: QtBot) -> None: # wait until there are no running jobs, because the callbacks # in the futures hold a strong reference to the viewer - qtbot.waitUntil(lambda: v._last_future is None, timeout=10000) + qtbot.waitUntil(lambda: v._last_future is None, timeout=1000) + + +def test_dims_sliders(qtbot: QtBot) -> None: + from superqt import QLabeledRangeSlider + + from pymmcore_widgets._stack_viewer_v2._dims_slider import DimsSlider + + # temporary debugging + ds = DimsSlider(dimension_key="t") + qtbot.addWidget(ds) + + rs = QLabeledRangeSlider() + qtbot.addWidget(rs) From 325a90ea29ff1c794a15232d983daeb25bca6e55 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 14:51:09 -0400 Subject: [PATCH 56/73] more indexing --- .../_stack_viewer_v2/_indexing.py | 39 +++++++++++++- .../_stack_viewer_v2/_lut_control.py | 5 +- .../_stack_viewer_v2/_stack_viewer.py | 51 +++++++------------ 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 16e3a0798..e2e3ca456 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -4,7 +4,7 @@ import warnings from abc import abstractmethod from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Generic, Hashable, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Sequence, TypeVar, cast import numpy as np @@ -86,6 +86,43 @@ def guess_channel_axis(self) -> Hashable | None: def save_as_zarr(self, save_loc: str | Path) -> None: raise NotImplementedError("save_as_zarr not implemented for this data type.") + def sizes(self) -> Mapping[Hashable, int]: + if (sz := getattr(self._data, "sizes", None)) and isinstance(sz, Mapping): + return {k: int(v) for k, v in sz.items()} + elif (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): + _sizes: dict[Hashable, int] = {} + for i, val in enumerate(shape): + if isinstance(val, int): + _sizes[i] = val + elif isinstance(val, Sequence) and len(val) == 2: + _sizes[val[0]] = int(val[1]) + else: + raise ValueError( + f"Invalid size: {val}. Must be an int or a 2-tuple." + ) + return _sizes + raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): @classmethod diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py index 91483b8e3..84c59521e 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, cast +from typing import TYPE_CHECKING, Any, Iterable, cast import numpy as np from qtpy.QtCore import Qt @@ -36,6 +36,7 @@ def __init__( name: str = "", handles: Iterable[PImageHandle] = (), parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), ) -> None: super().__init__(parent) self._handles = handles @@ -49,7 +50,7 @@ def __init__( self._cmap.currentColormapChanged.connect(self._on_cmap_changed) for handle in handles: self._cmap.addColormap(handle.cmap) - for color in ["green", "magenta", "cyan"]: + for color in cmaplist: self._cmap.addColormap(color) self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 7f4c3fe2d..0ad90962d 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -35,6 +35,11 @@ cmap.Colormap("green"), cmap.Colormap("magenta"), cmap.Colormap("cyan"), + cmap.Colormap("yellow"), + cmap.Colormap("red"), + cmap.Colormap("blue"), + cmap.Colormap("cubehelix"), + cmap.Colormap("gray"), ] ALL_CHANNELS = slice(None) @@ -194,11 +199,11 @@ def __init__( self._set_range_btn.clicked.connect(self._on_set_range_clicked) # place to display dataset summary - self._data_info = QElidingLabel("", parent=self) + self._data_info_label = QElidingLabel("", parent=self) # place to display arbitrary text - self._hover_info = QLabel("", self) + self._hover_info_label = QLabel("", self) # the canvas that displays the images - self._canvas: PCanvas = get_canvas()(self._hover_info.setText) + self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders(self) self._dims_sliders.valueChanged.connect( @@ -230,9 +235,9 @@ def __init__( layout = QVBoxLayout(self) layout.setSpacing(2) layout.setContentsMargins(6, 6, 6, 6) - layout.addWidget(self._data_info) + layout.addWidget(self._data_info_label) layout.addWidget(self._canvas.qwidget(), 1) - layout.addWidget(self._hover_info) + layout.addWidget(self._hover_info_label) layout.addWidget(self._dims_sliders) layout.addWidget(self._lut_drop) layout.addLayout(btns) @@ -276,12 +281,7 @@ def set_data( self._data_wrapper = DataWrapper.create(data) # determine sizes of the data - if sizes is None: - if (sz := getattr(data, "sizes", None)) and isinstance(sz, Mapping): - sizes = sz - elif (shp := getattr(data, "shape", None)) and isinstance(shp, tuple): - sizes = shp - self._sizes = _to_sizes(sizes) + self._sizes = self._data_wrapper.sizes() if sizes is None else _to_sizes(sizes) # set channel axis if channel_axis is not None: @@ -300,7 +300,7 @@ def set_data( # redraw self.setIndex({}) # update the data info label - self._update_data_info() + self._data_info_label.setText(self._data_wrapper.summary_info()) def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: """Set the dimensions that will be visualized. @@ -366,26 +366,6 @@ def setIndex(self, index: Indices) -> None: # ------------------- PRIVATE METHODS ---------------------------- - def _update_data_info(self) -> None: - """Update the data info label with information about the data.""" - data = self._data_wrapper._data - - package = getattr(data, "__module__", "").split(".")[0] - info = f"{package}.{getattr(type(data), '__qualname__', '')}" - - if self._sizes: - if all(isinstance(x, int) for x in self._sizes): - size_str = repr(tuple(self._sizes.values())) - else: - size_str = ", ".join(f"{k}:{v}" for k, v in self._sizes.items()) - size_str = f"({size_str})" - info += f" {size_str}" - if dtype := getattr(data, "dtype", ""): - info += f", {dtype}" - if nbytes := getattr(data, "nbytes", 0) / 1e6: - info += f", {nbytes:.2f}MB" - self._data_info.setText(info) - def _on_set_range_clicked(self) -> None: # using method to swallow the parameter passed by _set_range_btn.clicked self._canvas.set_range() @@ -486,7 +466,12 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: handles.append(self._canvas.add_image(datum, cmap=cm)) if imkey not in self._lut_ctrls: channel_name = self._get_channel_name(index) - self._lut_ctrls[imkey] = c = LutControl(channel_name, handles, self) + self._lut_ctrls[imkey] = c = LutControl( + channel_name, + handles, + self, + cmaplist=self._cmaps + DEFAULT_COLORMAPS, + ) self._lut_drop.addWidget(c) def _get_channel_name(self, index: Indices) -> str: From c7cfe6fbe1ef427991844cdc93de3f9e37a5c7a3 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 18:37:51 -0400 Subject: [PATCH 57/73] remove check --- src/pymmcore_widgets/_stack_viewer_v2/_indexing.py | 13 ++++++++++--- .../_stack_viewer_v2/_mda_viewer.py | 5 ----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index e2e3ca456..680c3cc3f 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Sequence, TypeVar, cast import numpy as np +from zarr import suppress if TYPE_CHECKING: from pathlib import Path @@ -87,9 +88,7 @@ def save_as_zarr(self, save_loc: str | Path) -> None: raise NotImplementedError("save_as_zarr not implemented for this data type.") def sizes(self) -> Mapping[Hashable, int]: - if (sz := getattr(self._data, "sizes", None)) and isinstance(sz, Mapping): - return {k: int(v) for k, v in sz.items()} - elif (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): + if (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): _sizes: dict[Hashable, int] = {} for i, val in enumerate(shape): if isinstance(val, int): @@ -125,6 +124,11 @@ def summary_info(self) -> str: class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): + def sizes(self) -> Mapping[Hashable, int]: + with suppress(Exception): + return self._data.current_sequence.sizes() + return {} + @classmethod def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: from pymmcore_plus.mda.handlers import TensorStoreHandler @@ -181,6 +185,9 @@ class XarrayWrapper(DataWrapper["xr.DataArray"]): def isel(self, indexers: Indices) -> np.ndarray: return np.asarray(self._data.isel(indexers)) + def sizes(self) -> Mapping[Hashable, int]: + return {k: int(v) for k, v in self._data.sizes.items()} + @classmethod def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 6a3da924e..127020d6d 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -28,17 +28,12 @@ def __init__( ): if datastore is None: datastore = TensorStoreHandler() - elif not isinstance(datastore, (OMEZarrWriter, OMETiffWriter)): - raise TypeError( - "MDAViewer currently only supports _5DWriterBase datastores." - ) # patch the frameReady method to call the superframeReady method # AFTER handling the event self._superframeReady = getattr(datastore, "frameReady", None) if callable(self._superframeReady): datastore.frameReady = self._patched_frame_ready # type: ignore - else: # pragma: no cover warnings.warn( "MDAViewer: datastore does not have a frameReady method to patch, " From 5bf478caaa7fb037a20274a6825b4d2e583c9f0e Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 3 Jun 2024 18:57:59 -0400 Subject: [PATCH 58/73] wip --- .../_stack_viewer_v2/_indexing.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 680c3cc3f..8f1c39895 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -4,10 +4,10 @@ import warnings from abc import abstractmethod from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Sequence, TypeVar, cast import numpy as np -from zarr import suppress if TYPE_CHECKING: from pathlib import Path @@ -126,9 +126,12 @@ def summary_info(self) -> str: class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): def sizes(self) -> Mapping[Hashable, int]: with suppress(Exception): - return self._data.current_sequence.sizes() + return self._data.current_sequence.sizes return {} + def guess_channel_axis(self) -> Hashable | None: + return "c" + @classmethod def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: from pymmcore_plus.mda.handlers import TensorStoreHandler @@ -138,27 +141,20 @@ def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: def isel(self, indexers: Indices) -> np.ndarray: return self._data.isel(indexers) # type: ignore + def save_as_zarr(self, save_loc: str | Path) -> None: + if (store := self._data.store) is None: + return + import tensorstore as ts -class MM5DWriter(DataWrapper["_5DWriterBase"]): - def isel(self, indexers: Indices) -> np.ndarray: - p_index = indexers.get("p", 0) - if isinstance(p_index, slice): - warnings.warn("Cannot slice over position index", stacklevel=2) # TODO - p_index = p_index.start - p_index = cast(int, p_index) + new_spec = store.spec().to_json() + new_spec["kvstore"] = {"driver": "file", "path": str(save_loc)} + new_ts = ts.open(new_spec, create=True).result() + new_ts[:] = store.read().result() - try: - sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] - except IndexError as e: - raise IndexError( - f"Position index {p_index} out of range for " - f"{len(self._data.position_sizes)}" - ) from e - data = self._data.position_arrays[self._data.get_position_key(p_index)] - full = slice(None, None) - index = tuple(indexers.get(k, full) for k in sizes) - return data[index] # type: ignore +class MM5DWriter(DataWrapper["_5DWriterBase"]): + def guess_channel_axis(self) -> Hashable | None: + return "c" @classmethod def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: @@ -180,6 +176,26 @@ def save_as_zarr(self, save_loc: str | Path) -> None: zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") + def isel(self, indexers: Indices) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + warnings.warn("Cannot slice over position index", stacklevel=2) # TODO + p_index = p_index.start + p_index = cast(int, p_index) + + try: + sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for " + f"{len(self._data.position_sizes)}" + ) from e + + data = self._data.position_arrays[self._data.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore + class XarrayWrapper(DataWrapper["xr.DataArray"]): def isel(self, indexers: Indices) -> np.ndarray: @@ -215,6 +231,9 @@ def supports(cls, obj: Any) -> TypeGuard[da.Array]: return True return False + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(url=str(save_loc)) + class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): def __init__(self, data: Any) -> None: From 88e93a64315b1cee91af891f2ac1b23095091ba7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 22:58:16 +0000 Subject: [PATCH 59/73] style(pre-commit.ci): auto fixes [...] --- src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 127020d6d..cdda4f82f 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -5,7 +5,7 @@ import superqt import useq -from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter, TensorStoreHandler +from pymmcore_plus.mda.handlers import TensorStoreHandler from ._save_button import SaveButton from ._stack_viewer import StackViewer From 2df4ce3162bef3bd75c48d763ad9b37ebbae7f71 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Wed, 5 Jun 2024 07:56:55 -0400 Subject: [PATCH 60/73] don't autoscale invisible --- src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py index 84c59521e..65ba69772 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py @@ -88,13 +88,19 @@ def _on_clims_changed(self, clims: tuple[float, float]) -> None: def _on_visible_changed(self, visible: bool) -> None: for handle in self._handles: handle.visible = visible + if visible: + self.update_autoscale() def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: for handle in self._handles: handle.cmap = cmap def update_autoscale(self) -> None: - if not self._auto_clim.isChecked(): + if ( + not self._auto_clim.isChecked() + or not self._visible.isChecked() + or not self._handles + ): return # find the min and max values for the current channel From 527a737ebbdf9969bf9ed5138c8a7d8c00bc3032 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 7 Jun 2024 09:27:03 -0400 Subject: [PATCH 61/73] start adding 3d --- .../_stack_viewer_v2/_backends/_vispy.py | 50 ++++++++++++++----- .../_stack_viewer_v2/_protocols.py | 3 +- .../_stack_viewer_v2/_stack_viewer.py | 36 +++++++++++-- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py index d17e49eab..11b3437f9 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, cast import numpy as np from superqt.utils import qthrottled @@ -14,33 +14,36 @@ class VispyImageHandle: - def __init__(self, image: scene.visuals.Image) -> None: - self._image = image + def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: + self._visual = visual @property def data(self) -> np.ndarray: - return self._image._data # type: ignore + try: + return self._visual._data # type: ignore + except AttributeError: + return self._visual._last_data @data.setter def data(self, data: np.ndarray) -> None: - self._image.set_data(data) + self._visual.set_data(data) @property def visible(self) -> bool: - return bool(self._image.visible) + return bool(self._visual.visible) @visible.setter def visible(self, visible: bool) -> None: - self._image.visible = visible + self._visual.visible = visible @property def clim(self) -> Any: - return self._image.clim + return self._visual.clim @clim.setter def clim(self, clims: tuple[float, float]) -> None: with suppress(ZeroDivisionError): - self._image.clim = clims + self._visual.clim = clims @property def cmap(self) -> cmap.Colormap: @@ -49,7 +52,7 @@ def cmap(self) -> cmap.Colormap: @cmap.setter def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap - self._image.cmap = cmap.to_vispy() + self._visual.cmap = cmap.to_vispy() @property def transform(self) -> np.ndarray: @@ -60,7 +63,7 @@ def transform(self, transform: np.ndarray) -> None: raise NotImplementedError def remove(self) -> None: - self._image.parent = None + self._visual.parent = None class VispyViewerCanvas: @@ -74,11 +77,18 @@ def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info self._canvas = scene.SceneCanvas() self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) - self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) self._has_set_range = False central_wdg: scene.Widget = self._canvas.central_widget - self._view: scene.ViewBox = central_wdg.add_view(camera=self._camera) + self._view: scene.ViewBox = central_wdg.add_view() + # self.set_ndim(2) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + if ndim == 3: + self._camera = scene.ArcballCamera() + else: + self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + self._view.camera = self._camera def qwidget(self) -> QWidget: return cast("QWidget", self._canvas.native) @@ -101,6 +111,20 @@ def add_image( handle.cmap = cmap return handle + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> VispyImageHandle: + vol = scene.visuals.Volume(data, parent=self._view.scene) + vol.set_gl_state("additive", depth_test=True) + vol.interactive = True + if not self._has_set_range: + self.set_range() + self._has_set_range = True + handle = VispyImageHandle(vol) + if cmap is not None: + handle.cmap = cmap + return handle + def set_range( self, x: tuple[float, float] | None = None, diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py index 8b8d5d67a..ecf1d3827 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Protocol +from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol if TYPE_CHECKING: import cmap @@ -30,6 +30,7 @@ def remove(self) -> None: ... class PCanvas(Protocol): def __init__(self, set_info: Callable[[str], None]) -> None: ... + def set_ndim(self, ndim: Literal[2, 3]) -> None: ... def set_range( self, x: tuple[float, float] | None = ..., diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 0ad90962d..f0352cc96 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -3,7 +3,7 @@ from collections import defaultdict from enum import Enum from itertools import cycle -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Iterable, Literal, Mapping, Sequence, cast import cmap import numpy as np @@ -187,6 +187,10 @@ def __init__( self._cmap_cycle = cycle(self._cmaps) # the last future that was created by _update_data_for_index self._last_future: Future | None = None + + # number of dimensions to display + self._ndims: Literal[2, 3] = 2 + # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels @@ -198,12 +202,18 @@ def __init__( ) self._set_range_btn.clicked.connect(self._on_set_range_clicked) + # button to change number of displayed dimensions + self._ndims_btn = QPushButton("Dims", self) + self._ndims_btn.clicked.connect(self._change_ndims) + # place to display dataset summary self._data_info_label = QElidingLabel("", parent=self) # place to display arbitrary text self._hover_info_label = QLabel("", self) # the canvas that displays the images self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) + self._canvas.set_ndim(self._ndims) + # the sliders that control the index of the displayed image self._dims_sliders = DimsSliders(self) self._dims_sliders.valueChanged.connect( @@ -231,6 +241,7 @@ def __init__( btns.addStretch() btns.addWidget(self._channel_mode_btn) btns.addWidget(self._set_range_btn) + btns.addWidget(self._ndims_btn) layout = QVBoxLayout(self) layout.setSpacing(2) @@ -248,6 +259,17 @@ def __init__( self.set_data(data) # ------------------- PUBLIC API ---------------------------- + def _change_ndims(self) -> None: + self.set_ndim(3 if self._ndims == 2 else 2) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions to display.""" + self._ndims = ndim + self._canvas.set_ndim(ndim) + print("Setting ndim to", ndim) + if self._img_handles: + self._clear_images() + self._update_data_for_index(self._dims_sliders.value()) @property def data(self) -> Any: @@ -291,7 +313,7 @@ def set_data( # update the dimensions we are visualizing if visualized_dims is None: - visualized_dims = list(self._sizes)[-2:] + visualized_dims = list(self._sizes)[-self._ndims :] self.set_visualized_dims(visualized_dims) # update the range of all the sliders to match the sizes we set above @@ -329,7 +351,7 @@ def update_slider_ranges( self._dims_sliders.setMinima(_to_sizes(mins)) # FIXME: this needs to be moved and made user-controlled - for dim in list(maxes.keys())[-2:]: + for dim in list(maxes.keys())[-self._ndims :]: self._dims_sliders.set_dimension_visible(dim, False) def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: @@ -446,6 +468,7 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: """ imkey = self._image_key(index) datum = self._reduce_data_for_display(data) + print("showing", imkey, datum.shape, datum.dtype) if handles := self._img_handles[imkey]: for handle in handles: handle.data = datum @@ -463,7 +486,10 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: # it's better just to not add it at all... if np.max(datum) == 0: return - handles.append(self._canvas.add_image(datum, cmap=cm)) + if datum.ndim == 2: + handles.append(self._canvas.add_image(datum, cmap=cm)) + elif datum.ndim == 3: + handles.append(self._canvas.add_volume(datum, cmap=cm)) if imkey not in self._lut_ctrls: channel_name = self._get_channel_name(index) self._lut_ctrls[imkey] = c = LutControl( @@ -495,7 +521,7 @@ def _reduce_data_for_display( # - for better way to determine which dims need to be reduced (currently just # the smallest dims) data = data.squeeze() - visualized_dims = 2 + visualized_dims = self._ndims if extra_dims := data.ndim - visualized_dims: shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) From 63c2c479084247be76faeb0b89c771990774649d Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 7 Jun 2024 09:37:38 -0400 Subject: [PATCH 62/73] remove print --- src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index f0352cc96..ff2e003f9 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -266,7 +266,6 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: """Set the number of dimensions to display.""" self._ndims = ndim self._canvas.set_ndim(ndim) - print("Setting ndim to", ndim) if self._img_handles: self._clear_images() self._update_data_for_index(self._dims_sliders.value()) @@ -468,7 +467,6 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: """ imkey = self._image_key(index) datum = self._reduce_data_for_display(data) - print("showing", imkey, datum.shape, datum.dtype) if handles := self._img_handles[imkey]: for handle in handles: handle.data = datum From b26e0bf5d87723782a86e0aa1b777e5a30b2aefb Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 7 Jun 2024 13:58:42 -0400 Subject: [PATCH 63/73] better 3d --- examples/stack_viewer/numpy_arr.py | 12 ++-- examples/stack_viewer/tensorstore_arr.py | 1 + .../_stack_viewer_v2/_backends/_vispy.py | 72 ++++++++++++++----- .../_stack_viewer_v2/_dims_slider.py | 19 +++-- .../_stack_viewer_v2/_indexing.py | 3 + .../_stack_viewer_v2/_protocols.py | 3 + .../_stack_viewer_v2/_stack_viewer.py | 16 ++--- 7 files changed, 89 insertions(+), 37 deletions(-) diff --git a/examples/stack_viewer/numpy_arr.py b/examples/stack_viewer/numpy_arr.py index 06711f3d4..e11c78e5f 100644 --- a/examples/stack_viewer/numpy_arr.py +++ b/examples/stack_viewer/numpy_arr.py @@ -48,12 +48,16 @@ def generate_5d_sine_wave( return output -# Example usage -array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions -sine_wave_5d = generate_5d_sine_wave(array_shape) +try: + from skimage import data + + img = data.cells3d() +except Exception: + img = generate_5d_sine_wave((10, 3, 8, 512, 512)) + if __name__ == "__main__": qapp = QtWidgets.QApplication([]) - v = StackViewer(sine_wave_5d, channel_axis=1) + v = StackViewer(img) v.show() qapp.exec() diff --git a/examples/stack_viewer/tensorstore_arr.py b/examples/stack_viewer/tensorstore_arr.py index d15b7f98f..5eb9a1f64 100644 --- a/examples/stack_viewer/tensorstore_arr.py +++ b/examples/stack_viewer/tensorstore_arr.py @@ -14,6 +14,7 @@ dtype=ts.uint8, ).result() ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8) +ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]] if __name__ == "__main__": qapp = QtWidgets.QApplication([]) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py index 11b3437f9..3a3f2988a 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py @@ -4,14 +4,21 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, cast import numpy as np +import vispy +import vispy.scene +import vispy.visuals from superqt.utils import qthrottled from vispy import scene +from vispy.util.quaternion import Quaternion if TYPE_CHECKING: import cmap from qtpy.QtWidgets import QWidget from vispy.scene.events import SceneMouseEvent +turn = np.sin(np.pi / 4) +DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) + class VispyImageHandle: def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: @@ -77,18 +84,37 @@ def __init__(self, set_info: Callable[[str], None]) -> None: self._set_info = set_info self._canvas = scene.SceneCanvas() self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) - self._has_set_range = False + self._current_shape: tuple[int, ...] = () + self._last_state: dict[Literal[2, 3], Any] = {} central_wdg: scene.Widget = self._canvas.central_widget self._view: scene.ViewBox = central_wdg.add_view() - # self.set_ndim(2) + self._ndim: Literal[2, 3] | None = None + + @property + def _camera(self) -> vispy.scene.cameras.BaseCamera: + return self._view.camera def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim == self._ndim: + return + elif self._ndim is not None: + # remember the current state before switching to the new camera + self._last_state[self._ndim] = self._camera.get_state() + + self._ndim = ndim if ndim == 3: - self._camera = scene.ArcballCamera() + cam = scene.ArcballCamera() + # this sets the initial view similar to what the panzoom view would have. + cam._quaternion = DEFAULT_QUATERNION else: - self._camera = scene.PanZoomCamera(aspect=1, flip=(0, 1)) - self._view.camera = self._camera + cam = scene.PanZoomCamera(aspect=1, flip=(0, 1)) + + # restore the previous state if it exists + if state := self._last_state.get(ndim): + cam.set_state(state) + self._view.camera = cam def qwidget(self) -> QWidget: return cast("QWidget", self._canvas.native) @@ -103,9 +129,10 @@ def add_image( img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True - if not self._has_set_range: - self.set_range() - self._has_set_range = True + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() handle = VispyImageHandle(img) if cmap is not None: handle.cmap = cmap @@ -114,12 +141,15 @@ def add_image( def add_volume( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: - vol = scene.visuals.Volume(data, parent=self._view.scene) - vol.set_gl_state("additive", depth_test=True) + vol = scene.visuals.Volume( + data, parent=self._view.scene, interpolation="nearest" + ) + # vol.set_gl_state("additive", depth_test=True) vol.interactive = True - if not self._has_set_range: - self.set_range() - self._has_set_range = True + if data is not None: + self._current_shape, prev_shape = data.shape, self._current_shape + if not prev_shape: + self.set_range() handle = VispyImageHandle(vol) if cmap is not None: handle.cmap = cmap @@ -129,19 +159,29 @@ def set_range( self, x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, - margin: float = 0.01, + z: tuple[float, float] | None = None, + margin: float = 0.0, ) -> None: """Update the range of the PanZoomCamera. When called with no arguments, the range is set to the full extent of the data. """ - self._camera.set_range(x=x, y=y, margin=margin) + if len(self._current_shape) >= 2: + if x is None: + x = (0, self._current_shape[-1]) + if y is None: + y = (0, self._current_shape[-2]) + if z is None and len(self._current_shape) == 3: + z = (0, self._current_shape[-3]) + if isinstance(self._camera, scene.ArcballCamera): + self._camera._quaternion = DEFAULT_QUATERNION + self._view.camera.set_range(x=x, y=y, z=z, margin=margin) def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" images = [] # Get the images the mouse is over - # FIXME: must be a better way to do this + # FIXME: this is narsty ... there must be a better way to do this seen = set() try: while visual := self._canvas.visual_at(event.pos): diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py index 2987c2801..219679437 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py @@ -445,32 +445,32 @@ def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: viz = visible if isinstance(visible, bool) else visible.get(dim, False) slider._lock_btn.setVisible(viz) - def add_dimension(self, name: DimKey, val: Index | None = None) -> None: + def add_dimension(self, key: DimKey, val: Index | None = None) -> None: """Add a new dimension to the DimsSliders widget. Parameters ---------- - name : Hashable + key : Hashable The name of the dimension. val : int | slice, optional The initial value for the dimension. If a slice, the slider will be in slice mode. """ - self._sliders[name] = slider = DimsSlider(dimension_key=name, parent=self) - if isinstance(self._locks_visible, dict) and name in self._locks_visible: - slider._lock_btn.setVisible(self._locks_visible[name]) + self._sliders[key] = slider = DimsSlider(dimension_key=key, parent=self) + if isinstance(self._locks_visible, dict) and key in self._locks_visible: + slider._lock_btn.setVisible(self._locks_visible[key]) else: slider._lock_btn.setVisible(bool(self._locks_visible)) val_int = val.start if isinstance(val, slice) else val - slider.setVisible(name not in self._invisible_dims) + slider.setVisible(key not in self._invisible_dims) if isinstance(val_int, int): slider.setRange(val_int, val_int) elif isinstance(val_int, slice): slider.setRange(val_int.start or 0, val_int.stop or 1) val = val if val is not None else 0 - self._current_index[name] = val + self._current_index[key] = val slider.forceValue(val) slider.valueChanged.connect(self._on_dim_slider_value_changed) cast("QVBoxLayout", self.layout()).addWidget(slider) @@ -483,8 +483,13 @@ def set_dimension_visible(self, key: DimKey, visible: bool) -> None: """ if visible: self._invisible_dims.discard(key) + if key in self._sliders: + self._current_index[key] = self._sliders[key].value() + else: + self.add_dimension(key) else: self._invisible_dims.add(key) + self._current_index.pop(key, None) if key in self._sliders: self._sliders[key].setVisible(visible) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 8f1c39895..084ae65e8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -242,6 +242,9 @@ def __init__(self, data: Any) -> None: self._ts = ts + def sizes(self) -> Mapping[Hashable, int]: + return {dim.label: dim.size for dim in self._data.domain} + def isel(self, indexers: Indices) -> np.ndarray: result = self._data[self._ts.d[*indexers][*indexers.values()]].read().result() return np.asarray(result) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py index ecf1d3827..9fc00082f 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py @@ -42,3 +42,6 @@ def qwidget(self) -> QWidget: ... def add_image( self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... ) -> PImageHandle: ... + def add_volume( + self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... + ) -> PImageHandle: ... diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index ff2e003f9..38ccd1e6c 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -256,7 +256,8 @@ def __init__( # SETUP ------------------------------------------------------ self.set_channel_mode(channel_mode) - self.set_data(data) + if data is not None: + self.set_data(data) # ------------------- PUBLIC API ---------------------------- def _change_ndims(self) -> None: @@ -266,6 +267,9 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: """Set the number of dimensions to display.""" self._ndims = ndim self._canvas.set_ndim(ndim) + non_channels = [x for x in self._sizes if x != self._channel_axis] + visualized_dims = non_channels[-self._ndims :] + self.set_visualized_dims(visualized_dims) if self._img_handles: self._clear_images() self._update_data_for_index(self._dims_sliders.value()) @@ -416,7 +420,6 @@ def _update_data_for_index(self, index: Indices) -> None: if self._last_future: self._last_future.cancel() - self._last_future = f = self._isel(index) f.add_done_callback(self._on_data_slice_ready) @@ -478,12 +481,6 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) - # FIXME: this is a hack ... - # however, there's a bug in the vispy backend such that if the first - # image is all zeros, it persists even if the data is updated - # it's better just to not add it at all... - if np.max(datum) == 0: - return if datum.ndim == 2: handles.append(self._canvas.add_image(datum, cmap=cm)) elif datum.ndim == 3: @@ -514,7 +511,6 @@ def _reduce_data_for_display( This also coerces 64-bit data to 32-bit data. """ # TODO - # - allow for 3d data # - allow dimensions to control how they are reduced (as opposed to just max) # - for better way to determine which dims need to be reduced (currently just # the smallest dims) @@ -523,7 +519,7 @@ def _reduce_data_for_display( if extra_dims := data.ndim - visualized_dims: shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) - return reductor(data, axis=smallest_dims) + data = reductor(data, axis=smallest_dims) if data.dtype.itemsize > 4: # More than 32 bits if np.issubdtype(data.dtype, np.integer): From 631d4218c8fad3db7c9814a0d89f358eaab03d18 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 7 Jun 2024 16:08:36 -0400 Subject: [PATCH 64/73] more tweaks --- .../_stack_viewer_v2/_backends/_pygfx.py | 11 +++ .../_stack_viewer_v2/_backends/_vispy.py | 18 ++-- .../_stack_viewer_v2/_indexing.py | 17 +++- .../_stack_viewer_v2/_protocols.py | 5 +- .../_stack_viewer_v2/_stack_viewer.py | 90 ++++++++++++------- 5 files changed, 99 insertions(+), 42 deletions(-) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py index 37fe110b4..e085526fa 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py @@ -97,6 +97,16 @@ def refresh(self) -> None: def _animate(self) -> None: self._renderer.render(self._scene, self._camera) + def set_ndim(self, ndim: int) -> None: + """Set the number of dimensions of the displayed data.""" + if ndim != 2: + raise NotImplementedError("Volume rendering is not supported by pygfx yet.") + + def add_volume( + self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None + ) -> PyGFXImageHandle: + raise NotImplementedError("Volume rendering is not supported by pygfx yet.") + def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> PyGFXImageHandle: @@ -118,6 +128,7 @@ def set_range( self, x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, margin: float = 0.05, ) -> None: """Update the range of the PanZoomCamera. diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py index 3a3f2988a..c74047857 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py @@ -27,9 +27,9 @@ def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: @property def data(self) -> np.ndarray: try: - return self._visual._data # type: ignore + return self._visual._data # type: ignore [no-any-return] except AttributeError: - return self._visual._last_data + return self._visual._last_data # type: ignore [no-any-return] @data.setter def data(self, data: np.ndarray) -> None: @@ -105,7 +105,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: self._ndim = ndim if ndim == 3: - cam = scene.ArcballCamera() + cam = scene.ArcballCamera(fov=0) # this sets the initial view similar to what the panzoom view would have. cam._quaternion = DEFAULT_QUATERNION else: @@ -144,11 +144,11 @@ def add_volume( vol = scene.visuals.Volume( data, parent=self._view.scene, interpolation="nearest" ) - # vol.set_gl_state("additive", depth_test=True) + vol.set_gl_state("additive", depth_test=False) vol.interactive = True if data is not None: self._current_shape, prev_shape = data.shape, self._current_shape - if not prev_shape: + if len(prev_shape) != 3: self.set_range() handle = VispyImageHandle(vol) if cmap is not None: @@ -160,7 +160,7 @@ def set_range( x: tuple[float, float] | None = None, y: tuple[float, float] | None = None, z: tuple[float, float] | None = None, - margin: float = 0.0, + margin: float = 0.01, ) -> None: """Update the range of the PanZoomCamera. @@ -173,9 +173,13 @@ def set_range( y = (0, self._current_shape[-2]) if z is None and len(self._current_shape) == 3: z = (0, self._current_shape[-3]) - if isinstance(self._camera, scene.ArcballCamera): + is_3d = isinstance(self._camera, scene.ArcballCamera) + if is_3d: self._camera._quaternion = DEFAULT_QUATERNION self._view.camera.set_range(x=x, y=y, z=z, margin=margin) + if is_3d: + max_size = max(self._current_shape) + self._camera.scale_factor = max_size + 6 def _on_mouse_move(self, event: SceneMouseEvent) -> None: """Mouse moved on the canvas, display the pixel value and position.""" diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py index 084ae65e8..6328708d6 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py @@ -5,7 +5,16 @@ from abc import abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from contextlib import suppress -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Sequence, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Generic, + Hashable, + Iterable, + Mapping, + Sequence, + TypeVar, + cast, +) import numpy as np @@ -66,9 +75,11 @@ def isel(self, indexers: Indices) -> np.ndarray: """ raise NotImplementedError - def isel_async(self, indexers: Indices) -> Future[tuple[Indices, np.ndarray]]: + def isel_async( + self, indexers: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: """Asynchronous version of isel.""" - return _EXECUTOR.submit(lambda: (indexers, self.isel(indexers))) + return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) @classmethod @abstractmethod diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py index 9fc00082f..413038ded 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py @@ -33,8 +33,9 @@ def __init__(self, set_info: Callable[[str], None]) -> None: ... def set_ndim(self, ndim: Literal[2, 3]) -> None: ... def set_range( self, - x: tuple[float, float] | None = ..., - y: tuple[float, float] | None = ..., + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, margin: float = ..., ) -> None: ... def refresh(self) -> None: ... diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py index 38ccd1e6c..6aa6f28d6 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py @@ -58,6 +58,9 @@ def __init__(self, parent: QWidget | None = None): self.setCheckable(True) self.toggled.connect(self.next_mode) + # set minimum width to the width of the larger string 'composite' + self.setMinimumWidth(92) # FIXME: magic number + def next_mode(self) -> None: if self.isChecked(): self.setMode(ChannelMode.MONO) @@ -74,6 +77,15 @@ def setMode(self, mode: ChannelMode) -> None: self.setChecked(mode == ChannelMode.MONO) +class DimToggleButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + icn = QIconifyIcon("f7:view-2d", color="#333333") + icn.addKey("f7:view-3d", state=QIconifyIcon.State.On, color="white") + super().__init__(icn, "", parent) + self.setCheckable(True) + self.setChecked(True) + + # @dataclass # class LutModel: # name: str = "" @@ -203,8 +215,8 @@ def __init__( self._set_range_btn.clicked.connect(self._on_set_range_clicked) # button to change number of displayed dimensions - self._ndims_btn = QPushButton("Dims", self) - self._ndims_btn.clicked.connect(self._change_ndims) + self._ndims_btn = DimToggleButton(self) + self._ndims_btn.clicked.connect(self.toggle_3d) # place to display dataset summary self._data_info_label = QElidingLabel("", parent=self) @@ -240,8 +252,8 @@ def __init__( btns.setSpacing(0) btns.addStretch() btns.addWidget(self._channel_mode_btn) - btns.addWidget(self._set_range_btn) btns.addWidget(self._ndims_btn) + btns.addWidget(self._set_range_btn) layout = QVBoxLayout(self) layout.setSpacing(2) @@ -260,20 +272,6 @@ def __init__( self.set_data(data) # ------------------- PUBLIC API ---------------------------- - def _change_ndims(self) -> None: - self.set_ndim(3 if self._ndims == 2 else 2) - - def set_ndim(self, ndim: Literal[2, 3]) -> None: - """Set the number of dimensions to display.""" - self._ndims = ndim - self._canvas.set_ndim(ndim) - non_channels = [x for x in self._sizes if x != self._channel_axis] - visualized_dims = non_channels[-self._ndims :] - self.set_visualized_dims(visualized_dims) - if self._img_handles: - self._clear_images() - self._update_data_for_index(self._dims_sliders.value()) - @property def data(self) -> Any: """Return the data backing the view.""" @@ -357,6 +355,27 @@ def update_slider_ranges( for dim in list(maxes.keys())[-self._ndims :]: self._dims_sliders.set_dimension_visible(dim, False) + def toggle_3d(self) -> None: + self.set_ndim(3 if self._ndims == 2 else 2) + + def set_ndim(self, ndim: Literal[2, 3]) -> None: + """Set the number of dimensions to display.""" + self._ndims = ndim + self._canvas.set_ndim(ndim) + + # set the visibility of the last non-channel dimension + sizes = list(self._sizes) + if self._channel_axis is not None: + sizes = [x for x in sizes if x != self._channel_axis] + if len(sizes) >= 3: + dim3 = sizes[-3] + self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) + + # clear image handles and redraw + if self._img_handles: + self._clear_images() + self._update_data_for_index(self._dims_sliders.value()) + def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: """Set the mode for displaying the channels. @@ -416,11 +435,22 @@ def _update_data_for_index(self, index: Indices) -> None: self._channel_axis is not None and self._channel_mode == ChannelMode.COMPOSITE ): - index = {**index, self._channel_axis: ALL_CHANNELS} + indices: list[Indices] = [ + {**index, self._channel_axis: i} + for i in range(self._sizes[self._channel_axis]) + ] + else: + indices = [index] if self._last_future: self._last_future.cancel() - self._last_future = f = self._isel(index) + + # don't request any dimensions that are not visualized + indices = [ + {k: v for k, v in idx.items() if k not in self._visualized_dims} + for idx in indices + ] + self._last_future = f = self._isel(indices) f.add_done_callback(self._on_data_slice_ready) def closeEvent(self, a0: QCloseEvent | None) -> None: @@ -429,16 +459,19 @@ def closeEvent(self, a0: QCloseEvent | None) -> None: self._last_future = None super().closeEvent(a0) - def _isel(self, index: Indices) -> Future[tuple[Indices, np.ndarray]]: + def _isel( + self, indices: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: """Select data from the datastore using the given index.""" - idx = {k: v for k, v in index.items() if k not in self._visualized_dims} try: - return self._data_wrapper.isel_async(idx) + return self._data_wrapper.isel_async(indices) except Exception as e: - raise type(e)(f"Failed to index data with {idx}: {e}") from e + raise type(e)(f"Failed to index data with {indices}: {e}") from e @ensure_main_thread # type: ignore - def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> None: + def _on_data_slice_ready( + self, future: Future[Iterable[tuple[Indices, np.ndarray]]] + ) -> None: """Update the displayed image for the given index. Connected to the future returned by _isel. @@ -450,15 +483,12 @@ def _on_data_slice_ready(self, future: Future[tuple[Indices, np.ndarray]]) -> No if future.cancelled(): return - index, data = future.result() - # assume that if we have channels remaining, that they are the first axis - # FIXME: this is a bad assumption - data = iter(data) if index.get(self._channel_axis) is ALL_CHANNELS else [data] + data = future.result() # FIXME: # `self._channel_axis: i` is a bug; we assume channel indices start at 0 # but the actual values used for indices are up to the user. - for i, datum in enumerate(data): - self._update_canvas_data(datum, {**index, self._channel_axis: i}) + for idx, datum in data: + self._update_canvas_data(datum, idx) self._canvas.refresh() def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: From f3d3932de459b7c30015d7f1d82c79b3bf82141f Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 8 Jun 2024 08:09:48 -0400 Subject: [PATCH 65/73] use ndv --- examples/stack_viewer/dask_arr.py | 27 - examples/stack_viewer/jax_arr.py | 17 - examples/stack_viewer/numpy_arr.py | 63 -- examples/stack_viewer/tensorstore_arr.py | 23 - examples/stack_viewer/xarray_arr.py | 14 - examples/stack_viewer/zarr_arr.py | 16 - .../_stack_viewer_v2/__init__.py | 3 +- .../_stack_viewer_v2/_backends/__init__.py | 36 -- .../_stack_viewer_v2/_backends/_pygfx.py | 175 ------ .../_stack_viewer_v2/_backends/_vispy.py | 212 ------- .../_stack_viewer_v2/_dims_slider.py | 528 ---------------- .../_stack_viewer_v2/_indexing.py | 292 --------- .../_stack_viewer_v2/_lut_control.py | 121 ---- .../_stack_viewer_v2/_mda_viewer.py | 33 +- .../_stack_viewer_v2/_protocols.py | 48 -- .../_stack_viewer_v2/_save_button.py | 34 - .../_stack_viewer_v2/_stack_viewer.py | 591 ------------------ 17 files changed, 30 insertions(+), 2203 deletions(-) delete mode 100644 examples/stack_viewer/dask_arr.py delete mode 100644 examples/stack_viewer/jax_arr.py delete mode 100644 examples/stack_viewer/numpy_arr.py delete mode 100644 examples/stack_viewer/tensorstore_arr.py delete mode 100644 examples/stack_viewer/xarray_arr.py delete mode 100644 examples/stack_viewer/zarr_arr.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_indexing.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_protocols.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_save_button.py delete mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py diff --git a/examples/stack_viewer/dask_arr.py b/examples/stack_viewer/dask_arr.py deleted file mode 100644 index eba37eb58..000000000 --- a/examples/stack_viewer/dask_arr.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -import numpy as np -from dask.array.core import map_blocks -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2 import StackViewer - -frame_size = (1024, 1024) - - -def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: - if isinstance(block_id, np.ndarray): - return None - data = np.random.randint(0, 255, size=frame_size, dtype=np.uint8) - return data[(None,) * 3] - - -chunks = [(1,) * x for x in (1000, 64, 3)] -chunks += [(x,) for x in frame_size] -dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(dask_arr) - v.show() - qapp.exec() diff --git a/examples/stack_viewer/jax_arr.py b/examples/stack_viewer/jax_arr.py deleted file mode 100644 index b57a129d6..000000000 --- a/examples/stack_viewer/jax_arr.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import jax.numpy as jnp -from numpy_arr import generate_5d_sine_wave -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2._stack_viewer import StackViewer - -# Example usage -array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions -sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape)) - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(sine_wave_5d, channel_axis=1) - v.show() - qapp.exec() diff --git a/examples/stack_viewer/numpy_arr.py b/examples/stack_viewer/numpy_arr.py deleted file mode 100644 index e11c78e5f..000000000 --- a/examples/stack_viewer/numpy_arr.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -import numpy as np -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2._stack_viewer import StackViewer - - -def generate_5d_sine_wave( - shape: tuple[int, int, int, int, int], - amplitude: float = 240, - base_frequency: float = 5, -) -> np.ndarray: - """5D dataset.""" - # Unpack the dimensions - angle_dim, freq_dim, phase_dim, ny, nx = shape - - # Create an empty array to hold the data - output = np.zeros(shape) - - # Define spatial coordinates for the last two dimensions - half_per = base_frequency * np.pi - x = np.linspace(-half_per, half_per, nx) - y = np.linspace(-half_per, half_per, ny) - y, x = np.meshgrid(y, x) - - # Iterate through each parameter in the higher dimensions - for phase_idx in range(phase_dim): - for freq_idx in range(freq_dim): - for angle_idx in range(angle_dim): - # Calculate phase and frequency - phase = np.pi / phase_dim * phase_idx - frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step - - # Calculate angle - angle = np.pi / angle_dim * angle_idx - # Rotate x and y coordinates - xr = np.cos(angle) * x - np.sin(angle) * y - np.sin(angle) * x + np.cos(angle) * y - - # Compute the sine wave - sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase) - sine_wave += amplitude * 0.5 - - # Assign to the output array - output[angle_idx, freq_idx, phase_idx] = sine_wave - - return output - - -try: - from skimage import data - - img = data.cells3d() -except Exception: - img = generate_5d_sine_wave((10, 3, 8, 512, 512)) - - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(img) - v.show() - qapp.exec() diff --git a/examples/stack_viewer/tensorstore_arr.py b/examples/stack_viewer/tensorstore_arr.py deleted file mode 100644 index 5eb9a1f64..000000000 --- a/examples/stack_viewer/tensorstore_arr.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import numpy as np -import tensorstore as ts -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2 import StackViewer - -shape = (10, 4, 3, 512, 512) -ts_array = ts.open( - {"driver": "zarr", "kvstore": {"driver": "memory"}}, - create=True, - shape=shape, - dtype=ts.uint8, -).result() -ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8) -ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]] - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(ts_array) - v.show() - qapp.exec() diff --git a/examples/stack_viewer/xarray_arr.py b/examples/stack_viewer/xarray_arr.py deleted file mode 100644 index 3c0871582..000000000 --- a/examples/stack_viewer/xarray_arr.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -import xarray as xr -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2 import StackViewer - -da = xr.tutorial.open_dataset("air_temperature").air - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(da, colormaps=["thermal"], channel_mode="composite") - v.show() - qapp.exec() diff --git a/examples/stack_viewer/zarr_arr.py b/examples/stack_viewer/zarr_arr.py deleted file mode 100644 index 385cbddd0..000000000 --- a/examples/stack_viewer/zarr_arr.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import zarr -import zarr.storage -from qtpy import QtWidgets - -from pymmcore_widgets._stack_viewer_v2 import StackViewer - -URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr" -zarr_arr = zarr.open(URL, mode="r") - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = StackViewer(zarr_arr["s0"]) - v.show() - qapp.exec() diff --git a/src/pymmcore_widgets/_stack_viewer_v2/__init__.py b/src/pymmcore_widgets/_stack_viewer_v2/__init__.py index d144dff42..6cb7e2982 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/__init__.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/__init__.py @@ -1,4 +1,3 @@ from ._mda_viewer import MDAViewer -from ._stack_viewer import StackViewer -__all__ = ["StackViewer", "MDAViewer"] +__all__ = ["MDAViewer"] diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py deleted file mode 100644 index 9650021f9..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -import importlib -import importlib.util -import os -import sys -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pymmcore_widgets._stack_viewer_v2._protocols import PCanvas - - -def get_canvas(backend: str | None = None) -> type[PCanvas]: - backend = backend or os.getenv("CANVAS_BACKEND", None) - if backend == "vispy" or (backend is None and "vispy" in sys.modules): - from ._vispy import VispyViewerCanvas - - return VispyViewerCanvas - - if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): - from ._pygfx import PyGFXViewerCanvas - - return PyGFXViewerCanvas - - if backend is None: - if importlib.util.find_spec("vispy") is not None: - from ._vispy import VispyViewerCanvas - - return VispyViewerCanvas - - if importlib.util.find_spec("pygfx") is not None: - from ._pygfx import PyGFXViewerCanvas - - return PyGFXViewerCanvas - - raise RuntimeError("No canvas backend found") diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py deleted file mode 100644 index e085526fa..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_pygfx.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, cast - -import numpy as np -import pygfx -from qtpy.QtCore import QSize -from wgpu.gui.qt import QWgpuCanvas - -if TYPE_CHECKING: - import cmap - from pygfx.materials import ImageBasicMaterial - from pygfx.resources import Texture - from qtpy.QtWidgets import QWidget - - -class PyGFXImageHandle: - def __init__(self, image: pygfx.Image, render: Callable) -> None: - self._image = image - self._render = render - self._grid = cast("Texture", image.geometry.grid) - self._material = cast("ImageBasicMaterial", image.material) - - @property - def data(self) -> np.ndarray: - return self._grid.data # type: ignore - - @data.setter - def data(self, data: np.ndarray) -> None: - self._grid.data[:] = data - self._grid.update_range((0, 0, 0), self._grid.size) - - @property - def visible(self) -> bool: - return bool(self._image.visible) - - @visible.setter - def visible(self, visible: bool) -> None: - self._image.visible = visible - self._render() - - @property - def clim(self) -> Any: - return self._material.clim - - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: - self._material.clim = clims - self._render() - - @property - def cmap(self) -> cmap.Colormap: - return self._cmap - - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: - self._cmap = cmap - self._material.map = cmap.to_pygfx() - self._render() - - def remove(self) -> None: - if (par := self._image.parent) is not None: - par.remove(self._image) - - -class _QWgpuCanvas(QWgpuCanvas): - def sizeHint(self) -> QSize: - return QSize(512, 512) - - -class PyGFXViewerCanvas: - """pygfx-based canvas wrapper.""" - - def __init__(self, set_info: Callable[[str], None]) -> None: - self._set_info = set_info - - self._canvas = _QWgpuCanvas(size=(512, 512)) - self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) - # requires https://github.com/pygfx/pygfx/pull/752 - self._renderer.blend_mode = "additive" - self._scene = pygfx.Scene() - self._camera = cam = pygfx.OrthographicCamera(512, 512) - cam.local.scale_y = -1 - - cam.local.position = (256, 256, 0) - self._controller = pygfx.PanZoomController(cam, register_events=self._renderer) - # increase zoom wheel gain - self._controller.controls.update({"wheel": ("zoom_to_point", "push", -0.005)}) - - def qwidget(self) -> QWidget: - return cast("QWidget", self._canvas) - - def refresh(self) -> None: - self._canvas.update() - self._canvas.request_draw(self._animate) - - def _animate(self) -> None: - self._renderer.render(self._scene, self._camera) - - def set_ndim(self, ndim: int) -> None: - """Set the number of dimensions of the displayed data.""" - if ndim != 2: - raise NotImplementedError("Volume rendering is not supported by pygfx yet.") - - def add_volume( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> PyGFXImageHandle: - raise NotImplementedError("Volume rendering is not supported by pygfx yet.") - - def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> PyGFXImageHandle: - """Add a new Image node to the scene.""" - image = pygfx.Image( - pygfx.Geometry(grid=pygfx.Texture(data, dim=2)), - # depth_test=False for additive-like blending - pygfx.ImageBasicMaterial(depth_test=False), - ) - self._scene.add(image) - # FIXME: I suspect there are more performant ways to refresh the canvas - # look into it. - handle = PyGFXImageHandle(image, self.refresh) - if cmap is not None: - handle.cmap = cmap - return handle - - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - z: tuple[float, float] | None = None, - margin: float = 0.05, - ) -> None: - """Update the range of the PanZoomCamera. - - When called with no arguments, the range is set to the full extent of the data. - """ - if not self._scene.children: - return - - cam = self._camera - cam.show_object(self._scene) - - width, height, depth = np.ptp(self._scene.get_world_bounding_box(), axis=0) - if width < 0.01: - width = 1 - if height < 0.01: - height = 1 - cam.width = width - cam.height = height - cam.zoom = 1 - margin - self.refresh() - - # def _on_mouse_move(self, event: SceneMouseEvent) -> None: - # """Mouse moved on the canvas, display the pixel value and position.""" - # images = [] - # # Get the images the mouse is over - # seen = set() - # while visual := self._canvas.visual_at(event.pos): - # if isinstance(visual, scene.visuals.Image): - # images.append(visual) - # visual.interactive = False - # seen.add(visual) - # for visual in seen: - # visual.interactive = True - # if not images: - # return - - # tform = images[0].get_transform("canvas", "visual") - # px, py, *_ = (int(x) for x in tform.map(event.pos)) - # text = f"[{py}, {px}]" - # for c, img in enumerate(images): - # with suppress(IndexError): - # text += f" c{c}: {img._data[py, px]}" - # self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py b/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py deleted file mode 100644 index c74047857..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_backends/_vispy.py +++ /dev/null @@ -1,212 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Literal, cast - -import numpy as np -import vispy -import vispy.scene -import vispy.visuals -from superqt.utils import qthrottled -from vispy import scene -from vispy.util.quaternion import Quaternion - -if TYPE_CHECKING: - import cmap - from qtpy.QtWidgets import QWidget - from vispy.scene.events import SceneMouseEvent - -turn = np.sin(np.pi / 4) -DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) - - -class VispyImageHandle: - def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: - self._visual = visual - - @property - def data(self) -> np.ndarray: - try: - return self._visual._data # type: ignore [no-any-return] - except AttributeError: - return self._visual._last_data # type: ignore [no-any-return] - - @data.setter - def data(self, data: np.ndarray) -> None: - self._visual.set_data(data) - - @property - def visible(self) -> bool: - return bool(self._visual.visible) - - @visible.setter - def visible(self, visible: bool) -> None: - self._visual.visible = visible - - @property - def clim(self) -> Any: - return self._visual.clim - - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: - with suppress(ZeroDivisionError): - self._visual.clim = clims - - @property - def cmap(self) -> cmap.Colormap: - return self._cmap - - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: - self._cmap = cmap - self._visual.cmap = cmap.to_vispy() - - @property - def transform(self) -> np.ndarray: - raise NotImplementedError - - @transform.setter - def transform(self, transform: np.ndarray) -> None: - raise NotImplementedError - - def remove(self) -> None: - self._visual.parent = None - - -class VispyViewerCanvas: - """Vispy-based viewer for data. - - All vispy-specific code is encapsulated in this class (and non-vispy canvases - could be swapped in if needed as long as they implement the same interface). - """ - - def __init__(self, set_info: Callable[[str], None]) -> None: - self._set_info = set_info - self._canvas = scene.SceneCanvas() - self._canvas.events.mouse_move.connect(qthrottled(self._on_mouse_move, 60)) - self._current_shape: tuple[int, ...] = () - self._last_state: dict[Literal[2, 3], Any] = {} - - central_wdg: scene.Widget = self._canvas.central_widget - self._view: scene.ViewBox = central_wdg.add_view() - self._ndim: Literal[2, 3] | None = None - - @property - def _camera(self) -> vispy.scene.cameras.BaseCamera: - return self._view.camera - - def set_ndim(self, ndim: Literal[2, 3]) -> None: - """Set the number of dimensions of the displayed data.""" - if ndim == self._ndim: - return - elif self._ndim is not None: - # remember the current state before switching to the new camera - self._last_state[self._ndim] = self._camera.get_state() - - self._ndim = ndim - if ndim == 3: - cam = scene.ArcballCamera(fov=0) - # this sets the initial view similar to what the panzoom view would have. - cam._quaternion = DEFAULT_QUATERNION - else: - cam = scene.PanZoomCamera(aspect=1, flip=(0, 1)) - - # restore the previous state if it exists - if state := self._last_state.get(ndim): - cam.set_state(state) - self._view.camera = cam - - def qwidget(self) -> QWidget: - return cast("QWidget", self._canvas.native) - - def refresh(self) -> None: - self._canvas.update() - - def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> VispyImageHandle: - """Add a new Image node to the scene.""" - img = scene.visuals.Image(data, parent=self._view.scene) - img.set_gl_state("additive", depth_test=False) - img.interactive = True - if data is not None: - self._current_shape, prev_shape = data.shape, self._current_shape - if not prev_shape: - self.set_range() - handle = VispyImageHandle(img) - if cmap is not None: - handle.cmap = cmap - return handle - - def add_volume( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> VispyImageHandle: - vol = scene.visuals.Volume( - data, parent=self._view.scene, interpolation="nearest" - ) - vol.set_gl_state("additive", depth_test=False) - vol.interactive = True - if data is not None: - self._current_shape, prev_shape = data.shape, self._current_shape - if len(prev_shape) != 3: - self.set_range() - handle = VispyImageHandle(vol) - if cmap is not None: - handle.cmap = cmap - return handle - - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - z: tuple[float, float] | None = None, - margin: float = 0.01, - ) -> None: - """Update the range of the PanZoomCamera. - - When called with no arguments, the range is set to the full extent of the data. - """ - if len(self._current_shape) >= 2: - if x is None: - x = (0, self._current_shape[-1]) - if y is None: - y = (0, self._current_shape[-2]) - if z is None and len(self._current_shape) == 3: - z = (0, self._current_shape[-3]) - is_3d = isinstance(self._camera, scene.ArcballCamera) - if is_3d: - self._camera._quaternion = DEFAULT_QUATERNION - self._view.camera.set_range(x=x, y=y, z=z, margin=margin) - if is_3d: - max_size = max(self._current_shape) - self._camera.scale_factor = max_size + 6 - - def _on_mouse_move(self, event: SceneMouseEvent) -> None: - """Mouse moved on the canvas, display the pixel value and position.""" - images = [] - # Get the images the mouse is over - # FIXME: this is narsty ... there must be a better way to do this - seen = set() - try: - while visual := self._canvas.visual_at(event.pos): - if isinstance(visual, scene.visuals.Image): - images.append(visual) - visual.interactive = False - seen.add(visual) - except Exception: - return - for visual in seen: - visual.interactive = True - if not images: - return - - tform = images[0].get_transform("canvas", "visual") - px, py, *_ = (int(x) for x in tform.map(event.pos)) - text = f"[{py}, {px}]" - for c, img in enumerate(reversed(images)): - with suppress(IndexError): - value = img._data[py, px] - if isinstance(value, (np.floating, float)): - value = f"{value:.2f}" - text += f" {c}: {value}" - self._set_info(text) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py b/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py deleted file mode 100644 index 219679437..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_dims_slider.py +++ /dev/null @@ -1,528 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, cast -from warnings import warn - -from qtpy.QtCore import QPoint, QPointF, QSize, Qt, Signal -from qtpy.QtGui import QCursor, QResizeEvent -from qtpy.QtWidgets import ( - QDialog, - QDoubleSpinBox, - QFormLayout, - QFrame, - QHBoxLayout, - QLabel, - QPushButton, - QSizePolicy, - QSlider, - QSpinBox, - QVBoxLayout, - QWidget, -) -from superqt import QElidingLabel, QLabeledRangeSlider -from superqt.iconify import QIconifyIcon -from superqt.utils import signals_blocked - -if TYPE_CHECKING: - from typing import Hashable, Mapping, TypeAlias - - from PyQt6.QtGui import QResizeEvent - - # any hashable represent a single dimension in a AND array - DimKey: TypeAlias = Hashable - # any object that can be used to index a single dimension in an AND array - Index: TypeAlias = int | slice - # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) - # this object is used frequently to query or set the currently displayed slice - Indices: TypeAlias = Mapping[DimKey, Index] - # mapping of dimension keys to the maximum value for that dimension - Sizes: TypeAlias = Mapping[DimKey, int] - - -SS = """ -QSlider::groove:horizontal { - height: 15px; - background: qlineargradient( - x1:0, y1:0, x2:0, y2:1, - stop:0 rgba(128, 128, 128, 0.25), - stop:1 rgba(128, 128, 128, 0.1) - ); - border-radius: 3px; -} - -QSlider::handle:horizontal { - width: 38px; - background: #999999; - border-radius: 3px; -} - -QLabel { font-size: 12px; } - -QRangeSlider { qproperty-barColor: qlineargradient( - x1:0, y1:0, x2:0, y2:1, - stop:0 rgba(100, 80, 120, 0.2), - stop:1 rgba(100, 80, 120, 0.4) - )} - -SliderLabel { - font-size: 12px; - color: white; -} -""" - - -class QtPopup(QDialog): - """A generic popup window.""" - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.setModal(False) # if False, then clicking anywhere else closes it - self.setWindowFlags(Qt.WindowType.Popup | Qt.WindowType.FramelessWindowHint) - - self.frame = QFrame(self) - layout = QVBoxLayout(self) - layout.addWidget(self.frame) - layout.setContentsMargins(0, 0, 0, 0) - - def show_above_mouse(self, *args: Any) -> None: - """Show popup dialog above the mouse cursor position.""" - pos = QCursor().pos() # mouse position - szhint = self.sizeHint() - pos -= QPoint(szhint.width() // 2, szhint.height() + 14) - self.move(pos) - self.resize(self.sizeHint()) - self.show() - - -class PlayButton(QPushButton): - """Just a styled QPushButton that toggles between play and pause icons.""" - - fpsChanged = Signal(float) - - PLAY_ICON = "bi:play-fill" - PAUSE_ICON = "bi:pause-fill" - - def __init__(self, fps: float = 20, parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.PLAY_ICON, color="#888888") - icn.addKey(self.PAUSE_ICON, state=QIconifyIcon.State.On, color="#4580DD") - super().__init__(icn, "", parent) - self.spin = QDoubleSpinBox(self) - self.spin.setRange(0.5, 100) - self.spin.setValue(fps) - self.spin.valueChanged.connect(self.fpsChanged) - self.setCheckable(True) - self.setFixedSize(14, 18) - self.setIconSize(QSize(16, 16)) - self.setStyleSheet("border: none; padding: 0; margin: 0;") - - self._popup = QtPopup(self) - form = QFormLayout(self._popup.frame) - form.setContentsMargins(6, 6, 6, 6) - form.addRow("FPS", self.spin) - - def mousePressEvent(self, e: Any) -> None: - if e and e.button() == Qt.MouseButton.RightButton: - self._show_fps_dialog(e.globalPosition()) - else: - super().mousePressEvent(e) - - def _show_fps_dialog(self, pos: QPointF) -> None: - self._popup.show_above_mouse() - - -class LockButton(QPushButton): - LOCK_ICON = "uis:unlock" - UNLOCK_ICON = "uis:lock" - - def __init__(self, text: str = "", parent: QWidget | None = None) -> None: - icn = QIconifyIcon(self.LOCK_ICON, color="#888888") - icn.addKey(self.UNLOCK_ICON, state=QIconifyIcon.State.On, color="red") - super().__init__(icn, text, parent) - self.setCheckable(True) - self.setFixedSize(20, 20) - self.setIconSize(QSize(14, 14)) - self.setStyleSheet("border: none; padding: 0; margin: 0;") - - -class DimsSlider(QWidget): - """A single slider in the DimsSliders widget. - - Provides a play/pause button that toggles animation of the slider value. - Has a QLabeledSlider for the actual value. - Adds a label for the maximum value (e.g. "3 / 10") - """ - - valueChanged = Signal(object, object) # where object is int | slice - - def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.setStyleSheet(SS) - self._slice_mode = False - self._dim_key = dimension_key - - self._timer_id: int | None = None # timer for play button - self._play_btn = PlayButton(parent=self) - self._play_btn.fpsChanged.connect(self.set_fps) - self._play_btn.toggled.connect(self._toggle_animation) - - self._dim_key = dimension_key - self._dim_label = QElidingLabel(str(dimension_key).upper()) - self._dim_label.setToolTip("Double-click to toggle slice mode") - - # note, this lock button only prevents the slider from updating programmatically - # using self.setValue, it doesn't prevent the user from changing the value. - self._lock_btn = LockButton(parent=self) - - self._pos_label = QSpinBox(self) - self._pos_label.valueChanged.connect(self._on_pos_label_edited) - self._pos_label.setButtonSymbols(QSpinBox.ButtonSymbols.NoButtons) - self._pos_label.setAlignment(Qt.AlignmentFlag.AlignRight) - self._pos_label.setStyleSheet( - "border: none; padding: 0; margin: 0; background: transparent" - ) - self._out_of_label = QLabel(self) - - self._int_slider = QSlider(Qt.Orientation.Horizontal) - self._int_slider.rangeChanged.connect(self._on_range_changed) - self._int_slider.valueChanged.connect(self._on_int_value_changed) - - self._slice_slider = slc = QLabeledRangeSlider(Qt.Orientation.Horizontal) - slc.setHandleLabelPosition(QLabeledRangeSlider.LabelPosition.LabelsOnHandle) - slc.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) - slc.setVisible(False) - slc.rangeChanged.connect(self._on_range_changed) - slc.valueChanged.connect(self._on_slice_value_changed) - - self.installEventFilter(self) - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(2) - layout.addWidget(self._play_btn) - layout.addWidget(self._dim_label) - layout.addWidget(self._int_slider) - layout.addWidget(self._slice_slider) - layout.addWidget(self._pos_label) - layout.addWidget(self._out_of_label) - layout.addWidget(self._lock_btn) - self.setMinimumHeight(22) - - def resizeEvent(self, a0: QResizeEvent | None) -> None: - if isinstance(par := self.parent(), DimsSliders): - par.resizeEvent(None) - - def mouseDoubleClickEvent(self, a0: Any) -> None: - self._set_slice_mode(not self._slice_mode) - super().mouseDoubleClickEvent(a0) - - def containMaximum(self, max_val: int) -> None: - if max_val > self._int_slider.maximum(): - self._int_slider.setMaximum(max_val) - if max_val > self._slice_slider.maximum(): - self._slice_slider.setMaximum(max_val) - - def setMaximum(self, max_val: int) -> None: - self._int_slider.setMaximum(max_val) - self._slice_slider.setMaximum(max_val) - - def setMinimum(self, min_val: int) -> None: - self._int_slider.setMinimum(min_val) - self._slice_slider.setMinimum(min_val) - - def containMinimum(self, min_val: int) -> None: - if min_val < self._int_slider.minimum(): - self._int_slider.setMinimum(min_val) - if min_val < self._slice_slider.minimum(): - self._slice_slider.setMinimum(min_val) - - def setRange(self, min_val: int, max_val: int) -> None: - self._int_slider.setRange(min_val, max_val) - self._slice_slider.setRange(min_val, max_val) - - def value(self) -> Index: - if not self._slice_mode: - return self._int_slider.value() # type: ignore - start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) - if start == stop: - return start - return slice(start, stop) - - def setValue(self, val: Index) -> None: - # variant of setValue that always updates the maximum - self._set_slice_mode(isinstance(val, slice)) - if self._lock_btn.isChecked(): - return - if isinstance(val, slice): - start = int(val.start) if val.start is not None else 0 - stop = ( - int(val.stop) if val.stop is not None else self._slice_slider.maximum() - ) - self._slice_slider.setValue((start, stop)) - else: - self._int_slider.setValue(val) - # self._slice_slider.setValue((val, val + 1)) - - def forceValue(self, val: Index) -> None: - """Set value and increase range if necessary.""" - if isinstance(val, slice): - if isinstance(val.start, int): - self.containMinimum(val.start) - if isinstance(val.stop, int): - self.containMaximum(val.stop) - else: - self.containMinimum(val) - self.containMaximum(val) - self.setValue(val) - - def _set_slice_mode(self, mode: bool = True) -> None: - if mode == self._slice_mode: - return - self._slice_mode = bool(mode) - self._slice_slider.setVisible(self._slice_mode) - self._int_slider.setVisible(not self._slice_mode) - # self._pos_label.setVisible(not self._slice_mode) - self.valueChanged.emit(self._dim_key, self.value()) - - def set_fps(self, fps: float) -> None: - self._play_btn.spin.setValue(fps) - self._toggle_animation(self._play_btn.isChecked()) - - def _toggle_animation(self, checked: bool) -> None: - if checked: - if self._timer_id is not None: - self.killTimer(self._timer_id) - interval = int(1000 / self._play_btn.spin.value()) - self._timer_id = self.startTimer(interval) - elif self._timer_id is not None: - self.killTimer(self._timer_id) - self._timer_id = None - - def timerEvent(self, event: Any) -> None: - """Handle timer event for play button, move to the next frame.""" - # TODO - # for now just increment the value by 1, but we should be able to - # take FPS into account better and skip additional frames if the timerEvent - # is delayed for some reason. - inc = 1 - if self._slice_mode: - val = cast(tuple[int, int], self._slice_slider.value()) - next_val = [v + inc for v in val] - if next_val[1] > self._slice_slider.maximum(): - # wrap around, without going below the min handle - next_val = [v - val[0] for v in val] - self._slice_slider.setValue(next_val) - else: - ival = self._int_slider.value() - ival = (ival + inc) % (self._int_slider.maximum() + 1) - self._int_slider.setValue(ival) - - def _on_pos_label_edited(self) -> None: - if self._slice_mode: - self._slice_slider.setValue( - (self._slice_slider.value()[0], self._pos_label.value()) - ) - else: - self._int_slider.setValue(self._pos_label.value()) - - def _on_range_changed(self, min: int, max: int) -> None: - self._out_of_label.setText(f"| {max}") - self._pos_label.setRange(min, max) - self.resizeEvent(None) - self.setVisible(min != max) - - def setVisible(self, visible: bool) -> None: - if self._has_no_range(): - visible = False - super().setVisible(visible) - - def _has_no_range(self) -> bool: - if self._slice_mode: - return bool(self._slice_slider.minimum() == self._slice_slider.maximum()) - return bool(self._int_slider.minimum() == self._int_slider.maximum()) - - def _on_int_value_changed(self, value: int) -> None: - self._pos_label.setValue(value) - if not self._slice_mode: - self.valueChanged.emit(self._dim_key, value) - - def _on_slice_value_changed(self, value: tuple[int, int]) -> None: - self._pos_label.setValue(int(value[1])) - with signals_blocked(self._int_slider): - self._int_slider.setValue(int(value[0])) - if self._slice_mode: - self.valueChanged.emit(self._dim_key, slice(*value)) - - -class DimsSliders(QWidget): - """A Collection of DimsSlider widgets for each dimension in the data. - - Maintains the global current index and emits a signal when it changes. - """ - - valueChanged = Signal(dict) # dict is of type Indices - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self._locks_visible: bool | Mapping[DimKey, bool] = False - self._sliders: dict[DimKey, DimsSlider] = {} - self._current_index: dict[DimKey, Index] = {} - self._invisible_dims: set[DimKey] = set() - - self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) - - layout = QVBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(0) - - def __contains__(self, key: DimKey) -> bool: - """Return True if the dimension key is present in the DimsSliders.""" - return key in self._sliders - - def slider(self, key: DimKey) -> DimsSlider: - """Return the DimsSlider widget for the given dimension key.""" - return self._sliders[key] - - def value(self) -> Indices: - """Return mapping of {dim_key -> current index} for each dimension.""" - return self._current_index.copy() - - def setValue(self, values: Indices) -> None: - """Set the current index for each dimension. - - Parameters - ---------- - values : Mapping[Hashable, int | slice] - Mapping of {dim_key -> index} for each dimension. If value is a slice, - the slider will be in slice mode. If the dimension is not present in the - DimsSliders, it will be added. - """ - if self._current_index == values: - return - with signals_blocked(self): - for dim, index in values.items(): - self.add_or_update_dimension(dim, index) - # FIXME: i don't know why this this is ever empty ... only happens on pyside6 - if val := self.value(): - self.valueChanged.emit(val) - - def minima(self) -> Sizes: - """Return mapping of {dim_key -> minimum value} for each dimension.""" - return {k: v._int_slider.minimum() for k, v in self._sliders.items()} - - def setMinima(self, values: Sizes) -> None: - """Set the minimum value for each dimension. - - Parameters - ---------- - values : Mapping[Hashable, int] - Mapping of {dim_key -> minimum value} for each dimension. - """ - for name, min_val in values.items(): - if name not in self._sliders: - self.add_dimension(name) - self._sliders[name].setMinimum(min_val) - - def maxima(self) -> Sizes: - """Return mapping of {dim_key -> maximum value} for each dimension.""" - return {k: v._int_slider.maximum() for k, v in self._sliders.items()} - - def setMaxima(self, values: Sizes) -> None: - """Set the maximum value for each dimension. - - Parameters - ---------- - values : Mapping[Hashable, int] - Mapping of {dim_key -> maximum value} for each dimension. - """ - for name, max_val in values.items(): - if name not in self._sliders: - self.add_dimension(name) - self._sliders[name].setMaximum(max_val) - - def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: - """Set the visibility of the lock buttons for all dimensions.""" - self._locks_visible = visible - for dim, slider in self._sliders.items(): - viz = visible if isinstance(visible, bool) else visible.get(dim, False) - slider._lock_btn.setVisible(viz) - - def add_dimension(self, key: DimKey, val: Index | None = None) -> None: - """Add a new dimension to the DimsSliders widget. - - Parameters - ---------- - key : Hashable - The name of the dimension. - val : int | slice, optional - The initial value for the dimension. If a slice, the slider will be in - slice mode. - """ - self._sliders[key] = slider = DimsSlider(dimension_key=key, parent=self) - if isinstance(self._locks_visible, dict) and key in self._locks_visible: - slider._lock_btn.setVisible(self._locks_visible[key]) - else: - slider._lock_btn.setVisible(bool(self._locks_visible)) - - val_int = val.start if isinstance(val, slice) else val - slider.setVisible(key not in self._invisible_dims) - if isinstance(val_int, int): - slider.setRange(val_int, val_int) - elif isinstance(val_int, slice): - slider.setRange(val_int.start or 0, val_int.stop or 1) - - val = val if val is not None else 0 - self._current_index[key] = val - slider.forceValue(val) - slider.valueChanged.connect(self._on_dim_slider_value_changed) - cast("QVBoxLayout", self.layout()).addWidget(slider) - - def set_dimension_visible(self, key: DimKey, visible: bool) -> None: - """Set the visibility of a dimension in the DimsSliders widget. - - Once a dimension is hidden, it will not be shown again until it is explicitly - made visible again with this method. - """ - if visible: - self._invisible_dims.discard(key) - if key in self._sliders: - self._current_index[key] = self._sliders[key].value() - else: - self.add_dimension(key) - else: - self._invisible_dims.add(key) - self._current_index.pop(key, None) - if key in self._sliders: - self._sliders[key].setVisible(visible) - - def remove_dimension(self, key: DimKey) -> None: - """Remove a dimension from the DimsSliders widget.""" - try: - slider = self._sliders.pop(key) - except KeyError: - warn(f"Dimension {key} not found in DimsSliders", stacklevel=2) - return - cast("QVBoxLayout", self.layout()).removeWidget(slider) - slider.deleteLater() - - def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: - self._current_index[key] = value - self.valueChanged.emit(self.value()) - - def add_or_update_dimension(self, key: DimKey, value: Index) -> None: - """Add a dimension if it doesn't exist, otherwise update the value.""" - if key in self._sliders: - self._sliders[key].forceValue(value) - else: - self.add_dimension(key, value) - - def resizeEvent(self, a0: QResizeEvent | None) -> None: - # align all labels - if sliders := list(self._sliders.values()): - for lbl in ("_dim_label", "_pos_label", "_out_of_label"): - lbl_width = max(getattr(s, lbl).sizeHint().width() for s in sliders) - for s in sliders: - getattr(s, lbl).setFixedWidth(lbl_width) - - super().resizeEvent(a0) - - def sizeHint(self) -> QSize: - return super().sizeHint().boundedTo(QSize(9999, 0)) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py b/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py deleted file mode 100644 index 6328708d6..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_indexing.py +++ /dev/null @@ -1,292 +0,0 @@ -from __future__ import annotations - -import sys -import warnings -from abc import abstractmethod -from concurrent.futures import Future, ThreadPoolExecutor -from contextlib import suppress -from typing import ( - TYPE_CHECKING, - Generic, - Hashable, - Iterable, - Mapping, - Sequence, - TypeVar, - cast, -) - -import numpy as np - -if TYPE_CHECKING: - from pathlib import Path - from typing import Any, Protocol, TypeGuard - - import dask.array as da - import numpy.typing as npt - import tensorstore as ts - import xarray as xr - from pymmcore_plus.mda.handlers import TensorStoreHandler - from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase - - from ._dims_slider import Index, Indices - - class SupportsIndexing(Protocol): - def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... - @property - def shape(self) -> tuple[int, ...]: ... - - -ArrayT = TypeVar("ArrayT") -MAX_CHANNELS = 16 -# Create a global executor -_EXECUTOR = ThreadPoolExecutor(max_workers=1) - - -class DataWrapper(Generic[ArrayT]): - def __init__(self, data: ArrayT) -> None: - self._data = data - - @classmethod - def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: - if isinstance(data, DataWrapper): - return data - if MMTensorStoreWrapper.supports(data): - return MMTensorStoreWrapper(data) - if MM5DWriter.supports(data): - return MM5DWriter(data) - if XarrayWrapper.supports(data): - return XarrayWrapper(data) - if DaskWrapper.supports(data): - return DaskWrapper(data) - if TensorstoreWrapper.supports(data): - return TensorstoreWrapper(data) - if ArrayLikeWrapper.supports(data): - return ArrayLikeWrapper(data) - raise NotImplementedError(f"Don't know how to wrap type {type(data)}") - - @abstractmethod - def isel(self, indexers: Indices) -> np.ndarray: - """Select a slice from a data store using (possibly) named indices. - - For xarray.DataArray, use the built-in isel method. - For any other duck-typed array, use numpy-style indexing, where indexers - is a mapping of axis to slice objects or indices. - """ - raise NotImplementedError - - def isel_async( - self, indexers: list[Indices] - ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - """Asynchronous version of isel.""" - return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) - - @classmethod - @abstractmethod - def supports(cls, obj: Any) -> bool: - """Return True if this wrapper can handle the given object.""" - raise NotImplementedError - - def guess_channel_axis(self) -> Hashable | None: - """Return the (best guess) axis name for the channel dimension.""" - if isinstance(shp := getattr(self._data, "shape", None), Sequence): - # for numpy arrays, use the smallest dimension as the channel axis - if min(shp) <= MAX_CHANNELS: - return shp.index(min(shp)) - return None - - def save_as_zarr(self, save_loc: str | Path) -> None: - raise NotImplementedError("save_as_zarr not implemented for this data type.") - - def sizes(self) -> Mapping[Hashable, int]: - if (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): - _sizes: dict[Hashable, int] = {} - for i, val in enumerate(shape): - if isinstance(val, int): - _sizes[i] = val - elif isinstance(val, Sequence) and len(val) == 2: - _sizes[val[0]] = int(val[1]) - else: - raise ValueError( - f"Invalid size: {val}. Must be an int or a 2-tuple." - ) - return _sizes - raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") - - def summary_info(self) -> str: - """Return info label with information about the data.""" - package = getattr(self._data, "__module__", "").split(".")[0] - info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" - - if sizes := self.sizes(): - # if all of the dimension keys are just integers, omit them from size_str - if all(isinstance(x, int) for x in sizes): - size_str = repr(tuple(sizes.values())) - # otherwise, include the keys in the size_str - else: - size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) - size_str = f"({size_str})" - info += f" {size_str}" - if dtype := getattr(self._data, "dtype", ""): - info += f", {dtype}" - if nbytes := getattr(self._data, "nbytes", 0) / 1e6: - info += f", {nbytes:.2f}MB" - return info - - -class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): - def sizes(self) -> Mapping[Hashable, int]: - with suppress(Exception): - return self._data.current_sequence.sizes - return {} - - def guess_channel_axis(self) -> Hashable | None: - return "c" - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: - from pymmcore_plus.mda.handlers import TensorStoreHandler - - return isinstance(obj, TensorStoreHandler) - - def isel(self, indexers: Indices) -> np.ndarray: - return self._data.isel(indexers) # type: ignore - - def save_as_zarr(self, save_loc: str | Path) -> None: - if (store := self._data.store) is None: - return - import tensorstore as ts - - new_spec = store.spec().to_json() - new_spec["kvstore"] = {"driver": "file", "path": str(save_loc)} - new_ts = ts.open(new_spec, create=True).result() - new_ts[:] = store.read().result() - - -class MM5DWriter(DataWrapper["_5DWriterBase"]): - def guess_channel_axis(self) -> Hashable | None: - return "c" - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: - try: - from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase - except ImportError: - from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - - _5DWriterBase = (OMETiffWriter, OMEZarrWriter) # type: ignore - if isinstance(obj, _5DWriterBase): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - import zarr - from pymmcore_plus.mda.handlers import OMEZarrWriter - - if isinstance(self._data, OMEZarrWriter): - zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) - raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") - - def isel(self, indexers: Indices) -> np.ndarray: - p_index = indexers.get("p", 0) - if isinstance(p_index, slice): - warnings.warn("Cannot slice over position index", stacklevel=2) # TODO - p_index = p_index.start - p_index = cast(int, p_index) - - try: - sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] - except IndexError as e: - raise IndexError( - f"Position index {p_index} out of range for " - f"{len(self._data.position_sizes)}" - ) from e - - data = self._data.position_arrays[self._data.get_position_key(p_index)] - full = slice(None, None) - index = tuple(indexers.get(k, full) for k in sizes) - return data[index] # type: ignore - - -class XarrayWrapper(DataWrapper["xr.DataArray"]): - def isel(self, indexers: Indices) -> np.ndarray: - return np.asarray(self._data.isel(indexers)) - - def sizes(self) -> Mapping[Hashable, int]: - return {k: int(v) for k, v in self._data.sizes.items()} - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: - if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): - return True - return False - - def guess_channel_axis(self) -> Hashable | None: - for d in self._data.dims: - if str(d).lower() in ("channel", "ch", "c"): - return cast("Hashable", d) - return None - - def save_as_zarr(self, save_loc: str | Path) -> None: - self._data.to_zarr(save_loc) - - -class DaskWrapper(DataWrapper["da.Array"]): - def isel(self, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) - return np.asarray(self._data[idx].compute()) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[da.Array]: - if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - self._data.to_zarr(url=str(save_loc)) - - -class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): - def __init__(self, data: Any) -> None: - super().__init__(data) - import tensorstore as ts - - self._ts = ts - - def sizes(self) -> Mapping[Hashable, int]: - return {dim.label: dim.size for dim in self._data.domain} - - def isel(self, indexers: Indices) -> np.ndarray: - result = self._data[self._ts.d[*indexers][*indexers.values()]].read().result() - return np.asarray(result) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: - if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): - return True - return False - - -class ArrayLikeWrapper(DataWrapper): - def isel(self, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) - return np.asarray(self._data[idx]) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: - if ( - isinstance(obj, np.ndarray) - or hasattr(obj, "__array_function__") - or hasattr(obj, "__array_namespace__") - or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) - ): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - import zarr - - if isinstance(self._data, zarr.Array): - self._data.store = zarr.DirectoryStore(save_loc) - else: - zarr.save(str(save_loc), self._data) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py b/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py deleted file mode 100644 index 65ba69772..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_lut_control.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Iterable, cast - -import numpy as np -from qtpy.QtCore import Qt -from qtpy.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QPushButton, QWidget -from superqt import QLabeledRangeSlider -from superqt.cmap import QColormapComboBox -from superqt.utils import signals_blocked - -from ._dims_slider import SS - -if TYPE_CHECKING: - import cmap - - from ._protocols import PImageHandle - - -class CmapCombo(QColormapComboBox): - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") - self.setMinimumSize(120, 21) - # self.setStyleSheet("background-color: transparent;") - - def showPopup(self) -> None: - super().showPopup() - popup = self.findChild(QFrame) - popup.setMinimumWidth(self.width() + 100) - popup.move(popup.x(), popup.y() - self.height() - popup.height()) - - -class LutControl(QWidget): - def __init__( - self, - name: str = "", - handles: Iterable[PImageHandle] = (), - parent: QWidget | None = None, - cmaplist: Iterable[Any] = (), - ) -> None: - super().__init__(parent) - self._handles = handles - self._name = name - - self._visible = QCheckBox(name) - self._visible.setChecked(True) - self._visible.toggled.connect(self._on_visible_changed) - - self._cmap = CmapCombo() - self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for handle in handles: - self._cmap.addColormap(handle.cmap) - for color in cmaplist: - self._cmap.addColormap(color) - - self._clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) - self._clims.setStyleSheet(SS) - self._clims.setHandleLabelPosition( - QLabeledRangeSlider.LabelPosition.LabelsOnHandle - ) - self._clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) - self._clims.setRange(0, 2**8) - self._clims.valueChanged.connect(self._on_clims_changed) - - self._auto_clim = QPushButton("Auto") - self._auto_clim.setMaximumWidth(42) - self._auto_clim.setCheckable(True) - self._auto_clim.setChecked(True) - self._auto_clim.toggled.connect(self.update_autoscale) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.addWidget(self._visible) - layout.addWidget(self._cmap) - layout.addWidget(self._clims) - layout.addWidget(self._auto_clim) - - self.update_autoscale() - - def autoscaleChecked(self) -> bool: - return cast("bool", self._auto_clim.isChecked()) - - def _on_clims_changed(self, clims: tuple[float, float]) -> None: - self._auto_clim.setChecked(False) - for handle in self._handles: - handle.clim = clims - - def _on_visible_changed(self, visible: bool) -> None: - for handle in self._handles: - handle.visible = visible - if visible: - self.update_autoscale() - - def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: - for handle in self._handles: - handle.cmap = cmap - - def update_autoscale(self) -> None: - if ( - not self._auto_clim.isChecked() - or not self._visible.isChecked() - or not self._handles - ): - return - - # find the min and max values for the current channel - clims = [np.inf, -np.inf] - for handle in self._handles: - clims[0] = min(clims[0], np.nanmin(handle.data)) - clims[1] = max(clims[1], np.nanmax(handle.data)) - - mi, ma = tuple(int(x) for x in clims) - if mi != ma: - for handle in self._handles: - handle.clim = (mi, ma) - - # set the slider values to the new clims - with signals_blocked(self._clims): - self._clims.setMinimum(min(mi, self._clims.minimum())) - self._clims.setMaximum(max(ma, self._clims.maximum())) - self._clims.setValue((mi, ma)) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index cdda4f82f..1a9507445 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -1,21 +1,22 @@ from __future__ import annotations import warnings +from pathlib import Path from typing import TYPE_CHECKING, Any, Mapping import superqt import useq +from ndv import DataWrapper, NDViewer from pymmcore_plus.mda.handlers import TensorStoreHandler - -from ._save_button import SaveButton -from ._stack_viewer import StackViewer +from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget +from superqt.iconify import QIconifyIcon if TYPE_CHECKING: from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase from qtpy.QtWidgets import QWidget -class MDAViewer(StackViewer): +class MDAViewer(NDViewer): """StackViewer specialized for pymmcore-plus MDA acquisitions.""" _data: _5DWriterBase @@ -64,3 +65,27 @@ def _get_channel_name(self, index: Mapping) -> str: if name := self._channel_names.get(index[self._channel_axis]): return name return super()._get_channel_name(index) + + +class SaveButton(QPushButton): + def __init__( + self, + data_wrapper: DataWrapper, + parent: QWidget | None = None, + ): + super().__init__(parent=parent) + self.setIcon(QIconifyIcon("mdi:content-save")) + self.clicked.connect(self._on_click) + + self._data_wrapper = data_wrapper + self._last_loc = str(Path.home()) + + def _on_click(self) -> None: + self._last_loc, _ = QFileDialog.getSaveFileName( + self, "Choose destination", str(self._last_loc), "" + ) + suffix = Path(self._last_loc).suffix + if suffix in (".zarr", ".ome.zarr", ""): + self._data_wrapper.save_as_zarr(self._last_loc) + else: + raise ValueError(f"Unsupported file format: {self._last_loc}") diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py b/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py deleted file mode 100644 index 413038ded..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_protocols.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol - -if TYPE_CHECKING: - import cmap - import numpy as np - from qtpy.QtWidgets import QWidget - - -class PImageHandle(Protocol): - @property - def data(self) -> np.ndarray: ... - @data.setter - def data(self, data: np.ndarray) -> None: ... - @property - def visible(self) -> bool: ... - @visible.setter - def visible(self, visible: bool) -> None: ... - @property - def clim(self) -> Any: ... - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: ... - @property - def cmap(self) -> Any: ... - @cmap.setter - def cmap(self, cmap: Any) -> None: ... - def remove(self) -> None: ... - - -class PCanvas(Protocol): - def __init__(self, set_info: Callable[[str], None]) -> None: ... - def set_ndim(self, ndim: Literal[2, 3]) -> None: ... - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - z: tuple[float, float] | None = None, - margin: float = ..., - ) -> None: ... - def refresh(self) -> None: ... - def qwidget(self) -> QWidget: ... - def add_image( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... - ) -> PImageHandle: ... - def add_volume( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... - ) -> PImageHandle: ... diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py b/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py deleted file mode 100644 index 85520641a..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_save_button.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING - -from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget -from superqt.iconify import QIconifyIcon - -if TYPE_CHECKING: - from ._indexing import DataWrapper - - -class SaveButton(QPushButton): - def __init__( - self, - data_wrapper: DataWrapper, - parent: QWidget | None = None, - ): - super().__init__(parent=parent) - self.setIcon(QIconifyIcon("mdi:content-save")) - self.clicked.connect(self._on_click) - - self._data_wrapper = data_wrapper - self._last_loc = str(Path.home()) - - def _on_click(self) -> None: - self._last_loc, _ = QFileDialog.getSaveFileName( - self, "Choose destination", str(self._last_loc), "" - ) - suffix = Path(self._last_loc).suffix - if suffix in (".zarr", ".ome.zarr", ""): - self._data_wrapper.save_as_zarr(self._last_loc) - else: - raise ValueError(f"Unsupported file format: {self._last_loc}") diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py deleted file mode 100644 index 6aa6f28d6..000000000 --- a/src/pymmcore_widgets/_stack_viewer_v2/_stack_viewer.py +++ /dev/null @@ -1,591 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from enum import Enum -from itertools import cycle -from typing import TYPE_CHECKING, Iterable, Literal, Mapping, Sequence, cast - -import cmap -import numpy as np -from qtpy.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget -from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread -from superqt.utils import qthrottled, signals_blocked - -from ._backends import get_canvas -from ._dims_slider import DimsSliders -from ._indexing import DataWrapper -from ._lut_control import LutControl - -if TYPE_CHECKING: - from concurrent.futures import Future - from typing import Any, Callable, Hashable, TypeAlias - - from qtpy.QtGui import QCloseEvent - - from ._dims_slider import DimKey, Indices, Sizes - from ._protocols import PCanvas, PImageHandle - - ImgKey: TypeAlias = Hashable - # any mapping of dimensions to sizes - SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] - -MID_GRAY = "#888888" -GRAYS = cmap.Colormap("gray") -DEFAULT_COLORMAPS = [ - cmap.Colormap("green"), - cmap.Colormap("magenta"), - cmap.Colormap("cyan"), - cmap.Colormap("yellow"), - cmap.Colormap("red"), - cmap.Colormap("blue"), - cmap.Colormap("cubehelix"), - cmap.Colormap("gray"), -] -ALL_CHANNELS = slice(None) - - -class ChannelMode(str, Enum): - COMPOSITE = "composite" - MONO = "mono" - - def __str__(self) -> str: - return self.value - - -class ChannelModeButton(QPushButton): - def __init__(self, parent: QWidget | None = None): - super().__init__(parent) - self.setCheckable(True) - self.toggled.connect(self.next_mode) - - # set minimum width to the width of the larger string 'composite' - self.setMinimumWidth(92) # FIXME: magic number - - def next_mode(self) -> None: - if self.isChecked(): - self.setMode(ChannelMode.MONO) - else: - self.setMode(ChannelMode.COMPOSITE) - - def mode(self) -> ChannelMode: - return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE - - def setMode(self, mode: ChannelMode) -> None: - # we show the name of the next mode, not the current one - other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO - self.setText(str(other)) - self.setChecked(mode == ChannelMode.MONO) - - -class DimToggleButton(QPushButton): - def __init__(self, parent: QWidget | None = None): - icn = QIconifyIcon("f7:view-2d", color="#333333") - icn.addKey("f7:view-3d", state=QIconifyIcon.State.On, color="white") - super().__init__(icn, "", parent) - self.setCheckable(True) - self.setChecked(True) - - -# @dataclass -# class LutModel: -# name: str = "" -# autoscale: bool = True -# min: float = 0.0 -# max: float = 1.0 -# colormap: cmap.Colormap = GRAYS -# visible: bool = True - - -# @dataclass -# class ViewerModel: -# data: Any = None -# # dimensions of the data that will *not* be sliced. -# visualized_dims: Container[DimKey] = (-2, -1) -# # the axis that represents the channels in the data -# channel_axis: DimKey | None = None -# # the mode for displaying the channels -# # if MONO, only the current selection of channel_axis is displayed -# # if COMPOSITE, the full channel_axis is sliced, and luts determine display -# channel_mode: ChannelMode = ChannelMode.MONO -# # map of index in the channel_axis to LutModel -# luts: Mapping[int, LutModel] = {} - - -class StackViewer(QWidget): - """A viewer for ND arrays. - - This widget displays a single slice from an ND array (or a composite of slices in - different colormaps). The widget provides sliders to select the slice to display, - and buttons to control the display mode of the channels. - - An important concept in this widget is the "index". The index is a mapping of - dimensions to integers or slices that define the slice of the data to display. For - example, a numpy slice of `[0, 1, 5:10]` would be represented as - `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be named, e.g. - `{'t': 0, 'c': 1, 'z': slice(5, 10)}`. The index is used to select the data from - the datastore, and to determine the position of the sliders. - - The flow of data is as follows: - - - The user sets the data using the `set_data` method. This will set the number - and range of the sliders to the shape of the data, and display the first slice. - - The user can then use the sliders to select the slice to display. The current - slice is defined as a `Mapping` of `{dim -> int|slice}` and can be retrieved - with the `_dims_sliders.value()` method. To programmatically set the current - position, use the `setIndex` method. This will set the values of the sliders, - which in turn will trigger the display of the new slice via the - `_update_data_for_index` method. - - `_update_data_for_index` is an asynchronous method that retrieves the data for - the given index from the datastore (using `_isel`) and queues the - `_on_data_slice_ready` method to be called when the data is ready. The logic - for extracting data from the datastore is defined in `_indexing.py`, which handles - idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). - - `_on_data_slice_ready` is called when the data is ready, and updates the image. - Note that if the slice is multidimensional, the data will be reduced to 2D using - max intensity projection (and double-clicking on any given dimension slider will - turn it into a range slider allowing a projection to be made over that dimension). - - The image is displayed on the canvas, which is an object that implements the - `PCanvas` protocol (mostly, it has an `add_image` method that returns a handle - to the added image that can be used to update the data and display). This - small abstraction allows for various backends to be used (e.g. vispy, pygfx, etc). - - Parameters - ---------- - data : Any - The data to display. This can be an ND array, an xarray DataArray, or any - object that supports numpy-style indexing. - parent : QWidget, optional - The parent widget of this widget. - channel_axis : Hashable, optional - The axis that represents the channels in the data. If not provided, this will - be guessed from the data. - channel_mode : ChannelMode, optional - The initial mode for displaying the channels. If not provided, this will be - set to ChannelMode.MONO. - """ - - def __init__( - self, - data: Any, - *, - colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, - parent: QWidget | None = None, - channel_axis: DimKey | None = None, - channel_mode: ChannelMode | str = ChannelMode.MONO, - ): - super().__init__(parent=parent) - - # ATTRIBUTES ---------------------------------------------------- - - # dimensions of the data in the datastore - self._sizes: Sizes = {} - # mapping of key to a list of objects that control image nodes in the canvas - self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) - # mapping of same keys to the LutControl objects control image display props - self._lut_ctrls: dict[ImgKey, LutControl] = {} - # the set of dimensions we are currently visualizing (e.g. XY) - # this is used to control which dimensions have sliders and the behavior - # of isel when selecting data from the datastore - self._visualized_dims: set[DimKey] = set() - # the axis that represents the channels in the data - self._channel_axis = channel_axis - self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode - # colormaps that will be cycled through when displaying composite images - # TODO: allow user to set this - if colormaps is not None: - self._cmaps = [cmap.Colormap(c) for c in colormaps] - else: - self._cmaps = DEFAULT_COLORMAPS - self._cmap_cycle = cycle(self._cmaps) - # the last future that was created by _update_data_for_index - self._last_future: Future | None = None - - # number of dimensions to display - self._ndims: Literal[2, 3] = 2 - - # WIDGETS ---------------------------------------------------- - - # the button that controls the display mode of the channels - self._channel_mode_btn = ChannelModeButton(self) - self._channel_mode_btn.clicked.connect(self.set_channel_mode) - # button to reset the zoom of the canvas - self._set_range_btn = QPushButton( - QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self - ) - self._set_range_btn.clicked.connect(self._on_set_range_clicked) - - # button to change number of displayed dimensions - self._ndims_btn = DimToggleButton(self) - self._ndims_btn.clicked.connect(self.toggle_3d) - - # place to display dataset summary - self._data_info_label = QElidingLabel("", parent=self) - # place to display arbitrary text - self._hover_info_label = QLabel("", self) - # the canvas that displays the images - self._canvas: PCanvas = get_canvas()(self._hover_info_label.setText) - self._canvas.set_ndim(self._ndims) - - # the sliders that control the index of the displayed image - self._dims_sliders = DimsSliders(self) - self._dims_sliders.valueChanged.connect( - qthrottled(self._update_data_for_index, 20, leading=True) - ) - - self._lut_drop = QCollapsible("LUTs", self) - self._lut_drop.setCollapsedIcon(QIconifyIcon("bi:chevron-down", color=MID_GRAY)) - self._lut_drop.setExpandedIcon(QIconifyIcon("bi:chevron-up", color=MID_GRAY)) - lut_layout = cast("QVBoxLayout", self._lut_drop.layout()) - lut_layout.setContentsMargins(0, 1, 0, 1) - lut_layout.setSpacing(0) - if ( - hasattr(self._lut_drop, "_content") - and (layout := self._lut_drop._content.layout()) is not None - ): - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(0) - - # LAYOUT ----------------------------------------------------- - - self._btns = btns = QHBoxLayout() - btns.setContentsMargins(0, 0, 0, 0) - btns.setSpacing(0) - btns.addStretch() - btns.addWidget(self._channel_mode_btn) - btns.addWidget(self._ndims_btn) - btns.addWidget(self._set_range_btn) - - layout = QVBoxLayout(self) - layout.setSpacing(2) - layout.setContentsMargins(6, 6, 6, 6) - layout.addWidget(self._data_info_label) - layout.addWidget(self._canvas.qwidget(), 1) - layout.addWidget(self._hover_info_label) - layout.addWidget(self._dims_sliders) - layout.addWidget(self._lut_drop) - layout.addLayout(btns) - - # SETUP ------------------------------------------------------ - - self.set_channel_mode(channel_mode) - if data is not None: - self.set_data(data) - - # ------------------- PUBLIC API ---------------------------- - @property - def data(self) -> Any: - """Return the data backing the view.""" - return self._data_wrapper._data - - @data.setter - def data(self, data: Any) -> None: - """Set the data backing the view.""" - raise AttributeError("Cannot set data directly. Use `set_data` method.") - - @property - def dims_sliders(self) -> DimsSliders: - """Return the DimsSliders widget.""" - return self._dims_sliders - - @property - def sizes(self) -> Sizes: - """Return sizes {dimkey: int} of the dimensions in the datastore.""" - return self._sizes - - def set_data( - self, - data: Any, - sizes: SizesLike | None = None, - channel_axis: int | None = None, - visualized_dims: Iterable[DimKey] | None = None, - ) -> None: - """Set the datastore, and, optionally, the sizes of the data.""" - # store the data - self._data_wrapper = DataWrapper.create(data) - - # determine sizes of the data - self._sizes = self._data_wrapper.sizes() if sizes is None else _to_sizes(sizes) - - # set channel axis - if channel_axis is not None: - self._channel_axis = channel_axis - elif self._channel_axis is None: - self._channel_axis = self._data_wrapper.guess_channel_axis() - - # update the dimensions we are visualizing - if visualized_dims is None: - visualized_dims = list(self._sizes)[-self._ndims :] - self.set_visualized_dims(visualized_dims) - - # update the range of all the sliders to match the sizes we set above - with signals_blocked(self._dims_sliders): - self.update_slider_ranges() - # redraw - self.setIndex({}) - # update the data info label - self._data_info_label.setText(self._data_wrapper.summary_info()) - - def set_visualized_dims(self, dims: Iterable[DimKey]) -> None: - """Set the dimensions that will be visualized. - - This dims will NOT have sliders associated with them. - """ - self._visualized_dims = set(dims) - for d in self._dims_sliders._sliders: - self._dims_sliders.set_dimension_visible(d, d not in self._visualized_dims) - for d in self._visualized_dims: - self._dims_sliders.set_dimension_visible(d, False) - - def update_slider_ranges( - self, mins: SizesLike | None = None, maxes: SizesLike | None = None - ) -> None: - """Set the maximum values of the sliders. - - If `sizes` is not provided, sizes will be inferred from the datastore. - This is mostly here as a public way to reset the - """ - if maxes is None: - maxes = self._sizes - maxes = _to_sizes(maxes) - self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) - if mins is not None: - self._dims_sliders.setMinima(_to_sizes(mins)) - - # FIXME: this needs to be moved and made user-controlled - for dim in list(maxes.keys())[-self._ndims :]: - self._dims_sliders.set_dimension_visible(dim, False) - - def toggle_3d(self) -> None: - self.set_ndim(3 if self._ndims == 2 else 2) - - def set_ndim(self, ndim: Literal[2, 3]) -> None: - """Set the number of dimensions to display.""" - self._ndims = ndim - self._canvas.set_ndim(ndim) - - # set the visibility of the last non-channel dimension - sizes = list(self._sizes) - if self._channel_axis is not None: - sizes = [x for x in sizes if x != self._channel_axis] - if len(sizes) >= 3: - dim3 = sizes[-3] - self._dims_sliders.set_dimension_visible(dim3, True if ndim == 2 else False) - - # clear image handles and redraw - if self._img_handles: - self._clear_images() - self._update_data_for_index(self._dims_sliders.value()) - - def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: - """Set the mode for displaying the channels. - - In "composite" mode, the channels are displayed as a composite image, using - self._channel_axis as the channel axis. In "grayscale" mode, each channel is - displayed separately. (If mode is None, the current value of the - channel_mode_picker button is used) - """ - if mode is None or isinstance(mode, bool): - mode = self._channel_mode_btn.mode() - else: - mode = ChannelMode(mode) - self._channel_mode_btn.setMode(mode) - if mode == getattr(self, "_channel_mode", None): - return - - self._channel_mode = mode - self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle - if self._channel_axis is not None: - # set the visibility of the channel slider - self._dims_sliders.set_dimension_visible( - self._channel_axis, mode != ChannelMode.COMPOSITE - ) - - if self._img_handles: - self._clear_images() - self._update_data_for_index(self._dims_sliders.value()) - - def setIndex(self, index: Indices) -> None: - """Set the index of the displayed image.""" - self._dims_sliders.setValue(index) - - # ------------------- PRIVATE METHODS ---------------------------- - - def _on_set_range_clicked(self) -> None: - # using method to swallow the parameter passed by _set_range_btn.clicked - self._canvas.set_range() - - def _image_key(self, index: Indices) -> ImgKey: - """Return the key for image handle(s) corresponding to `index`.""" - if self._channel_mode == ChannelMode.COMPOSITE: - val = index.get(self._channel_axis, 0) - if isinstance(val, slice): - return (val.start, val.stop) - return val - return 0 - - def _update_data_for_index(self, index: Indices) -> None: - """Retrieve data for `index` from datastore and update canvas image(s). - - This will pull the data from the datastore using the given index, and update - the image handle(s) with the new data. This method is *asynchronous*. It - makes a request for the new data slice and queues _on_data_future_done to be - called when the data is ready. - """ - if ( - self._channel_axis is not None - and self._channel_mode == ChannelMode.COMPOSITE - ): - indices: list[Indices] = [ - {**index, self._channel_axis: i} - for i in range(self._sizes[self._channel_axis]) - ] - else: - indices = [index] - - if self._last_future: - self._last_future.cancel() - - # don't request any dimensions that are not visualized - indices = [ - {k: v for k, v in idx.items() if k not in self._visualized_dims} - for idx in indices - ] - self._last_future = f = self._isel(indices) - f.add_done_callback(self._on_data_slice_ready) - - def closeEvent(self, a0: QCloseEvent | None) -> None: - if self._last_future is not None: - self._last_future.cancel() - self._last_future = None - super().closeEvent(a0) - - def _isel( - self, indices: list[Indices] - ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - """Select data from the datastore using the given index.""" - try: - return self._data_wrapper.isel_async(indices) - except Exception as e: - raise type(e)(f"Failed to index data with {indices}: {e}") from e - - @ensure_main_thread # type: ignore - def _on_data_slice_ready( - self, future: Future[Iterable[tuple[Indices, np.ndarray]]] - ) -> None: - """Update the displayed image for the given index. - - Connected to the future returned by _isel. - """ - # NOTE: removing the reference to the last future here is important - # because the future has a reference to this widget in its _done_callbacks - # which will prevent the widget from being garbage collected if the future - self._last_future = None - if future.cancelled(): - return - - data = future.result() - # FIXME: - # `self._channel_axis: i` is a bug; we assume channel indices start at 0 - # but the actual values used for indices are up to the user. - for idx, datum in data: - self._update_canvas_data(datum, idx) - self._canvas.refresh() - - def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: - """Actually update the image handle(s) with the (sliced) data. - - By this point, data should be sliced from the underlying datastore. Any - dimensions remaining that are more than the number of visualized dimensions - (currently just 2D) will be reduced using max intensity projection (currently). - """ - imkey = self._image_key(index) - datum = self._reduce_data_for_display(data) - if handles := self._img_handles[imkey]: - for handle in handles: - handle.data = datum - if ctrl := self._lut_ctrls.get(imkey, None): - ctrl.update_autoscale() - else: - cm = ( - next(self._cmap_cycle) - if self._channel_mode == ChannelMode.COMPOSITE - else GRAYS - ) - if datum.ndim == 2: - handles.append(self._canvas.add_image(datum, cmap=cm)) - elif datum.ndim == 3: - handles.append(self._canvas.add_volume(datum, cmap=cm)) - if imkey not in self._lut_ctrls: - channel_name = self._get_channel_name(index) - self._lut_ctrls[imkey] = c = LutControl( - channel_name, - handles, - self, - cmaplist=self._cmaps + DEFAULT_COLORMAPS, - ) - self._lut_drop.addWidget(c) - - def _get_channel_name(self, index: Indices) -> str: - c = index.get(self._channel_axis, 0) - return f"Ch {c}" # TODO: get name from user - - def _reduce_data_for_display( - self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max - ) -> np.ndarray: - """Reduce the number of dimensions in the data for display. - - This function takes a data array and reduces the number of dimensions to - the max allowed for display. The default behavior is to reduce the smallest - dimensions, using np.max. This can be improved in the future. - - This also coerces 64-bit data to 32-bit data. - """ - # TODO - # - allow dimensions to control how they are reduced (as opposed to just max) - # - for better way to determine which dims need to be reduced (currently just - # the smallest dims) - data = data.squeeze() - visualized_dims = self._ndims - if extra_dims := data.ndim - visualized_dims: - shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) - smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) - data = reductor(data, axis=smallest_dims) - - if data.dtype.itemsize > 4: # More than 32 bits - if np.issubdtype(data.dtype, np.integer): - data = data.astype(np.int32) - else: - data = data.astype(np.float32) - return data - - def _clear_images(self) -> None: - """Remove all images from the canvas.""" - for handles in self._img_handles.values(): - for handle in handles: - handle.remove() - self._img_handles.clear() - - # clear the current LutControls as well - for c in self._lut_ctrls.values(): - cast("QVBoxLayout", self.layout()).removeWidget(c) - c.deleteLater() - self._lut_ctrls.clear() - - -def _to_sizes(sizes: SizesLike | None) -> Sizes: - """Coerce `sizes` to a {dimKey -> int} mapping.""" - if sizes is None: - return {} - if isinstance(sizes, Mapping): - return {k: int(v) for k, v in sizes.items()} - if not isinstance(sizes, Iterable): - raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}") - _sizes: dict[Hashable, int] = {} - for i, val in enumerate(sizes): - if isinstance(val, int): - _sizes[i] = val - elif isinstance(val, Sequence) and len(val) == 2: - _sizes[val[0]] = int(val[1]) - else: - raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.") - return _sizes From fc2a792b5dded2a6761c3b94920c7772fccd31b4 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 8 Jun 2024 18:05:02 -0400 Subject: [PATCH 66/73] add data wrapper --- .../_stack_viewer_v2/_data_wrapper.py | 89 +++++++++++++++++++ .../_stack_viewer_v2/_mda_viewer.py | 2 + 2 files changed, 91 insertions(+) create mode 100644 src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py b/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py new file mode 100644 index 000000000..1797f1433 --- /dev/null +++ b/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import warnings +from contextlib import suppress +from typing import TYPE_CHECKING, cast + +from ndv import DataWrapper +from pymmcore_plus.mda.handlers import TensorStoreHandler + +if TYPE_CHECKING: + from collections.abc import Hashable, Mapping + from pathlib import Path + from typing import Any, TypeGuard + + import numpy as np + from ndv import Indices + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + + +class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): + def sizes(self) -> Mapping[Hashable, int]: + with suppress(Exception): + return self._data.current_sequence.sizes # type: ignore [return-value] + return {} + + def guess_channel_axis(self) -> Hashable | None: + return "c" + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: + return isinstance(obj, TensorStoreHandler) + + def isel(self, indexers: Indices) -> np.ndarray: + return self._data.isel({str(k): v for k, v in indexers.items()}) + + def save_as_zarr(self, save_loc: str | Path) -> None: + if (store := self._data.store) is None: + return + import tensorstore as ts + + new_spec = store.spec().to_json() + new_spec["kvstore"] = {"driver": "file", "path": str(save_loc)} + new_ts = ts.open(new_spec, create=True).result() + new_ts[:] = store.read().result() + + +class MM5DWriter(DataWrapper["_5DWriterBase"]): + def guess_channel_axis(self) -> Hashable | None: + return "c" + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: + try: + from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase + except ImportError: + from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter + + _5DWriterBase = (OMETiffWriter, OMEZarrWriter) + if isinstance(obj, _5DWriterBase): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + import zarr + from pymmcore_plus.mda.handlers import OMEZarrWriter + + if isinstance(self._data, OMEZarrWriter): + zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) + raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") + + def isel(self, indexers: Indices) -> np.ndarray: + p_index = indexers.get("p", 0) + if isinstance(p_index, slice): + warnings.warn("Cannot slice over position index", stacklevel=2) # TODO + p_index = p_index.start + p_index = cast(int, p_index) + + try: + sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] + except IndexError as e: + raise IndexError( + f"Position index {p_index} out of range for " + f"{len(self._data.position_sizes)}" + ) from e + + data = self._data.position_arrays[self._data.get_position_key(p_index)] + full = slice(None, None) + index = tuple(indexers.get(k, full) for k in sizes) + return data[index] # type: ignore [no-any-return] diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 1a9507445..60123e542 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -11,6 +11,8 @@ from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget from superqt.iconify import QIconifyIcon +from . import _data_wrapper # noqa: F401 + if TYPE_CHECKING: from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase from qtpy.QtWidgets import QWidget From 346e66e7d95d9d8d368ea7a9214dc8d423bc1539 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 8 Jun 2024 18:07:14 -0400 Subject: [PATCH 67/73] add comment --- src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 60123e542..47db4a2f2 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -11,6 +11,7 @@ from qtpy.QtWidgets import QFileDialog, QPushButton, QWidget from superqt.iconify import QIconifyIcon +# this import is necessary so that ndv can find our custom DataWrapper from . import _data_wrapper # noqa: F401 if TYPE_CHECKING: From 064d72d964deba48a18eb9a27b503b22df315271 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sat, 8 Jun 2024 18:12:04 -0400 Subject: [PATCH 68/73] add dep --- pyproject.toml | 1 + src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 53769051c..9b7cccf15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ 'qtpy >=2.0', 'superqt[quantity] >=0.6.5', 'useq-schema >=0.4.7', + 'ndv', ] # extras diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index 47db4a2f2..dbc6df7f7 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -58,16 +58,17 @@ def _patched_frame_ready(self, *args: Any) -> None: @superqt.ensure_main_thread # type: ignore def _on_frame_ready(self, event: useq.MDAEvent) -> None: - c = event.index.get(self._channel_axis) # type: ignore + c = event.index.get(self._channel_axis) if c not in self._channel_names and c is not None and event.channel: self._channel_names[c] = event.channel.config - self.setIndex(event.index) # type: ignore + self.setIndex(event.index) def _get_channel_name(self, index: Mapping) -> str: if self._channel_axis in index: if name := self._channel_names.get(index[self._channel_axis]): return name - return super()._get_channel_name(index) + c = index.get(self._channel_axis, 0) + return f"Ch {c}" class SaveButton(QPushButton): From 3d80b01d67721ee6f5810043e683598e1fe8db3c Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 12:15:57 -0400 Subject: [PATCH 69/73] bump deps --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9b7cccf15..3bbc31b46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,11 +48,11 @@ classifiers = [ dynamic = ["version"] dependencies = [ 'fonticon-materialdesignicons6', - 'pymmcore-plus[cli] >=0.9.5', + 'pymmcore-plus[cli] >=0.10.0', 'qtpy >=2.0', 'superqt[quantity] >=0.6.5', 'useq-schema >=0.4.7', - 'ndv', + 'ndv >=0.0.3', ] # extras From 0f359c2db02dc91bb6767411b5492cee4264e40e Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 12:35:54 -0400 Subject: [PATCH 70/73] lint --- .pre-commit-config.yaml | 2 +- src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4f95e725..6eb598f3d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,5 +27,5 @@ repos: - id: mypy files: "^src/" additional_dependencies: - - pymmcore-plus >=0.9.5 + - pymmcore-plus >=0.10.0 - useq-schema >=0.4.7 diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py b/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py index 1797f1433..d35689e95 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_data_wrapper.py @@ -20,7 +20,7 @@ class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): def sizes(self) -> Mapping[Hashable, int]: with suppress(Exception): - return self._data.current_sequence.sizes # type: ignore [return-value] + return self.data.current_sequence.sizes # type: ignore return {} def guess_channel_axis(self) -> Hashable | None: @@ -31,10 +31,10 @@ def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: return isinstance(obj, TensorStoreHandler) def isel(self, indexers: Indices) -> np.ndarray: - return self._data.isel({str(k): v for k, v in indexers.items()}) + return self.data.isel({str(k): v for k, v in indexers.items()}) # type: ignore [no-any-return] def save_as_zarr(self, save_loc: str | Path) -> None: - if (store := self._data.store) is None: + if (store := self.data.store) is None: return import tensorstore as ts @@ -55,7 +55,7 @@ def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: except ImportError: from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - _5DWriterBase = (OMETiffWriter, OMEZarrWriter) + _5DWriterBase = (OMETiffWriter, OMEZarrWriter) # type: ignore if isinstance(obj, _5DWriterBase): return True return False From be728cd438d699f649ba5cab552f45c76aceb9ae Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 12:37:15 -0400 Subject: [PATCH 71/73] remove test --- tests/test_stack_viewer2.py | 52 ------------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 tests/test_stack_viewer2.py diff --git a/tests/test_stack_viewer2.py b/tests/test_stack_viewer2.py deleted file mode 100644 index 3dfa17926..000000000 --- a/tests/test_stack_viewer2.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import dask.array as da -import numpy as np - -from pymmcore_widgets._stack_viewer_v2 import StackViewer - -if TYPE_CHECKING: - from pytestqt.qtbot import QtBot - - -def make_lazy_array(shape: tuple[int, ...]) -> da.Array: - rest_shape = shape[:-2] - frame_shape = shape[-2:] - - def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: - if isinstance(block_id, np.ndarray): - return None - size = (1,) * len(rest_shape) + frame_shape - return np.random.randint(0, 255, size=size, dtype=np.uint8) - - chunks = [(1,) * x for x in rest_shape] + [(x,) for x in frame_shape] - return da.map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) # type: ignore - - -# this test is still leaking widgets and it's hard to track down... I think -# it might have to do with the cmapComboBox -# @pytest.mark.allow_leaks -def test_stack_viewer2(qtbot: QtBot) -> None: - dask_arr = make_lazy_array((1000, 64, 3, 256, 256)) - v = StackViewer(dask_arr) - qtbot.addWidget(v) - v.show() - - # wait until there are no running jobs, because the callbacks - # in the futures hold a strong reference to the viewer - qtbot.waitUntil(lambda: v._last_future is None, timeout=1000) - - -def test_dims_sliders(qtbot: QtBot) -> None: - from superqt import QLabeledRangeSlider - - from pymmcore_widgets._stack_viewer_v2._dims_slider import DimsSlider - - # temporary debugging - ds = DimsSlider(dimension_key="t") - qtbot.addWidget(ds) - - rs = QLabeledRangeSlider() - qtbot.addWidget(rs) From 8957e8715abd24b4c6e64833b4bbb65c346c0b46 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 9 Jun 2024 12:52:39 -0400 Subject: [PATCH 72/73] add test --- pyproject.toml | 2 +- .../_stack_viewer_v2/_mda_viewer.py | 2 +- tests/conftest.py | 9 ++++- tests/test_mda_viewer.py | 38 +++++++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 tests/test_mda_viewer.py diff --git a/pyproject.toml b/pyproject.toml index 3bbc31b46..b90a6013a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,9 +136,9 @@ docstring-code-format = true # https://docs.pytest.org/en/6.2.x/customize.html [tool.pytest.ini_options] -markers = ["allow_leaks"] minversion = "6.0" testpaths = ["tests"] +markers = ["allow_leaks: mark test to allow widget leaks"] filterwarnings = [ "error", "ignore:distutils Version classes are deprecated", diff --git a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py index dbc6df7f7..e290631e8 100644 --- a/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py +++ b/src/pymmcore_widgets/_stack_viewer_v2/_mda_viewer.py @@ -61,7 +61,7 @@ def _on_frame_ready(self, event: useq.MDAEvent) -> None: c = event.index.get(self._channel_axis) if c not in self._channel_names and c is not None and event.channel: self._channel_names[c] = event.channel.config - self.setIndex(event.index) + self.set_current_index(event.index) def _get_channel_name(self, index: Mapping) -> str: if self._channel_axis in index: diff --git a/tests/conftest.py b/tests/conftest.py index f3e979e10..f52aed172 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import gc from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator from unittest.mock import patch import pytest @@ -16,7 +16,7 @@ # to create a new CMMCorePlus() for every test @pytest.fixture(autouse=True) -def global_mmcore(): +def global_mmcore() -> Iterator[CMMCorePlus]: mmc = CMMCorePlus() mmc.loadSystemConfiguration(TEST_CONFIG) with patch.object(_mmcore_plus, "_instance", mmc): @@ -32,6 +32,11 @@ def _run_after_each_test(request: "FixtureRequest", qapp: "QApplication") -> Non `functools.partial(self._method)` or `lambda: self._method` being used in that widget's code. """ + # check for the "allow_leaks" marker + if "allow_leaks" in request.node.keywords: + yield + return + nbefore = len(qapp.topLevelWidgets()) failures_before = request.session.testsfailed yield diff --git a/tests/test_mda_viewer.py b/tests/test_mda_viewer.py new file mode 100644 index 000000000..4e6fa5b4c --- /dev/null +++ b/tests/test_mda_viewer.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pymmcore_plus import CMMCorePlus +from pymmcore_plus.mda.handlers import TensorStoreHandler +import pytest +from useq import MDASequence + +from pymmcore_widgets._stack_viewer_v2._mda_viewer import MDAViewer + +if TYPE_CHECKING: + from pymmcore_plus import CMMCorePlus + from pytestqt.qtbot import QtBot + + +@pytest.mark.allow_leaks +def test_mda_viewer(qtbot: QtBot, global_mmcore: CMMCorePlus) -> None: + + core = global_mmcore + core.defineConfig("Channel", "DAPI", "Camera", "Mode", "Artificial Waves") + core.defineConfig("Channel", "DAPI", "Camera", "StripeWidth", "1") + core.defineConfig("Channel", "FITC", "Camera", "Mode", "Artificial Waves") + core.defineConfig("Channel", "FITC", "Camera", "StripeWidth", "4") + + sequence = MDASequence( + channels=({"config": "DAPI", "exposure": 1}, {"config": "FITC", "exposure": 1}), + stage_positions=[(0, 0), (1, 1)], + z_plan={"range": 9, "step": 0.4}, + time_plan={"interval": 0.2, "loops": 4}, + # grid_plan={"rows": 2, "columns": 1}, + ) + v = MDAViewer() + qtbot.addWidget(v) + assert isinstance(v.data, TensorStoreHandler) + + core.mda.run(sequence, output=v.data) + assert v.data.current_sequence is sequence From ee879a933c574fe8863808065d7e3db5fcabaa36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Jun 2024 16:53:13 +0000 Subject: [PATCH 73/73] style(pre-commit.ci): auto fixes [...] --- tests/test_mda_viewer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mda_viewer.py b/tests/test_mda_viewer.py index 4e6fa5b4c..b3d4659d1 100644 --- a/tests/test_mda_viewer.py +++ b/tests/test_mda_viewer.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING +import pytest from pymmcore_plus import CMMCorePlus from pymmcore_plus.mda.handlers import TensorStoreHandler -import pytest from useq import MDASequence from pymmcore_widgets._stack_viewer_v2._mda_viewer import MDAViewer @@ -16,7 +16,6 @@ @pytest.mark.allow_leaks def test_mda_viewer(qtbot: QtBot, global_mmcore: CMMCorePlus) -> None: - core = global_mmcore core.defineConfig("Channel", "DAPI", "Camera", "Mode", "Artificial Waves") core.defineConfig("Channel", "DAPI", "Camera", "StripeWidth", "1")