Skip to content

Commit

Permalink
Merge pull request #94 from andrewsayre/strict_mypy
Browse files Browse the repository at this point in the history
Enable mypy strict typing
  • Loading branch information
andrewsayre authored Jan 26, 2025
2 parents f74b2d1 + 78514b7 commit e4cd44d
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
rev: v1.14.1
hooks:
- id: mypy
args: []
args: [--strict]
additional_dependencies:
- pydantic==2.10.5
- pylint==3.3.3
Expand Down
2 changes: 1 addition & 1 deletion pyheos/command/browse.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def get_music_sources(
HeosCommand(c.COMMAND_BROWSE_GET_SOURCES, params)
)
self._music_sources.clear()
for data in cast(Sequence[dict], message.payload):
for data in cast(Sequence[dict[str, Any]], message.payload):
source = MediaMusicSource.from_data(data, cast("Heos", self))
self._music_sources[source.source_id] = source
self._music_sources_loaded = True
Expand Down
4 changes: 2 additions & 2 deletions pyheos/command/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ async def get_groups(self, *, refresh: bool = False) -> dict[int, HeosGroup]:
References:
4.3.1 Get Groups"""
if not self._groups_loaded or refresh:
groups = {}
groups: dict[int, HeosGroup] = {}
result = await self._connection.command(HeosCommand(c.COMMAND_GET_GROUPS))
payload = cast(Sequence[dict], result.payload)
payload = cast(Sequence[dict[str, Any]], result.payload)
for data in payload:
group = HeosGroup._from_data(data, cast("Heos", self))
groups[group.group_id] = group
Expand Down
14 changes: 4 additions & 10 deletions pyheos/command/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,8 @@
from pyheos.command.connection import ConnectionMixin
from pyheos.media import QueueItem
from pyheos.message import HeosCommand
from pyheos.player import (
HeosNowPlayingMedia,
HeosPlayer,
PlayerUpdateResult,
PlayMode,
PlayState,
)
from pyheos.types import RepeatType
from pyheos.player import HeosNowPlayingMedia, HeosPlayer, PlayerUpdateResult, PlayMode
from pyheos.types import PlayState, RepeatType

if TYPE_CHECKING:
from pyheos.heos import Heos
Expand Down Expand Up @@ -102,7 +96,7 @@ async def load_players(self) -> PlayerUpdateResult:

players: dict[int, HeosPlayer] = {}
response = await self._connection.command(HeosCommand(c.COMMAND_GET_PLAYERS))
payload = cast(Sequence[dict], response.payload)
payload = cast(Sequence[dict[str, str]], response.payload)
existing = list(self._players.values())
for player_data in payload:
player_id = int(player_data[c.ATTR_PLAYER_ID])
Expand Down Expand Up @@ -463,7 +457,7 @@ async def player_get_quick_selects(self, player_id: int) -> dict[int, str]:
)
return {
int(data[c.ATTR_ID]): data[c.ATTR_NAME]
for data in cast(list[dict], result.payload)
for data in cast(list[dict[str, Any]], result.payload)
}

async def player_check_update(self, player_id: int) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pyheos/command/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def get_system_info(self) -> HeosSystem:
References:
4.2.1 Get Players"""
response = await self._connection.command(HeosCommand(c.COMMAND_GET_PLAYERS))
payload = cast(Sequence[dict], response.payload)
payload = cast(Sequence[dict[str, Any]], response.payload)
hosts = list([HeosHost._from_data(item) for item in payload])
host = next(host for host in hosts if host.ip_address == self._options.host)
return HeosSystem(self._signed_in_username, host, hosts)
22 changes: 11 additions & 11 deletions pyheos/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import logging
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Final
Expand Down Expand Up @@ -35,23 +35,23 @@ def __init__(self, host: str, *, timeout: float) -> None:
self._state: ConnectionState = ConnectionState.DISCONNECTED
self._writer: asyncio.StreamWriter | None = None
self._pending_command_event = ResponseEvent()
self._running_tasks: set[asyncio.Task] = set()
self._running_tasks: set[asyncio.Task[None]] = set()
self._last_activity: datetime = datetime.now()
self._command_lock = asyncio.Lock()

self._on_event_callbacks: list[Callable[[HeosMessage], Awaitable]] = []
self._on_connected_callbacks: list[Callable[[], Awaitable]] = []
self._on_disconnected_callbacks: list[Callable[[bool], Awaitable]] = []
self._on_event_callbacks: list[Callable[[HeosMessage], Awaitable[None]]] = []
self._on_connected_callbacks: list[Callable[[], Awaitable[None]]] = []
self._on_disconnected_callbacks: list[Callable[[bool], Awaitable[None]]] = []
self._on_command_error_callbacks: list[
Callable[[CommandFailedError], Awaitable]
Callable[[CommandFailedError], Awaitable[None]]
] = []

@property
def state(self) -> ConnectionState:
"""Get the current state of the connection."""
return self._state

def add_on_event(self, callback: Callable[[HeosMessage], Awaitable]) -> None:
def add_on_event(self, callback: Callable[[HeosMessage], Awaitable[None]]) -> None:
"""Add a callback to be invoked when an event is received."""
self._on_event_callbacks.append(callback)

Expand All @@ -60,7 +60,7 @@ async def _on_event(self, message: HeosMessage) -> None:
for callback in self._on_event_callbacks:
await callback(message)

def add_on_connected(self, callback: Callable[[], Awaitable]) -> None:
def add_on_connected(self, callback: Callable[[], Awaitable[None]]) -> None:
"""Add a callback to be invoked when connected."""
self._on_connected_callbacks.append(callback)

Expand All @@ -69,7 +69,7 @@ async def _on_connected(self) -> None:
for callback in self._on_connected_callbacks:
await callback()

def add_on_disconnected(self, callback: Callable[[bool], Awaitable]) -> None:
def add_on_disconnected(self, callback: Callable[[bool], Awaitable[None]]) -> None:
"""Add a callback to be invoked when connected."""
self._on_disconnected_callbacks.append(callback)

Expand All @@ -79,7 +79,7 @@ async def _on_disconnected(self, due_to_error: bool = False) -> None:
await callback(due_to_error)

def add_on_command_error(
self, callback: Callable[[CommandFailedError], Awaitable]
self, callback: Callable[[CommandFailedError], Awaitable[None]]
) -> None:
"""Add a callback to be invoked when a command error occurs."""
self._on_command_error_callbacks.append(callback)
Expand All @@ -89,7 +89,7 @@ async def _on_command_error(self, error: CommandFailedError) -> None:
for callback in self._on_command_error_callbacks:
await callback(error)

def _register_task(self, future: Coroutine) -> None:
def _register_task(self, future: Awaitable[None]) -> None:
"""Register a task that is running in the background, so it can be canceled and reset later."""
task = asyncio.ensure_future(future)
self._running_tasks.add(task)
Expand Down
30 changes: 13 additions & 17 deletions pyheos/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,20 @@
import logging
from collections import defaultdict
from collections.abc import Callable, Sequence
from typing import Any, Final, TypeVar
from typing import Any, Final

_LOGGER: Final = logging.getLogger(__name__)

TargetType = Callable[..., Any]
DisconnectType = Callable[[], None]
ConnectType = Callable[[str, TargetType], DisconnectType]
SendType = Callable[..., Sequence[asyncio.Future]]

TEvent = TypeVar("TEvent", bound=str)
TPlayerId = TypeVar("TPlayerId", bound=int)
TGroupId = TypeVar("TGroupId", bound=int)
SendType = Callable[..., Sequence[asyncio.Future[Any]]]

CallbackType = Callable[[], Any]
EventCallbackType = Callable[[TEvent], Any]
ControllerEventCallbackType = Callable[[TEvent, Any], Any]
PlayerEventCallbackType = Callable[[TPlayerId, TEvent], Any]
GroupEventCallbackType = Callable[[TGroupId, TEvent], Any]
EventCallbackType = Callable[[str], Any]
ControllerEventCallbackType = Callable[[str, Any], Any]
PlayerEventCallbackType = Callable[[int, str], Any]
GroupEventCallbackType = Callable[[int, str], Any]


def _is_coroutine_function(func: TargetType) -> bool:
Expand Down Expand Up @@ -81,20 +77,20 @@ def __init__(
) -> None:
"""Create a new instance of the dispatch component."""
self._signal_prefix = signal_prefix
self._signals: dict[str, list] = defaultdict(list)
self._signals: dict[str, list[TargetType]] = defaultdict(list)
self._loop = loop or asyncio.get_running_loop()
self._connect = connect or self._default_connect
self._send = send or self._default_send
self._disconnects: list[Callable] = []
self._running_tasks: set[asyncio.Future] = set()
self._disconnects: list[DisconnectType] = []
self._running_tasks: set[asyncio.Future[Any]] = set()

def connect(self, signal: str, target: TargetType) -> DisconnectType:
"""Connect function to signal. Must be ran in the event loop."""
disconnect = self._connect(self._signal_prefix + signal, target)
self._disconnects.append(disconnect)
return disconnect

def send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]:
def send(self, signal: str, *args: Any) -> Sequence[asyncio.Future[Any]]:
"""Fire a signal. Must be ran in the event loop."""
return self._send(self._signal_prefix + signal, *args)

Expand Down Expand Up @@ -137,7 +133,7 @@ def remove_dispatcher() -> None:

return remove_dispatcher

def _log_target_exception(self, future: asyncio.Future) -> None:
def _log_target_exception(self, future: asyncio.Future[Any]) -> None:
"""Log the exception from the target, if raised."""
if not future.cancelled() and future.exception():
_LOGGER.exception(
Expand All @@ -146,7 +142,7 @@ def _log_target_exception(self, future: asyncio.Future) -> None:
exc_info=future.exception(),
)

def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]:
def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future[Any]]:
"""Fire a signal. Must be ran in the event loop."""
targets = self._signals[signal]
futures = []
Expand All @@ -158,7 +154,7 @@ def _default_send(self, signal: str, *args: Any) -> Sequence[asyncio.Future]:
futures.append(task)
return futures

def _call_target(self, target: Callable, *args: Any) -> asyncio.Future:
def _call_target(self, target: TargetType, *args: Any) -> asyncio.Future[Any]:
if _is_coroutine_function(target):
return self._loop.create_task(target(*args))
return self._loop.run_in_executor(None, target, *args)
Expand Down
2 changes: 1 addition & 1 deletion pyheos/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _from_message(
items=list(
[
MediaItem.from_data(item, source_id, container_id, heos)
for item in cast(Sequence[dict], message.payload)
for item in cast(Sequence[dict[str, Any]], message.payload)
]
),
options=ServiceOption._from_options(message.options),
Expand Down
3 changes: 2 additions & 1 deletion pyheos/system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Define the System module."""

from dataclasses import dataclass, field
from typing import Any

from pyheos import command as c
from pyheos.types import NetworkType
Expand All @@ -21,7 +22,7 @@ class HeosHost:
network: NetworkType

@staticmethod
def _from_data(data: dict[str, str]) -> "HeosHost":
def _from_data(data: dict[str, Any]) -> "HeosHost":
"""Create a HeosHost object from a dictionary.
Args:
Expand Down
16 changes: 8 additions & 8 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def get_value(self, args: dict[str, Any]) -> Any:
return arg_value


def calls_commands(*commands: CallCommand) -> Callable:
def calls_commands(*commands: CallCommand) -> Callable[..., Any]:
"""
Decorator that registers commands prior to test execution.
Args:
commands: The commands to register.
"""

def wrapper(func: Callable) -> Callable:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
# Build a list of commands that match the when conditions
Expand Down Expand Up @@ -118,7 +118,7 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any:
)

# Register commands
assert_list: list[Callable] = []
assert_list: list[Callable[..., None]] = []

for command in matched_commands:
# Get the fixture command
Expand Down Expand Up @@ -174,7 +174,7 @@ def calls_command(
when: dict[str, Any] | None = None,
replace: bool = False,
add_command_under_process: bool = False,
) -> Callable:
) -> Callable[..., Any]:
"""
Decorator that registers a command prior to test execution.
Expand All @@ -200,7 +200,7 @@ def calls_command(

def calls_player_commands(
player_ids: Sequence[int] = (1, 2), *additional: CallCommand
) -> Callable:
) -> Callable[..., Any]:
"""
Decorator that registers player commands and any optional additional commands.
"""
Expand All @@ -223,7 +223,7 @@ def calls_player_commands(
return calls_commands(*commands)


def calls_group_commands(*additional: CallCommand) -> Callable:
def calls_group_commands(*additional: CallCommand) -> Callable[..., Any]:
commands = [
CallCommand("group.get_groups"),
CallCommand("group.get_volume", {c.ATTR_GROUP_ID: 1}),
Expand Down Expand Up @@ -456,14 +456,14 @@ def is_match(
self.match_count += 1
return True

async def get_response(self, query: dict) -> list[str]:
async def get_response(self, query: dict[str, str]) -> list[str]:
"""Get the response body."""
responses = []
for fixture in self.responses:
responses.append(await self._get_response(fixture, query))
return responses

async def _get_response(self, response: str, query: dict) -> str:
async def _get_response(self, response: str, query: dict[str, str]) -> str:
response = await get_fixture(response)
keys = {
c.ATTR_PLAYER_ID: "{player_id}",
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from syrupy.assertion import SnapshotAssertion

from pyheos.group import HeosGroup
from pyheos.heos import Heos, HeosOptions
from pyheos.heos import Heos
from pyheos.media import MediaItem, MediaMusicSource
from pyheos.options import HeosOptions
from pyheos.player import HeosPlayer
from pyheos.types import LineOutLevelType, NetworkType
from tests.common import MediaItems, MediaMusicSources
Expand Down Expand Up @@ -52,7 +53,7 @@ async def heos_fixture(mock_device: MockHeosDevice) -> AsyncGenerator[Heos]:


@pytest.fixture
def handler() -> Callable:
def handler() -> Callable[..., Any]:
"""Fixture handler to mock in the dispatcher."""

def target(*args: Any, **kwargs: Any) -> None:
Expand All @@ -65,7 +66,7 @@ def target(*args: Any, **kwargs: Any) -> None:


@pytest.fixture
def async_handler() -> Callable[..., Coroutine]:
def async_handler() -> Callable[..., Coroutine[Any, Any, None]]:
"""Fixture async handler to mock in the dispatcher."""

async def target(*args: Any, **kwargs: Any) -> None:
Expand Down
Loading

0 comments on commit e4cd44d

Please sign in to comment.