Skip to content

Commit

Permalink
fix: retain WellPlatePlan type when used for MDASequence.stage_positi…
Browse files Browse the repository at this point in the history
…ons (#190)

This commit refactors the `_validate_stage_positions` method in the `MDASequence` class to improve the validation process for the `stage_positions` field. The method now checks if the value is an instance of `np.ndarray` and handles it accordingly. Additionally, it adds support for validating a `WellPlatePlan` object as the `stage_positions` value. The commit also includes a new test case in `test_well_plate.py` to ensure that the `stage_positions` field in `MDASequence` correctly handles a `WellPlatePlan` object.

Fixes #XYZ
  • Loading branch information
tlambert03 authored Oct 1, 2024
1 parent dea132e commit f689e06
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -183,8 +184,10 @@ class MDASequence(UseqModel):

metadata: Dict[str, Any] = Field(default_factory=dict)
axis_order: Tuple[str, ...] = AXES
# note that these are BOTH just `Sequence[Position]` but we retain the distinction
# here so that WellPlatePlans are preserved in the model instance.
stage_positions: Union[WellPlatePlan, Tuple[Position, ...]] = Field(
default_factory=tuple
default_factory=tuple, union_mode="left_to_right"
)
grid_plan: Optional[MultiPointPlan] = Field(
default=None, union_mode="left_to_right"
Expand Down Expand Up @@ -240,14 +243,23 @@ def _validate_channels(cls, value: Any) -> Tuple[Channel, ...]:
return tuple(channels)

@field_validator("stage_positions", mode="before")
def _validate_stage_positions(cls, value: Any) -> Tuple[Position, ...]:
def _validate_stage_positions(
cls, value: Any
) -> Union[WellPlatePlan, Tuple[Position, ...]]:
if isinstance(value, np.ndarray):
if value.ndim == 1:
value = [value]
elif value.ndim == 2:
value = list(value)
else:
with suppress(ValueError):
val = WellPlatePlan.model_validate(value)
return val
if not isinstance(value, Sequence): # pragma: no cover
raise ValueError(f"stage_positions must be a sequence, got {type(value)}")
raise ValueError(
"stage_positions must be a WellPlatePlan or Sequence[Position], "
f"got {type(value)}"
)

positions = []
for v in value:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_well_plate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import numpy as np
import pytest

Expand Down Expand Up @@ -169,3 +171,26 @@ def test_plate_repr() -> None:
rpp = repr(pp)
assert "selected_wells=(slice(8), slice(1, 2))" in rpp
assert eval(rpp, vars(useq)) == pp # noqa: S307


@pytest.mark.parametrize(
"pp",
[
useq.WellPlatePlan(
plate=96,
a1_center_xy=(500, 200),
rotation=5,
selected_wells=np.s_[1:5:2, :6:3],
),
{
"plate": 96,
"a1_center_xy": (500, 200),
"rotation": 5,
"selected_wells": np.s_[1:5:2, :6:3],
},
],
)
def test_plate_plan_in_seq(pp: Any) -> None:
seq = useq.MDASequence(stage_positions=pp)
assert isinstance(seq.stage_positions, useq.WellPlatePlan)
assert seq.stage_positions.plate.size == 96

0 comments on commit f689e06

Please sign in to comment.