Skip to content

Commit

Permalink
Made the API for the caching and moving wrapper more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
DominicOram committed Jul 23, 2024
1 parent 3029cd1 commit f276226
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 55 deletions.
42 changes: 29 additions & 13 deletions src/dodal/plans/motor_util_plans.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, Dict, Generator
from typing import Any, Dict, Generator, TypeVar

from bluesky import plan_stubs as bps
from bluesky.preprocessors import finalize_wrapper, pchain
Expand All @@ -9,10 +9,12 @@

from dodal.common import MsgGenerator

AnyDevice = TypeVar("AnyDevice", bound=Device)


class MoveTooLarge(Exception):
def __init__(
self, axis: Motor, maximum_move: float, position: float, *args: object
self, axis: Device, maximum_move: float, position: float, *args: object
) -> None:
self.axis = axis
self.maximum_move = maximum_move
Expand All @@ -21,30 +23,43 @@ def __init__(


def _check_and_cache_values(
device: Device,
devices_and_positions: Dict[AnyDevice, float],
smallest_move: float,
maximum_move: float,
home_position: float = 0,
) -> Generator[Msg, Any, Dict[Motor, float]]:
) -> Generator[Msg, Any, Dict[AnyDevice, float]]:
"""Caches the positions of all Motors on specified device if they are within
smallest_move of home_position. Throws MoveTooLarge if they are outside maximum_move
of the home_position
"""
positions = {}
motors_to_move = [axis for _, axis in device.children() if isinstance(axis, Motor)]
for axis in motors_to_move:
for axis, new_position in devices_and_positions.items():
position = yield from bps.rd(axis)
if abs(position - home_position) > maximum_move:
if abs(position - new_position) > maximum_move:
raise MoveTooLarge(axis, maximum_move, position)
if abs(position - home_position) > smallest_move:
if abs(position - new_position) > smallest_move:
positions[axis] = position
return positions


def home_and_reset_wrapper(
plan: MsgGenerator,
device: Device,
home_position: float,
smallest_move: float,
maximum_move: float,
group: str | None = None,
wait_for_all: bool = True,
) -> MsgGenerator:
home_positions = {
axis: 0.0 for _, axis in device.children() if isinstance(axis, Motor)
}
return move_and_reset_wrapper(
plan, home_positions, smallest_move, maximum_move, group, wait_for_all
)


def move_and_reset_wrapper(
plan: MsgGenerator,
device_and_positions: Dict[AnyDevice, float],
smallest_move: float,
maximum_move: float,
group: str | None = None,
Expand Down Expand Up @@ -72,13 +87,14 @@ def home_and_reset_wrapper(
them. Defaults to True.
"""
initial_positions = yield from _check_and_cache_values(
device, smallest_move, maximum_move, home_position
device_and_positions, smallest_move, maximum_move
)

def move_to_home():
home_group = f"home-{group if group else str(uuid.uuid4())[:6]}"
for axis, _ in initial_positions.items():
yield from bps.abs_set(axis, home_position, group=home_group)
for axis, position in device_and_positions.items():
if axis in initial_positions.keys():
yield from bps.abs_set(axis, position, group=home_group)
if wait_for_all:
yield from bps.wait(home_group)

Expand Down
101 changes: 59 additions & 42 deletions tests/plans/test_motor_util_plans.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List
from unittest.mock import ANY, call
from unittest.mock import ANY, MagicMock, call, patch

import pytest
from bluesky import plan_stubs as bps
Expand Down Expand Up @@ -56,23 +56,33 @@ def my_device(RE):


@pytest.mark.parametrize(
"device",
[
DeviceWithOnlyMotors,
DeviceWithNoMotors,
DeviceWithSomeMotors,
],
"device_type",
[DeviceWithOnlyMotors, DeviceWithNoMotors, DeviceWithSomeMotors],
)
def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(device):
RE = RunEngine(call_returns_result=True)
@patch("dodal.plans.motor_util_plans.move_and_reset_wrapper")
def test_given_types_of_device_when_home_and_reset_wrapper_called_then_motors_and_zeros_passed_to_move_and_reset_wrapper(
patch_move_and_reset, device_type, RE
):
with DeviceCollector(mock=True):
my_device: TestMotorDevice = device()
device = device_type()
home_and_reset_wrapper(MagicMock(), device, 0, 0)

home_positions = patch_move_and_reset.call_args.args[1]
assert home_positions == {motor_obj: 0.0 for motor_obj in device.motors}


def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(
my_device,
):
RE = RunEngine(call_returns_result=True)

for i, motor in enumerate(my_device.motors, start=1):
set_mock_value(motor.user_readback, i * 100)

motors_and_positions: Dict[Motor, float] = RE(
_check_and_cache_values(my_device, 0, 1000)
_check_and_cache_values(
{motor_obj: 0.0 for motor_obj in my_device.motors}, 0, 1000
)
).plan_result # type: ignore
cached_positions = motors_and_positions.values()

Expand All @@ -82,7 +92,7 @@ def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(d


@pytest.mark.parametrize(
"initial, max, home",
"initial, max, new_position",
[
(200, 100, 0),
(-200, 100, 0),
Expand All @@ -91,19 +101,21 @@ def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(d
],
)
def test_given_a_device_with_a_too_large_move_when_check_and_cache_values_then_exception_thrown(
RE, my_device, initial, max, home
RE, my_device, initial, max, new_position
):
set_mock_value(my_device.x.user_readback, 10)
set_mock_value(my_device.y.user_readback, initial)

motors_and_positions = {motor_obj: new_position for motor_obj in my_device.motors}

with pytest.raises(MoveTooLarge) as e:
RE(_check_and_cache_values(my_device, 0, max, home_position=home))
RE(_check_and_cache_values(motors_and_positions, 0, max))
assert e.value.axis == my_device.y
assert e.value.maximum_move == max


@pytest.mark.parametrize(
"initial, min, home",
"initial, min, new_position",
[
(50, 5, 49),
(48, 5, 49),
Expand All @@ -112,15 +124,19 @@ def test_given_a_device_with_a_too_large_move_when_check_and_cache_values_then_e
],
)
def test_given_a_device_where_one_move_too_small_when_check_and_cache_values_then_other_positions_returned(
my_device, initial, min, home
my_device, initial, min, new_position
):
RE = RunEngine(call_returns_result=True)

set_mock_value(my_device.x.user_readback, initial)
set_mock_value(my_device.y.user_readback, 200)

motors_and_new_positions = {
motor_obj: new_position for motor_obj in my_device.motors
}

motors_and_positions: Dict[Motor, float] = RE(
_check_and_cache_values(my_device, min, 1000, home_position=home)
_check_and_cache_values(motors_and_new_positions, min, 1000)
).plan_result # type: ignore
cached_positions = motors_and_positions.values()

Expand All @@ -137,25 +153,30 @@ def test_given_a_device_where_all_moves_too_small_when_check_and_cache_values_th
set_mock_value(my_device.x.user_readback, 10)
set_mock_value(my_device.y.user_readback, 20)

motors_and_new_positions = {motor_obj: 0.0 for motor_obj in my_device.motors}

motors_and_positions: Dict[Motor, float] = RE(
_check_and_cache_values(my_device, 40, 1000)
_check_and_cache_values(motors_and_new_positions, 40, 1000)
).plan_result # type: ignore
cached_positions = motors_and_positions.values()

assert len(cached_positions) == 0


@pytest.mark.parametrize(
"initial_x, initial_y, home",
"initial_x, initial_y",
[
(10, 20, 0),
(150, 40, 95),
(-56, 50, 78),
(74, -89, -2),
(10, 20),
(150, 40),
(-56, 50),
(74, -89),
],
)
def test_when_home_and_reset_wrapper_called_with_null_plan_then_motors_homed_and_reset(
RE, my_device, initial_x, initial_y, home
RE,
my_device,
initial_x,
initial_y,
):
def my_plan():
yield from bps.null()
Expand All @@ -167,32 +188,31 @@ def my_plan():
home_and_reset_wrapper(
my_plan(),
my_device,
home,
0,
1000,
)
)

get_mock_put(my_device.x.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
[call(0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
)

get_mock_put(my_device.y.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
[call(0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
)


@pytest.mark.parametrize(
"initial, min, home",
"initial, min",
[
(50, 5, 49),
(48, 5, 49),
(100, 50, 105),
(5, 10, -2),
(1, 5),
(-1, 5),
(-5, 50),
(7, 10),
],
)
def test_given_motors_already_close_to_home_when_home_and_reset_wrapper_called_then_motors_do_not_move(
RE, my_device, initial, home, min
RE, my_device, initial, min
):
def my_plan():
yield from bps.null()
Expand All @@ -204,7 +224,6 @@ def my_plan():
home_and_reset_wrapper(
my_plan(),
my_device,
home,
min,
1000,
)
Expand Down Expand Up @@ -250,7 +269,7 @@ def my_plan():


def test_given_home_and_reset_inner_plan_fails_reset_still(RE, my_device):
initial_x, initial_y, home = 10, 20, 6
initial_x, initial_y = 10, 20

def my_plan():
yield from bps.null()
Expand All @@ -264,18 +283,17 @@ def my_plan():
home_and_reset_wrapper(
my_plan(),
my_device,
home,
0,
1000,
)
)

get_mock_put(my_device.x.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
[call(0.0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
)

get_mock_put(my_device.y.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
[call(0.0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
)


Expand All @@ -284,7 +302,7 @@ def my_plan():
["x", "y"],
)
def test_given_move_to_home_fails_reset_still(RE, my_device, move_that_failed):
initial_x, initial_y, home = 10, 20, 6
initial_x, initial_y = 10, 20

def my_plan():
# This will never get called as fails before
Expand All @@ -301,16 +319,15 @@ def my_plan():
home_and_reset_wrapper(
my_plan(),
my_device,
home,
0,
1000,
)
)

get_mock_put(my_device.x.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
[call(0.0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)]
)

get_mock_put(my_device.y.user_setpoint).assert_has_calls(
[call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
[call(0.0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)]
)

0 comments on commit f276226

Please sign in to comment.