From f276226f9e1377319310e34b89c4c0c685e1cd98 Mon Sep 17 00:00:00 2001 From: Dominic Oram Date: Tue, 23 Jul 2024 12:16:53 +0100 Subject: [PATCH] Made the API for the caching and moving wrapper more generic --- src/dodal/plans/motor_util_plans.py | 42 +++++++---- tests/plans/test_motor_util_plans.py | 101 ++++++++++++++++----------- 2 files changed, 88 insertions(+), 55 deletions(-) diff --git a/src/dodal/plans/motor_util_plans.py b/src/dodal/plans/motor_util_plans.py index cf36529dd3..cfca14e9fb 100644 --- a/src/dodal/plans/motor_util_plans.py +++ b/src/dodal/plans/motor_util_plans.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, Generator +from typing import Any, Dict, Generator, TypeVar from bluesky import plan_stubs as bps from bluesky.preprocessors import finalize_wrapper, pchain @@ -9,10 +9,12 @@ from dodal.common import MsgGenerator +AnyDevice = TypeVar("AnyDevice", bound=Device) + class MoveTooLarge(Exception): def __init__( - self, axis: Motor, maximum_move: float, position: float, *args: object + self, axis: Device, maximum_move: float, position: float, *args: object ) -> None: self.axis = axis self.maximum_move = maximum_move @@ -21,22 +23,20 @@ def __init__( def _check_and_cache_values( - device: Device, + devices_and_positions: Dict[AnyDevice, float], smallest_move: float, maximum_move: float, - home_position: float = 0, -) -> Generator[Msg, Any, Dict[Motor, float]]: +) -> Generator[Msg, Any, Dict[AnyDevice, float]]: """Caches the positions of all Motors on specified device if they are within smallest_move of home_position. Throws MoveTooLarge if they are outside maximum_move of the home_position """ positions = {} - motors_to_move = [axis for _, axis in device.children() if isinstance(axis, Motor)] - for axis in motors_to_move: + for axis, new_position in devices_and_positions.items(): position = yield from bps.rd(axis) - if abs(position - home_position) > maximum_move: + if abs(position - new_position) > maximum_move: raise MoveTooLarge(axis, maximum_move, position) - if abs(position - home_position) > smallest_move: + if abs(position - new_position) > smallest_move: positions[axis] = position return positions @@ -44,7 +44,22 @@ def _check_and_cache_values( def home_and_reset_wrapper( plan: MsgGenerator, device: Device, - home_position: float, + smallest_move: float, + maximum_move: float, + group: str | None = None, + wait_for_all: bool = True, +) -> MsgGenerator: + home_positions = { + axis: 0.0 for _, axis in device.children() if isinstance(axis, Motor) + } + return move_and_reset_wrapper( + plan, home_positions, smallest_move, maximum_move, group, wait_for_all + ) + + +def move_and_reset_wrapper( + plan: MsgGenerator, + device_and_positions: Dict[AnyDevice, float], smallest_move: float, maximum_move: float, group: str | None = None, @@ -72,13 +87,14 @@ def home_and_reset_wrapper( them. Defaults to True. """ initial_positions = yield from _check_and_cache_values( - device, smallest_move, maximum_move, home_position + device_and_positions, smallest_move, maximum_move ) def move_to_home(): home_group = f"home-{group if group else str(uuid.uuid4())[:6]}" - for axis, _ in initial_positions.items(): - yield from bps.abs_set(axis, home_position, group=home_group) + for axis, position in device_and_positions.items(): + if axis in initial_positions.keys(): + yield from bps.abs_set(axis, position, group=home_group) if wait_for_all: yield from bps.wait(home_group) diff --git a/tests/plans/test_motor_util_plans.py b/tests/plans/test_motor_util_plans.py index 9a176b1ebb..7a874f047d 100644 --- a/tests/plans/test_motor_util_plans.py +++ b/tests/plans/test_motor_util_plans.py @@ -1,5 +1,5 @@ from typing import Dict, List -from unittest.mock import ANY, call +from unittest.mock import ANY, MagicMock, call, patch import pytest from bluesky import plan_stubs as bps @@ -56,23 +56,33 @@ def my_device(RE): @pytest.mark.parametrize( - "device", - [ - DeviceWithOnlyMotors, - DeviceWithNoMotors, - DeviceWithSomeMotors, - ], + "device_type", + [DeviceWithOnlyMotors, DeviceWithNoMotors, DeviceWithSomeMotors], ) -def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(device): - RE = RunEngine(call_returns_result=True) +@patch("dodal.plans.motor_util_plans.move_and_reset_wrapper") +def test_given_types_of_device_when_home_and_reset_wrapper_called_then_motors_and_zeros_passed_to_move_and_reset_wrapper( + patch_move_and_reset, device_type, RE +): with DeviceCollector(mock=True): - my_device: TestMotorDevice = device() + device = device_type() + home_and_reset_wrapper(MagicMock(), device, 0, 0) + + home_positions = patch_move_and_reset.call_args.args[1] + assert home_positions == {motor_obj: 0.0 for motor_obj in device.motors} + + +def test_given_a_device_when_check_and_cache_values_then_motor_values_returned( + my_device, +): + RE = RunEngine(call_returns_result=True) for i, motor in enumerate(my_device.motors, start=1): set_mock_value(motor.user_readback, i * 100) motors_and_positions: Dict[Motor, float] = RE( - _check_and_cache_values(my_device, 0, 1000) + _check_and_cache_values( + {motor_obj: 0.0 for motor_obj in my_device.motors}, 0, 1000 + ) ).plan_result # type: ignore cached_positions = motors_and_positions.values() @@ -82,7 +92,7 @@ def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(d @pytest.mark.parametrize( - "initial, max, home", + "initial, max, new_position", [ (200, 100, 0), (-200, 100, 0), @@ -91,19 +101,21 @@ def test_given_a_device_when_check_and_cache_values_then_motor_values_returned(d ], ) def test_given_a_device_with_a_too_large_move_when_check_and_cache_values_then_exception_thrown( - RE, my_device, initial, max, home + RE, my_device, initial, max, new_position ): set_mock_value(my_device.x.user_readback, 10) set_mock_value(my_device.y.user_readback, initial) + motors_and_positions = {motor_obj: new_position for motor_obj in my_device.motors} + with pytest.raises(MoveTooLarge) as e: - RE(_check_and_cache_values(my_device, 0, max, home_position=home)) + RE(_check_and_cache_values(motors_and_positions, 0, max)) assert e.value.axis == my_device.y assert e.value.maximum_move == max @pytest.mark.parametrize( - "initial, min, home", + "initial, min, new_position", [ (50, 5, 49), (48, 5, 49), @@ -112,15 +124,19 @@ def test_given_a_device_with_a_too_large_move_when_check_and_cache_values_then_e ], ) def test_given_a_device_where_one_move_too_small_when_check_and_cache_values_then_other_positions_returned( - my_device, initial, min, home + my_device, initial, min, new_position ): RE = RunEngine(call_returns_result=True) set_mock_value(my_device.x.user_readback, initial) set_mock_value(my_device.y.user_readback, 200) + motors_and_new_positions = { + motor_obj: new_position for motor_obj in my_device.motors + } + motors_and_positions: Dict[Motor, float] = RE( - _check_and_cache_values(my_device, min, 1000, home_position=home) + _check_and_cache_values(motors_and_new_positions, min, 1000) ).plan_result # type: ignore cached_positions = motors_and_positions.values() @@ -137,8 +153,10 @@ def test_given_a_device_where_all_moves_too_small_when_check_and_cache_values_th set_mock_value(my_device.x.user_readback, 10) set_mock_value(my_device.y.user_readback, 20) + motors_and_new_positions = {motor_obj: 0.0 for motor_obj in my_device.motors} + motors_and_positions: Dict[Motor, float] = RE( - _check_and_cache_values(my_device, 40, 1000) + _check_and_cache_values(motors_and_new_positions, 40, 1000) ).plan_result # type: ignore cached_positions = motors_and_positions.values() @@ -146,16 +164,19 @@ def test_given_a_device_where_all_moves_too_small_when_check_and_cache_values_th @pytest.mark.parametrize( - "initial_x, initial_y, home", + "initial_x, initial_y", [ - (10, 20, 0), - (150, 40, 95), - (-56, 50, 78), - (74, -89, -2), + (10, 20), + (150, 40), + (-56, 50), + (74, -89), ], ) def test_when_home_and_reset_wrapper_called_with_null_plan_then_motors_homed_and_reset( - RE, my_device, initial_x, initial_y, home + RE, + my_device, + initial_x, + initial_y, ): def my_plan(): yield from bps.null() @@ -167,32 +188,31 @@ def my_plan(): home_and_reset_wrapper( my_plan(), my_device, - home, 0, 1000, ) ) get_mock_put(my_device.x.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] + [call(0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] ) get_mock_put(my_device.y.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] + [call(0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] ) @pytest.mark.parametrize( - "initial, min, home", + "initial, min", [ - (50, 5, 49), - (48, 5, 49), - (100, 50, 105), - (5, 10, -2), + (1, 5), + (-1, 5), + (-5, 50), + (7, 10), ], ) def test_given_motors_already_close_to_home_when_home_and_reset_wrapper_called_then_motors_do_not_move( - RE, my_device, initial, home, min + RE, my_device, initial, min ): def my_plan(): yield from bps.null() @@ -204,7 +224,6 @@ def my_plan(): home_and_reset_wrapper( my_plan(), my_device, - home, min, 1000, ) @@ -250,7 +269,7 @@ def my_plan(): def test_given_home_and_reset_inner_plan_fails_reset_still(RE, my_device): - initial_x, initial_y, home = 10, 20, 6 + initial_x, initial_y = 10, 20 def my_plan(): yield from bps.null() @@ -264,18 +283,17 @@ def my_plan(): home_and_reset_wrapper( my_plan(), my_device, - home, 0, 1000, ) ) get_mock_put(my_device.x.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] + [call(0.0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] ) get_mock_put(my_device.y.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] + [call(0.0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] ) @@ -284,7 +302,7 @@ def my_plan(): ["x", "y"], ) def test_given_move_to_home_fails_reset_still(RE, my_device, move_that_failed): - initial_x, initial_y, home = 10, 20, 6 + initial_x, initial_y = 10, 20 def my_plan(): # This will never get called as fails before @@ -301,16 +319,15 @@ def my_plan(): home_and_reset_wrapper( my_plan(), my_device, - home, 0, 1000, ) ) get_mock_put(my_device.x.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] + [call(0.0, wait=ANY, timeout=ANY), call(initial_x, wait=ANY, timeout=ANY)] ) get_mock_put(my_device.y.user_setpoint).assert_has_calls( - [call(home, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] + [call(0.0, wait=ANY, timeout=ANY), call(initial_y, wait=ANY, timeout=ANY)] )