From 78514b71c729e12a40145222dbb62377d1570d5d Mon Sep 17 00:00:00 2001 From: Andrew Sayre <6730289+andrewsayre@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:39:34 +0000 Subject: [PATCH] Enable mypy strict typing --- .pre-commit-config.yaml | 2 +- pyheos/command/browse.py | 2 +- pyheos/command/group.py | 4 ++-- pyheos/command/player.py | 14 ++++---------- pyheos/command/system.py | 2 +- pyheos/connection.py | 22 +++++++++++----------- pyheos/dispatch.py | 30 +++++++++++++----------------- pyheos/media.py | 2 +- pyheos/system.py | 3 ++- tests/__init__.py | 16 ++++++++-------- tests/conftest.py | 7 ++++--- tests/test_dispatch.py | 18 +++++++++--------- tests/test_heos.py | 5 +++-- tests/test_heos_browse.py | 3 ++- tests/test_heos_callback.py | 3 ++- 15 files changed, 64 insertions(+), 69 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 742cbf3..1ecb2d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: rev: v1.14.1 hooks: - id: mypy - args: [] + args: [--strict] additional_dependencies: - pydantic==2.10.5 - pylint==3.3.3 diff --git a/pyheos/command/browse.py b/pyheos/command/browse.py index 92538ec..5751a6e 100644 --- a/pyheos/command/browse.py +++ b/pyheos/command/browse.py @@ -80,7 +80,7 @@ async def get_music_sources( HeosCommand(c.COMMAND_BROWSE_GET_SOURCES, params) ) self._music_sources.clear() - for data in cast(Sequence[dict], message.payload): + for data in cast(Sequence[dict[str, Any]], message.payload): source = MediaMusicSource.from_data(data, cast("Heos", self)) self._music_sources[source.source_id] = source self._music_sources_loaded = True diff --git a/pyheos/command/group.py b/pyheos/command/group.py index bfb4b71..1d2c12a 100644 --- a/pyheos/command/group.py +++ b/pyheos/command/group.py @@ -38,9 +38,9 @@ async def get_groups(self, *, refresh: bool = False) -> dict[int, HeosGroup]: References: 4.3.1 Get Groups""" if not self._groups_loaded or refresh: - groups = {} + groups: dict[int, HeosGroup] = {} result = await self._connection.command(HeosCommand(c.COMMAND_GET_GROUPS)) - payload = cast(Sequence[dict], result.payload) + payload = cast(Sequence[dict[str, Any]], result.payload) for data in payload: group = HeosGroup._from_data(data, cast("Heos", self)) groups[group.group_id] = group diff --git a/pyheos/command/player.py b/pyheos/command/player.py index 41f40ff..7e149cc 100644 --- a/pyheos/command/player.py +++ b/pyheos/command/player.py @@ -13,14 +13,8 @@ from pyheos.command.connection import ConnectionMixin from pyheos.media import QueueItem from pyheos.message import HeosCommand -from pyheos.player import ( - HeosNowPlayingMedia, - HeosPlayer, - PlayerUpdateResult, - PlayMode, - PlayState, -) -from pyheos.types import RepeatType +from pyheos.player import HeosNowPlayingMedia, HeosPlayer, PlayerUpdateResult, PlayMode +from pyheos.types import PlayState, RepeatType if TYPE_CHECKING: from pyheos.heos import Heos @@ -102,7 +96,7 @@ async def load_players(self) -> PlayerUpdateResult: players: dict[int, HeosPlayer] = {} response = await self._connection.command(HeosCommand(c.COMMAND_GET_PLAYERS)) - payload = cast(Sequence[dict], response.payload) + payload = cast(Sequence[dict[str, str]], response.payload) existing = list(self._players.values()) for player_data in payload: player_id = int(player_data[c.ATTR_PLAYER_ID]) @@ -463,7 +457,7 @@ async def player_get_quick_selects(self, player_id: int) -> dict[int, str]: ) return { int(data[c.ATTR_ID]): data[c.ATTR_NAME] - for data in cast(list[dict], result.payload) + for data in cast(list[dict[str, Any]], result.payload) } async def player_check_update(self, player_id: int) -> bool: diff --git a/pyheos/command/system.py b/pyheos/command/system.py index dbf4123..f1fcc4c 100644 --- a/pyheos/command/system.py +++ b/pyheos/command/system.py @@ -132,7 +132,7 @@ async def get_system_info(self) -> HeosSystem: References: 4.2.1 Get Players""" response = await self._connection.command(HeosCommand(c.COMMAND_GET_PLAYERS)) - payload = cast(Sequence[dict], response.payload) + payload = cast(Sequence[dict[str, Any]], response.payload) hosts = list([HeosHost._from_data(item) for item in payload]) host = next(host for host in hosts if host.ip_address == self._options.host) return HeosSystem(self._signed_in_username, host, hosts) diff --git a/pyheos/connection.py b/pyheos/connection.py index 40acfce..7a62fef 100644 --- a/pyheos/connection.py +++ b/pyheos/connection.py @@ -2,7 +2,7 @@ import asyncio import logging -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable from contextlib import suppress from datetime import datetime, timedelta from typing import TYPE_CHECKING, Final @@ -35,15 +35,15 @@ def __init__(self, host: str, *, timeout: float) -> None: self._state: ConnectionState = ConnectionState.DISCONNECTED self._writer: asyncio.StreamWriter | None = None self._pending_command_event = ResponseEvent() - self._running_tasks: set[asyncio.Task] = set() + self._running_tasks: set[asyncio.Task[None]] = set() self._last_activity: datetime = datetime.now() self._command_lock = asyncio.Lock() - self._on_event_callbacks: list[Callable[[HeosMessage], Awaitable]] = [] - self._on_connected_callbacks: list[Callable[[], Awaitable]] = [] - self._on_disconnected_callbacks: list[Callable[[bool], Awaitable]] = [] + self._on_event_callbacks: list[Callable[[HeosMessage], Awaitable[None]]] = [] + self._on_connected_callbacks: list[Callable[[], Awaitable[None]]] = [] + self._on_disconnected_callbacks: list[Callable[[bool], Awaitable[None]]] = [] self._on_command_error_callbacks: list[ - Callable[[CommandFailedError], Awaitable] + Callable[[CommandFailedError], Awaitable[None]] ] = [] @property @@ -51,7 +51,7 @@ def state(self) -> ConnectionState: """Get the current state of the connection.""" return self._state - def add_on_event(self, callback: Callable[[HeosMessage], Awaitable]) -> None: + def add_on_event(self, callback: Callable[[HeosMessage], Awaitable[None]]) -> None: """Add a callback to be invoked when an event is received.""" self._on_event_callbacks.append(callback) @@ -60,7 +60,7 @@ async def _on_event(self, message: HeosMessage) -> None: for callback in self._on_event_callbacks: await callback(message) - def add_on_connected(self, callback: Callable[[], Awaitable]) -> None: + def add_on_connected(self, callback: Callable[[], Awaitable[None]]) -> None: """Add a callback to be invoked when connected.""" self._on_connected_callbacks.append(callback) @@ -69,7 +69,7 @@ async def _on_connected(self) -> None: for callback in self._on_connected_callbacks: await callback() - def add_on_disconnected(self, callback: Callable[[bool], Awaitable]) -> None: + def add_on_disconnected(self, callback: Callable[[bool], Awaitable[None]]) -> None: """Add a callback to be invoked when connected.""" self._on_disconnected_callbacks.append(callback) @@ -79,7 +79,7 @@ async def _on_disconnected(self, due_to_error: bool = False) -> None: await callback(due_to_error) def add_on_command_error( - self, callback: Callable[[CommandFailedError], Awaitable] + self, callback: Callable[[CommandFailedError], Awaitable[None]] ) -> None: """Add a callback to be invoked when a command error occurs.""" self._on_command_error_callbacks.append(callback) @@ -89,7 +89,7 @@ async def _on_command_error(self, error: CommandFailedError) -> None: for callback in self._on_command_error_callbacks: await callback(error) - def _register_task(self, future: Coroutine) -> None: + def _register_task(self, future: Awaitable[None]) -> None: """Register a task that is running in the background, so it can be canceled and reset later.""" task = asyncio.ensure_future(future) self._running_tasks.add(task) diff --git a/pyheos/dispatch.py b/pyheos/dispatch.py index 326a645..a5befab 100644 --- a/pyheos/dispatch.py +++ b/pyheos/dispatch.py @@ -5,24 +5,20 @@ import logging from collections import defaultdict from collections.abc import Callable, Sequence -from typing import Any, Final, TypeVar +from typing import Any, Final _LOGGER: Final = logging.getLogger(__name__) TargetType = Callable[..., Any] DisconnectType = Callable[[], None] ConnectType = Callable[[str, TargetType], DisconnectType] -SendType = Callable[..., Sequence[asyncio.Future]] - -TEvent = TypeVar("TEvent", bound=str) -TPlayerId = TypeVar("TPlayerId", bound=int) -TGroupId = TypeVar("TGroupId", bound=int) +SendType = Callable[..., Sequence[asyncio.Future[Any]]] CallbackType = Callable[[], Any] -EventCallbackType = Callable[[TEvent], Any] -ControllerEventCallbackType = Callable[[TEvent, Any], Any] -PlayerEventCallbackType = Callable[[TPlayerId, TEvent], Any] -GroupEventCallbackType = Callable[[TGroupId, TEvent], Any] +EventCallbackType = Callable[[str], Any] +ControllerEventCallbackType = Callable[[str, Any], Any] +PlayerEventCallbackType = Callable[[int, str], Any] +GroupEventCallbackType = Callable[[int, str], Any] def _is_coroutine_function(func: TargetType) -> bool: @@ -81,12 +77,12 @@ def __init__( ) -> None: """Create a new instance of the dispatch component.""" self._signal_prefix = signal_prefix - self._signals: dict[str, list] = defaultdict(list) + self._signals: dict[str, list[TargetType]] = defaultdict(list) self._loop = loop or asyncio.get_running_loop() self._connect = connect or self._default_connect self._send = send or self._default_send - self._disconnects: list[Callable] = [] - self._running_tasks: set[asyncio.Future] = set() + self._disconnects: list[DisconnectType] = [] + self._running_tasks: set[asyncio.Future[Any]] = set() def connect(self, signal: str, target: TargetType) -> DisconnectType: """Connect function to signal. Must be ran in the event loop.""" @@ -94,7 +90,7 @@ def connect(self, signal: str, target: TargetType) -> DisconnectType: self._disconnects.append(disconnect) return disconnect - def send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]: + def send(self, signal: str, *args: Any) -> Sequence[asyncio.Future[Any]]: """Fire a signal. Must be ran in the event loop.""" return self._send(self._signal_prefix + signal, *args) @@ -137,7 +133,7 @@ def remove_dispatcher() -> None: return remove_dispatcher - def _log_target_exception(self, future: asyncio.Future) -> None: + def _log_target_exception(self, future: asyncio.Future[Any]) -> None: """Log the exception from the target, if raised.""" if not future.cancelled() and future.exception(): _LOGGER.exception( @@ -146,7 +142,7 @@ def _log_target_exception(self, future: asyncio.Future) -> None: exc_info=future.exception(), ) - def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]: + def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future[Any]]: """Fire a signal. Must be ran in the event loop.""" targets = self._signals[signal] futures = [] @@ -158,7 +154,7 @@ def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]: futures.append(task) return futures - def _call_target(self, target: Callable, *args: Any) -> asyncio.Future: + def _call_target(self, target: TargetType, *args: Any) -> asyncio.Future[Any]: if _is_coroutine_function(target): return self._loop.create_task(target(*args)) return self._loop.run_in_executor(None, target, *args) diff --git a/pyheos/media.py b/pyheos/media.py index 68ee63b..ca7f31a 100644 --- a/pyheos/media.py +++ b/pyheos/media.py @@ -247,7 +247,7 @@ def _from_message( items=list( [ MediaItem.from_data(item, source_id, container_id, heos) - for item in cast(Sequence[dict], message.payload) + for item in cast(Sequence[dict[str, Any]], message.payload) ] ), options=ServiceOption._from_options(message.options), diff --git a/pyheos/system.py b/pyheos/system.py index 00058fe..83ac82e 100644 --- a/pyheos/system.py +++ b/pyheos/system.py @@ -1,6 +1,7 @@ """Define the System module.""" from dataclasses import dataclass, field +from typing import Any from pyheos import command as c from pyheos.types import NetworkType @@ -21,7 +22,7 @@ class HeosHost: network: NetworkType @staticmethod - def _from_data(data: dict[str, str]) -> "HeosHost": + def _from_data(data: dict[str, Any]) -> "HeosHost": """Create a HeosHost object from a dictionary. Args: diff --git a/tests/__init__.py b/tests/__init__.py index ba3d675..3b73ce3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -68,7 +68,7 @@ def get_value(self, args: dict[str, Any]) -> Any: return arg_value -def calls_commands(*commands: CallCommand) -> Callable: +def calls_commands(*commands: CallCommand) -> Callable[..., Any]: """ Decorator that registers commands prior to test execution. @@ -76,7 +76,7 @@ def calls_commands(*commands: CallCommand) -> Callable: commands: The commands to register. """ - def wrapper(func: Callable) -> Callable: + def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) async def wrapped(*args: Any, **kwargs: Any) -> Any: # Build a list of commands that match the when conditions @@ -118,7 +118,7 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any: ) # Register commands - assert_list: list[Callable] = [] + assert_list: list[Callable[..., None]] = [] for command in matched_commands: # Get the fixture command @@ -174,7 +174,7 @@ def calls_command( when: dict[str, Any] | None = None, replace: bool = False, add_command_under_process: bool = False, -) -> Callable: +) -> Callable[..., Any]: """ Decorator that registers a command prior to test execution. @@ -200,7 +200,7 @@ def calls_command( def calls_player_commands( player_ids: Sequence[int] = (1, 2), *additional: CallCommand -) -> Callable: +) -> Callable[..., Any]: """ Decorator that registers player commands and any optional additional commands. """ @@ -223,7 +223,7 @@ def calls_player_commands( return calls_commands(*commands) -def calls_group_commands(*additional: CallCommand) -> Callable: +def calls_group_commands(*additional: CallCommand) -> Callable[..., Any]: commands = [ CallCommand("group.get_groups"), CallCommand("group.get_volume", {c.ATTR_GROUP_ID: 1}), @@ -456,14 +456,14 @@ def is_match( self.match_count += 1 return True - async def get_response(self, query: dict) -> list[str]: + async def get_response(self, query: dict[str, str]) -> list[str]: """Get the response body.""" responses = [] for fixture in self.responses: responses.append(await self._get_response(fixture, query)) return responses - async def _get_response(self, response: str, query: dict) -> str: + async def _get_response(self, response: str, query: dict[str, str]) -> str: response = await get_fixture(response) keys = { c.ATTR_PLAYER_ID: "{player_id}", diff --git a/tests/conftest.py b/tests/conftest.py index 16385b4..a0e01ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,9 @@ from syrupy.assertion import SnapshotAssertion from pyheos.group import HeosGroup -from pyheos.heos import Heos, HeosOptions +from pyheos.heos import Heos from pyheos.media import MediaItem, MediaMusicSource +from pyheos.options import HeosOptions from pyheos.player import HeosPlayer from pyheos.types import LineOutLevelType, NetworkType from tests.common import MediaItems, MediaMusicSources @@ -52,7 +53,7 @@ async def heos_fixture(mock_device: MockHeosDevice) -> AsyncGenerator[Heos]: @pytest.fixture -def handler() -> Callable: +def handler() -> Callable[..., Any]: """Fixture handler to mock in the dispatcher.""" def target(*args: Any, **kwargs: Any) -> None: @@ -65,7 +66,7 @@ def target(*args: Any, **kwargs: Any) -> None: @pytest.fixture -def async_handler() -> Callable[..., Coroutine]: +def async_handler() -> Callable[..., Coroutine[Any, Any, None]]: """Fixture async handler to mock in the dispatcher.""" async def target(*args: Any, **kwargs: Any) -> None: diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index d5ace16..ba88942 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -9,7 +9,7 @@ from pyheos.dispatch import Dispatcher -async def test_connect(handler: Callable) -> None: +async def test_connect(handler: Callable[..., Any]) -> None: """Tests the connect function.""" # Arrange dispatcher = Dispatcher() @@ -19,7 +19,7 @@ async def test_connect(handler: Callable) -> None: assert handler in dispatcher.signals["TEST"] -async def test_disconnect(handler: Callable) -> None: +async def test_disconnect(handler: Callable[..., Any]) -> None: """Tests the disconnect function.""" # Arrange dispatcher = Dispatcher() @@ -30,7 +30,7 @@ async def test_disconnect(handler: Callable) -> None: assert handler not in dispatcher.signals["TEST"] -async def test_disconnect_all(handler: Callable) -> None: +async def test_disconnect_all(handler: Callable[..., Any]) -> None: """Tests the disconnect all function.""" # Arrange dispatcher = Dispatcher() @@ -46,7 +46,7 @@ async def test_disconnect_all(handler: Callable) -> None: assert handler not in dispatcher.signals["TEST3"] -async def test_already_disconnected(handler: Callable) -> None: +async def test_already_disconnected(handler: Callable[..., Any]) -> None: """Tests that disconnect can be called more than once.""" # Arrange dispatcher = Dispatcher() @@ -58,7 +58,7 @@ async def test_already_disconnected(handler: Callable) -> None: assert handler not in dispatcher.signals["TEST"] -async def test_send_async_handler(async_handler: Callable) -> None: +async def test_send_async_handler(async_handler: Callable[..., Any]) -> None: """Tests sending to async handlers.""" # Arrange dispatcher = Dispatcher() @@ -105,7 +105,7 @@ async def async_handler_exception() -> None: assert "Exception in target callback:" not in caplog.text -async def test_send_async_partial_handler(async_handler: Callable) -> None: +async def test_send_async_partial_handler(async_handler: Callable[..., Any]) -> None: """Tests sending to async handlers.""" # Arrange partial = functools.partial(async_handler) @@ -117,7 +117,7 @@ async def test_send_async_partial_handler(async_handler: Callable) -> None: assert async_handler.fired # type: ignore[attr-defined] -async def test_send(handler: Callable) -> None: +async def test_send(handler: Callable[..., Any]) -> None: """Tests sending to async handlers.""" # Arrange dispatcher = Dispatcher() @@ -130,13 +130,13 @@ async def test_send(handler: Callable) -> None: assert handler.args[0] == args # type: ignore[attr-defined] -async def test_custom_connect_and_send(handler: Callable) -> None: +async def test_custom_connect_and_send(handler: Callable[..., Any]) -> None: """Tests using the custom connect and send implementations.""" # Arrange test_signal = "PREFIX_TEST" stored_target = None - def connect(signal: str, target: Callable) -> Callable: + def connect(signal: str, target: Callable[..., Any]) -> Callable[..., Any]: assert signal == test_signal nonlocal stored_target stored_target = target diff --git a/tests/test_heos.py b/tests/test_heos.py index 23f905b..388f8dd 100644 --- a/tests/test_heos.py +++ b/tests/test_heos.py @@ -36,9 +36,10 @@ HeosError, ) from pyheos.group import HeosGroup -from pyheos.heos import Heos, HeosOptions, PlayerUpdateResult +from pyheos.heos import Heos from pyheos.media import MediaItem, MediaMusicSource -from pyheos.player import HeosPlayer +from pyheos.options import HeosOptions +from pyheos.player import HeosPlayer, PlayerUpdateResult from pyheos.types import ( AddCriteriaType, ConnectionState, diff --git a/tests/test_heos_browse.py b/tests/test_heos_browse.py index 0ed2186..c53e0db 100644 --- a/tests/test_heos_browse.py +++ b/tests/test_heos_browse.py @@ -26,8 +26,9 @@ SERVICE_OPTION_THUMBS_DOWN, SERVICE_OPTION_THUMBS_UP, ) -from pyheos.heos import Heos, HeosOptions +from pyheos.heos import Heos from pyheos.media import MediaMusicSource +from pyheos.options import HeosOptions from tests import calls_command, value from tests.common import MediaMusicSources diff --git a/tests/test_heos_callback.py b/tests/test_heos_callback.py index 6bd459c..0324d48 100644 --- a/tests/test_heos_callback.py +++ b/tests/test_heos_callback.py @@ -2,7 +2,8 @@ from typing import Any -from pyheos.heos import Heos, HeosOptions +from pyheos.heos import Heos +from pyheos.options import HeosOptions from pyheos.types import SignalHeosEvent, SignalType