Skip to content

Commit

Permalink
Make more use of AsyncStatus.wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
DominicOram committed Jul 22, 2024
1 parent e4c3d5a commit 9c16e6a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
5 changes: 3 additions & 2 deletions src/dodal/devices/aperturescatterguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,13 @@ def load_aperture_positions(self, positions: AperturePositions):
LOGGER.info(f"{self.name} loaded in {positions}")
self.aperture_positions = positions

def set(self, pos: SingleAperturePosition) -> AsyncStatus:
@AsyncStatus.wrap
async def set(self, pos: SingleAperturePosition):
assert isinstance(self.aperture_positions, AperturePositions)
if pos not in self.aperture_positions.as_list():
raise InvalidApertureMove(f"Unknown aperture: {pos}")

return AsyncStatus(self._safe_move_within_datacollection_range(pos.location))
await self._safe_move_within_datacollection_range(pos.location)

def _get_motor_list(self):
return [
Expand Down
14 changes: 6 additions & 8 deletions src/dodal/devices/undulator_dcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@ def __init__(
daq_configuration_path + "/domain/beamlineParameters"
)["DCM_Perp_Offset_FIXED"]

def set(self, value: float) -> AsyncStatus:
async def _set():
await asyncio.gather(
self._set_dcm_energy(value),
self._set_undulator_gap_if_required(value),
)

return AsyncStatus(_set())
@AsyncStatus.wrap
async def set(self, value: float):
await asyncio.gather(
self._set_dcm_energy(value),
self._set_undulator_gap_if_required(value),
)

async def _set_dcm_energy(self, energy_kev: float) -> None:
access_level = await self.undulator.gap_access.get_value()
Expand Down
7 changes: 3 additions & 4 deletions src/dodal/devices/zebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ async def _set_armed(self, demand: ArmDemand):
if reading == demand.value:
return

def set(self, demand: ArmDemand) -> AsyncStatus:
return AsyncStatus(
asyncio.wait_for(self._set_armed(demand), timeout=self.TIMEOUT)
)
@AsyncStatus.wrap
async def set(self, demand: ArmDemand):
await asyncio.wait_for(self._set_armed(demand), timeout=self.TIMEOUT)


class PositionCompare(StandardReadable):
Expand Down
11 changes: 6 additions & 5 deletions tests/devices/unit_tests/test_aperture_scatterguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from bluesky.run_engine import RunEngine
from ophyd_async.core import (
DeviceCollector,
get_mock_put,
set_mock_value,
)
Expand Down Expand Up @@ -36,10 +37,10 @@ def get_all_motors(ap_sg: ApertureScatterguard):


@pytest.fixture
async def ap_sg_and_call_log(aperture_positions: AperturePositions):
async def ap_sg_and_call_log(RE: RunEngine, aperture_positions: AperturePositions):
call_log = MagicMock()
ap_sg = ApertureScatterguard(name="test_ap_sg")
await ap_sg.connect(mock=True)
async with DeviceCollector(mock=True):
ap_sg = ApertureScatterguard(name="test_ap_sg")
ap_sg.load_aperture_positions(aperture_positions)
with ExitStack() as motor_patch_stack:
for motor in get_all_motors(ap_sg):
Expand Down Expand Up @@ -121,11 +122,11 @@ def _assert_patched_ap_sg_has_call(
)


def test_aperture_scatterguard_rejects_unknown_position(aperture_in_medium_pos):
async def test_aperture_scatterguard_rejects_unknown_position(aperture_in_medium_pos):
position_to_reject = ApertureFiveDimensionalLocation(0, 0, 0, 0, 0)

with pytest.raises(InvalidApertureMove):
aperture_in_medium_pos.set(
await aperture_in_medium_pos.set(
SingleAperturePosition("test", "GDA_NAME", 10, position_to_reject)
)

Expand Down

0 comments on commit 9c16e6a

Please sign in to comment.