Skip to content

Commit

Permalink
Merge pull request #90 from andrewsayre/use_syrupy
Browse files Browse the repository at this point in the history
Add snapshot testing
  • Loading branch information
andrewsayre authored Jan 24, 2025
2 parents 8a652dc + 79311fe commit bff198c
Show file tree
Hide file tree
Showing 25 changed files with 3,277 additions and 523 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ This class encapsulates the options and configuration for connecting to a HEOS s
#### `pyheos.HeosOptions(host, *, timeout, heart_beat, heart_beat_interval, dispatcher, auto_reconnect, auto_reconnect_delay, auto_reconnect_max_attempts, credentials)`

- `host: str`: A host name or IP address of a HEOS-capable device. This parameter is required.
- `timeout: float`: The timeout in seconds for opening a connectoin and issuing commands to the device. Default is `pyheos.const.DEFAULT_TIMEOUT = 10.0`. This parameter is required.
- `timeout: float`: The timeout in seconds for opening a connection and issuing commands to the device. Default is `pyheos.const.DEFAULT_TIMEOUT = 10.0`. This parameter is required.
- `heart_beat: bool`: Set to `True` to enable heart beat messages, `False` to disable. Used in conjunction with `heart_beat_delay`. The default is `True`.
- `heart_beat_interval: float`: The interval in seconds between heart beat messages. Used in conjunction with `heart_beat`. Default is `pyheos.const.DEFAULT_HEART_BEAT = 10.0`
- `events: bool`: Set to `True` to enable event updates, `False` to disable. The default is `True`.
Expand Down
15 changes: 15 additions & 0 deletions pyheos/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Define abstract base classes for HEOS."""

from abc import ABC
from typing import Any


class RemoveHeosFieldABC(ABC):
"""Define an abstract base class that removes the 'heos' from dataclass's fields list to prevent serialization."""

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
"""Post initialize the player."""
# Prevent the heos instance from being serialized
fields = self.__dataclass_fields__.copy() # type: ignore[has-type] # pylint: disable=access-member-before-definition
del fields["heos"]
self.__dataclass_fields__ = fields
3 changes: 2 additions & 1 deletion pyheos/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional

from pyheos.abc import RemoveHeosFieldABC
from pyheos.const import DEFAULT_STEP, EVENT_GROUP_VOLUME_CHANGED
from pyheos.dispatch import DisconnectType, EventCallbackType, callback_wrapper
from pyheos.message import HeosMessage
Expand All @@ -17,7 +18,7 @@


@dataclass
class HeosGroup:
class HeosGroup(RemoveHeosFieldABC):
"""A group of players."""

name: str
Expand Down
2 changes: 1 addition & 1 deletion pyheos/heos.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def create_and_connect(cls, host: str, **kwargs: Any) -> "Heos":
Args:
host: A host name or IP address of a HEOS-capable device.
timeout: The timeout in seconds for opening a connectoin and issuing commands to the device.
timeout: The timeout in seconds for opening a connection and issuing commands to the device.
events: Set to True to enable event updates, False to disable. The default is True.
all_progress_events: Set to True to receive media progress events, False to only receive media changed events. The default is True.
dispatcher: The dispatcher instance to use for event callbacks. If not provided, an internally created instance will be used.
Expand Down
33 changes: 3 additions & 30 deletions pyheos/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Optional, cast

from pyheos import command as c
from pyheos.abc import RemoveHeosFieldABC
from pyheos.message import HeosMessage
from pyheos.types import AddCriteriaType, MediaType

Expand Down Expand Up @@ -39,7 +40,7 @@ def from_data(cls, data: dict[str, str]) -> "QueueItem":


@dataclass(init=False)
class Media:
class Media(RemoveHeosFieldABC):
"""
Define a base media item.
Expand Down Expand Up @@ -90,18 +91,6 @@ def _update_from_data(self, data: dict[str, Any]) -> None:
self.available = data[c.ATTR_AVAILABLE] == c.VALUE_TRUE
self.service_username = data.get(c.ATTR_SERVICE_USER_NAME)

def clone(self) -> "MediaMusicSource":
"""Create a new instance from the current instance."""
return MediaMusicSource(
source_id=self.source_id,
name=self.name,
type=self.type,
image_url=self.image_url,
available=self.available,
service_username=self.service_username,
heos=self.heos,
)

async def refresh(self) -> None:
"""Refresh the instance with the latest data."""
assert self.heos, "Heos instance not set"
Expand Down Expand Up @@ -162,22 +151,6 @@ def from_data(
heos=heos,
)

def clone(self) -> "MediaItem":
return MediaItem(
source_id=self.source_id,
name=self.name,
type=self.type,
image_url=self.image_url,
playable=self.playable,
browsable=self.browsable,
container_id=self.container_id,
media_id=self.media_id,
artist=self.artist,
album=self.album,
album_id=self.album_id,
heos=self.heos,
)

async def browse(
self,
range_start: int | None = None,
Expand Down Expand Up @@ -246,7 +219,7 @@ def __from_data(context: str, data: dict[str, str]) -> "ServiceOption":


@dataclass
class BrowseResult:
class BrowseResult(RemoveHeosFieldABC):
"""Define the result of a browse operation."""

count: int
Expand Down
4 changes: 2 additions & 2 deletions pyheos/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class HeosCommand:
parameters: dict[str, Any] = field(default_factory=dict)

def __repr__(self) -> str:
"""Get a string representaton of the message."""
"""Get a string representation of the message."""
return self.uri_masked

@cached_property
Expand Down Expand Up @@ -79,7 +79,7 @@ class HeosMessage:
)

def __repr__(self) -> str:
"""Get a string representaton of the message."""
"""Get a string representation of the message."""
return self._raw_message or f"{self.command} {self.message}"

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion pyheos/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HeosOptions:
Args:
host: A host name or IP address of a HEOS-capable device.
timeout: The timeout in seconds for opening a connectoin and issuing commands to the device.
timeout: The timeout in seconds for opening a connection and issuing commands to the device.
events: Set to True to enable event updates, False to disable. The default is True.
heart_beat: Set to True to enable heart beat messages, False to disable. Used in conjunction with heart_beat_delay. The default is True.
heart_beat_interval: The interval in seconds between heart beat messages. Used in conjunction with heart_beat.
Expand Down
3 changes: 2 additions & 1 deletion pyheos/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Final, Optional, cast

from pyheos.abc import RemoveHeosFieldABC
from pyheos.command import optional_int, parse_enum
from pyheos.dispatch import DisconnectType, EventCallbackType, callback_wrapper
from pyheos.media import MediaItem, QueueItem, ServiceOption
Expand Down Expand Up @@ -184,7 +185,7 @@ def _from_data(data: HeosMessage) -> "PlayMode":


@dataclass
class HeosPlayer:
class HeosPlayer(RemoveHeosFieldABC):
"""Define a HEOS player."""

name: str = field(repr=True, hash=False, compare=False)
Expand Down
11 changes: 6 additions & 5 deletions pyheos/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Final, Optional, cast

from pyheos import command as c
from pyheos.abc import RemoveHeosFieldABC
from pyheos.media import MediaItem
from pyheos.message import HeosMessage

Expand Down Expand Up @@ -38,7 +39,7 @@ def _from_data(data: dict[str, str]) -> "SearchCriteria":


@dataclass
class SearchResult:
class SearchResult(RemoveHeosFieldABC):
"""Define the search result."""

source_id: int
Expand Down Expand Up @@ -71,7 +72,7 @@ def _from_message(message: HeosMessage, heos: "Heos") -> "SearchResult":


@dataclass
class MultiSearchResult:
class MultiSearchResult(RemoveHeosFieldABC):
"""Define the results of a multi-search."""

source_ids: Sequence[int]
Expand All @@ -91,14 +92,14 @@ def _from_message(message: HeosMessage, heos: "Heos") -> "MultiSearchResult":
"""Create a new instance from a message."""
source_ids = message.get_message_value(c.ATTR_SOURCE_ID).split(",")
criteria_ids = message.get_message_value(c.ATTR_SEARCH_CRITERIA_ID).split(",")
statisics = SearchStatistic._from_string(
statistics = SearchStatistic._from_string(
message.get_message_value(c.ATTR_STATS)
)
items: list[MediaItem] = []
# In order to determine the source_id of the result, we match up the index with how many items were returned for a given source
payload = cast(list[dict[str, str]], message.payload)
index = 0
for stat in statisics:
for stat in statistics:
assert stat.returned is not None
for _ in range(stat.returned):
items.append(
Expand All @@ -114,7 +115,7 @@ def _from_message(message: HeosMessage, heos: "Heos") -> "MultiSearchResult":
returned=message.get_message_value_int(c.ATTR_RETURNED),
count=message.get_message_value_int(c.ATTR_COUNT),
items=items,
statistics=statisics,
statistics=statistics,
errors=SearchStatistic._from_string(
message.get_message_value(c.ATTR_ERROR_NUMBER)
),
Expand Down
29 changes: 12 additions & 17 deletions pyheos/system.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Define the System module."""

from dataclasses import dataclass
from functools import cached_property
from dataclasses import dataclass, field

from pyheos import command as c
from pyheos.types import NetworkType
Expand Down Expand Up @@ -51,18 +50,14 @@ class HeosSystem:
signed_in_username: str | None
host: HeosHost
hosts: list[HeosHost]

@property
def is_signed_in(self) -> bool:
"""Return whether the system is signed in."""
return self.signed_in_username is not None

@cached_property
def preferred_hosts(self) -> list[HeosHost]:
"""Return the preferred hosts."""
return list([host for host in self.hosts if host.network == NetworkType.WIRED])

@cached_property
def connected_to_preferred_host(self) -> bool:
"""Return whether the system is connected to a host."""
return self.host in self.preferred_hosts
is_signed_in: bool = field(init=False)
preferred_hosts: list[HeosHost] = field(init=False)
connected_to_preferred_host: bool = field(init=False)

def __post_init__(self) -> None:
"""Post initialize the system."""
self.is_signed_in = self.signed_in_username is not None
self.preferred_hosts = list(
[host for host in self.hosts if host.network == NetworkType.WIRED]
)
self.connected_to_preferred_host = self.host in self.preferred_hosts
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,4 @@ skip_empty = true
sort = "Name"

[tool.codespell]
skip = "./tests/fixtures/*"
skip = "./tests/fixtures/*,./tests/snapshots/*"
13 changes: 7 additions & 6 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
codespell==2.3.0
codespell==2.4.0
coveralls==4.0.1
mypy-dev==1.15.0a1
pydantic==2.10.4
mypy==1.14.1
pydantic==2.10.5
pylint==3.3.3
pylint-per-file-ignores==1.3.2
pylint-per-file-ignores==1.4.0
pytest==8.3.4
pytest-asyncio==0.25.1
pytest-asyncio==0.25.2
pytest-cov==6.0.0
pytest-timeout==2.3.1
ruff==0.8.6
ruff==0.9.3
syrupy==4.8.1
12 changes: 6 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ def assert_command_called(
if matcher.is_match(target_command, target_args, increment=False):
matcher.assert_called()
return
assert (
False
), f"Command was not registered: {target_command} with args {target_args}."
assert False, (
f"Command was not registered: {target_command} with args {target_args}."
)

async def _handle_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
Expand Down Expand Up @@ -483,9 +483,9 @@ async def _get_response(self, response: str, query: dict) -> str:

def assert_called(self) -> None:
"""Assert that the command was called."""
assert (
self.match_count
), f"Command {self.command} was not called with arguments {self._args}."
assert self.match_count, (
f"Command {self.command} was not called with arguments {self._args}."
)


class ConnectionLog:
Expand Down
Loading

0 comments on commit bff198c

Please sign in to comment.