Skip to content

Commit

Permalink
Move function for getting frame duration from dls-bluesky-core (from …
Browse files Browse the repository at this point in the history
…i22-bluesky) to ScanSpec
  • Loading branch information
DiamondJoseph committed Jan 29, 2024
1 parent c42ced1 commit 51bee05
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,3 +667,31 @@ def step(spec: Spec[Axis], duration: float, num: int = 1) -> Spec[Axis]:
spec = step(Line("x", 1, 2, 3), 0.1)
"""
return spec * Static.duration(duration, num)


def duration(frames: List[Frames]) -> Optional[float]:
"""
Returns the duration of a number of ScanSpec frames, if known and consistent.
Args:
frames (List[Frames]): A number of Frame objects
Returns:
duration (float): if all frames have a consistent duration
None: otherwise
"""
duration_frame = [
f for f in frames if DURATION in f.axes() and len(f.midpoints[DURATION])
]
if len(duration_frame) != 1 or len(duration_frame[0]) < 1:
# Either no frame has DURATION axis,
# the frame with a DURATION axis has 0 points,
# or multiple frames have DURATION axis
return None
durations = duration_frame[0].midpoints[DURATION]
first_duration = durations[0]
if np.any(durations != first_duration):
# Not all durations are the same
return None
return first_duration
60 changes: 60 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any, Tuple

import pytest
Expand All @@ -15,6 +16,7 @@
Squash,
Static,
Zip,
duration,
fly,
step,
)
Expand Down Expand Up @@ -558,3 +560,61 @@ def test_multiple_statics_with_grid():
)
def test_shape(spec: Spec, expected_shape: Tuple[int, ...]):
assert expected_shape == spec.shape()


def test_single_frame_single_point():
spec = Static.duration(0.1)
assert duration(spec.calculate()) == 0.1


def test_consistent_points():
spec = Static.duration(0.1).concat(Static.duration(0.1))
assert duration(spec.calculate()) == 0.1


def test_inconsistent_points():
spec = Static.duration(0.1).concat(Static.duration(0.2))
assert duration(spec.calculate()) is None


def test_frame_with_multiple_axes():
spec = Static.duration(0.1).zip(Line.bounded("x", 0, 0, 1))
frames = spec.calculate()
assert len(frames) == 1
assert duration(frames) == 0.1


def test_inconsistent_frame_with_multiple_axes():
spec = (
Static.duration(0.1)
.concat(Static.duration(0.2))
.zip(Line.bounded("x", 0, 0, 2))
)
frames = spec.calculate()
assert len(frames) == 1
assert duration(frames) is None


def test_non_static_spec_duration():
spec = Line.bounded(DURATION, 0, 0, 3)
frames = spec.calculate()
assert len(frames) == 1
assert duration(frames) == 0


def test_multiple_duration_frames():
spec = (
Static.duration(0.1)
.concat(Static.duration(0.2))
.zip(Line.bounded(DURATION, 0, 0, 2))
)
with pytest.raises(
AssertionError, match=re.escape("Zipping would overwrite axes ['DURATION']")
):
spec.calculate()
spec = ( # TODO: refactor when https://github.com/dls-controls/scanspec/issues/90
Static.duration(0.1) * Line.bounded(DURATION, 0, 0, 2)
)
frames = spec.calculate()
assert len(frames) == 2
assert duration(frames) is None

0 comments on commit 51bee05

Please sign in to comment.