From 2969ba5d05ccb9e7218ecf6ddf56df66c2a68f20 Mon Sep 17 00:00:00 2001 From: Niko Kivel Date: Thu, 9 Jan 2025 14:57:23 +0000 Subject: [PATCH] assert removed from StandardReadable.add_readables - added TypeGaurd to make Pylance happy. - added TypeError raised test --- src/ophyd_async/core/_readable.py | 40 ++++++++++++++++++------------- tests/core/test_readable.py | 19 +++++++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/ophyd_async/core/_readable.py b/src/ophyd_async/core/_readable.py index 4c424d7d9a..3215dcd8d9 100644 --- a/src/ophyd_async/core/_readable.py +++ b/src/ophyd_async/core/_readable.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable, Generator, Sequence from contextlib import contextmanager from enum import Enum -from typing import Any, cast +from typing import Any, TypeGuard, cast from bluesky.protocols import HasHints, Hints, Reading from event_model import DataKey @@ -205,6 +205,14 @@ def add_readables( `StandardReadableFormat` documentation """ + def is_signalr(device: Device) -> TypeGuard[SignalR]: + return isinstance(device, SignalR) + + def assert_device_is_signalr(device: Device) -> SignalR: + if not is_signalr(device): + raise TypeError(f"{device} is not a SignalR") + return device + for device in devices: match format: case StandardReadableFormat.CHILD: @@ -219,24 +227,24 @@ def add_readables( if isinstance(device, HasHints): self._has_hints += (device,) case StandardReadableFormat.CONFIG_SIGNAL: - assert isinstance(device, SignalR), f"{device} is not a SignalR" - self._describe_config_funcs += (device.describe,) - self._read_config_funcs += (device.read,) + signalr_device = assert_device_is_signalr(device=device) + self._describe_config_funcs += (signalr_device.describe,) + self._read_config_funcs += (signalr_device.read,) case StandardReadableFormat.HINTED_SIGNAL: - assert isinstance(device, SignalR), f"{device} is not a SignalR" - self._describe_funcs += (device.describe,) - self._read_funcs += (device.read,) - self._stageables += (device,) - self._has_hints += (_HintsFromName(device),) + signalr_device = assert_device_is_signalr(device=device) + self._describe_funcs += (signalr_device.describe,) + self._read_funcs += (signalr_device.read,) + self._stageables += (signalr_device,) + self._has_hints += (_HintsFromName(signalr_device),) case StandardReadableFormat.UNCACHED_SIGNAL: - assert isinstance(device, SignalR), f"{device} is not a SignalR" - self._describe_funcs += (device.describe,) - self._read_funcs += (_UncachedRead(device),) + signalr_device = assert_device_is_signalr(device=device) + self._describe_funcs += (signalr_device.describe,) + self._read_funcs += (_UncachedRead(signalr_device),) case StandardReadableFormat.HINTED_UNCACHED_SIGNAL: - assert isinstance(device, SignalR), f"{device} is not a SignalR" - self._describe_funcs += (device.describe,) - self._read_funcs += (_UncachedRead(device),) - self._has_hints += (_HintsFromName(device),) + signalr_device = assert_device_is_signalr(device=device) + self._describe_funcs += (signalr_device.describe,) + self._read_funcs += (_UncachedRead(signalr_device),) + self._has_hints += (_HintsFromName(signalr_device),) class _UncachedRead: diff --git a/tests/core/test_readable.py b/tests/core/test_readable.py index 8ab69ed03c..7d2771bd66 100644 --- a/tests/core/test_readable.py +++ b/tests/core/test_readable.py @@ -209,6 +209,25 @@ def test_standard_readable_add_readables_adds_to_expected_attrs( assert_sr_has_attrs(sr, expected_attrs) +@pytest.mark.parametrize( + "format", + [ + Format.CONFIG_SIGNAL, + Format.HINTED_SIGNAL, + Format.UNCACHED_SIGNAL, + Format.HINTED_UNCACHED_SIGNAL, + ], +) +def test_standard_readable_add_readables_raises_signalr_typeerror(format) -> None: + # Mock a Device instance that is not a SignalR + mock_device = MagicMock(spec=Device) + sr = StandardReadable() + + # Ensure it raises TypeError + with pytest.raises(TypeError, match=f"{mock_device} is not a SignalR"): + sr.add_readables([mock_device], format=format) + + def test_standard_readable_config_signal(): signal_r = MagicMock(spec=SignalR) sr = StandardReadable()