From 161e29506a84c65a35920621fb336d9b24d939e7 Mon Sep 17 00:00:00 2001 From: Jorim Tielemans Date: Tue, 28 Jan 2025 01:32:46 +0100 Subject: [PATCH] Enhance iRail class documentation and refactor code for clarity (#45) --- pyrail/irail.py | 51 ++++---- pyrail/models.py | 231 +++++++++++++++------------------- tests/test_irail.py | 298 +++++++++++++++++++++----------------------- 3 files changed, 270 insertions(+), 310 deletions(-) diff --git a/pyrail/irail.py b/pyrail/irail.py index 4ac055f..13fbdc9 100644 --- a/pyrail/irail.py +++ b/pyrail/irail.py @@ -6,7 +6,7 @@ import logging import time from types import TracebackType -from typing import Any, Dict, Type +from typing import Any, Type from aiohttp import ClientError, ClientResponse, ClientSession @@ -45,7 +45,7 @@ class iRail: # Available iRail API endpoints and their parameter requirements. # Each endpoint is configured with required parameters, optional parameters, and XOR # parameter groups (where exactly one parameter from the group must be provided). - endpoints: Dict[str, Dict[str, Any]] = { + endpoints: dict[str, dict[str, Any]] = { "stations": {}, "liveboard": {"xor": ["station", "id"], "optional": ["date", "time", "arrdep", "alerts"]}, "connections": { @@ -71,8 +71,8 @@ def __init__(self, lang: str = "en", session: ClientSession | None = None) -> No self.last_request_time: float = time.time() self.lock: Lock = Lock() self.session: ClientSession | None = session - self._owns_session = session is None # Track ownership - self.etag_cache: Dict[str, str] = {} + self._owns_session = session is None # Track ownership + self.etag_cache: dict[str, str] = {} logger.info("iRail instance created") async def __aenter__(self) -> "iRail": @@ -130,11 +130,10 @@ def clear_etag_cache(self) -> None: logger.info("ETag cache cleared") def _refill_tokens(self) -> None: - """Refill tokens for rate limiting based on elapsed time. - - This method refills both standard tokens (max 3) and burst tokens (max 5) - using a token bucket algorithm. The refill rate is 3 tokens per second. + """Refill tokens for rate limiting using a token bucket algorithm. + - Standard tokens: Refill rate of 3 tokens/second, max 3 tokens. + - Burst tokens: Refilled only when standard tokens are full, max 5 tokens. """ logger.debug("Refilling tokens") current_time: float = time.time() @@ -151,9 +150,9 @@ def _refill_tokens(self) -> None: async def _handle_rate_limit(self) -> None: """Handle rate limiting using a token bucket algorithm. - The implementation uses two buckets: - - Normal bucket: 3 tokens/second - - Burst bucket: 5 tokens/second + - Standard tokens: 3 requests/second. + - Burst tokens: Additional 5 requests/second for spikes. + - Waits and refills tokens if both are exhausted. """ logger.debug("Handling rate limit") self._refill_tokens() @@ -169,18 +168,18 @@ async def _handle_rate_limit(self) -> None: else: self.tokens -= 1 - def _add_etag_header(self, method: str) -> Dict[str, str]: + def _add_etag_header(self, method: str) -> dict[str, str]: """Add ETag header for the given method if a cached ETag is available. Args: method (str): The API endpoint for which the header is being generated. Returns: - Dict[str, str]: A dictionary containing HTTP headers, including the ETag header + dict[str, str]: A dictionary containing HTTP headers, including the ETag header if a cached value exists. """ - headers: Dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"} + headers: dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"} if method in self.etag_cache: logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method]) headers["If-None-Match"] = self.etag_cache[method] @@ -226,12 +225,12 @@ def _validate_time(self, time: str | None) -> bool: logger.error("Invalid time format. Expected HHMM (e.g., 1430 for 2:30 PM), got: %s", time) return False - def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> bool: + def _validate_params(self, method: str, params: dict[str, Any] | None = None) -> bool: """Validate parameters for a specific iRail API endpoint based on predefined requirements. Args: method (str): The API endpoint method to validate parameters for. - params (Dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None. + params (dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None. Returns: bool: True if parameters are valid, False otherwise. @@ -289,12 +288,12 @@ def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> return True - async def _handle_success_response(self, response: ClientResponse, method: str) -> Dict[str, Any] | None: + async def _handle_success_response(self, response: ClientResponse, method: str) -> dict[str, Any] | None: """Handle a successful API response.""" if "Etag" in response.headers: self.etag_cache[method] = response.headers["Etag"] try: - json_data: Dict[str, Any] | None = await response.json() + json_data: dict[str, Any] | None = await response.json() if not json_data: logger.warning("Empty response received") return json_data @@ -303,8 +302,8 @@ async def _handle_success_response(self, response: ClientResponse, method: str) return None async def _handle_response( - self, response: ClientResponse, method: str, args: Dict[str, Any] | None = None - ) -> Dict[str, Any] | None: + self, response: ClientResponse, method: str, args: dict[str, Any] | None = None + ) -> dict[str, Any] | None: """Handle the API response based on status code.""" if response.status == 429: retry_after: int = int(response.headers.get("Retry-After", 1)) @@ -326,7 +325,7 @@ async def _handle_response( logger.error("Request failed with status code: %s, response: %s", response.status, await response.text()) return None - async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> Dict[str, Any] | None: + async def _do_request(self, method: str, args: dict[str, Any] | None = None) -> dict[str, Any] | None: """Send an asynchronous request to the specified iRail API endpoint. This method handles API requests with rate limiting, parameter validation, @@ -368,7 +367,7 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> if args: params.update(args) - request_headers: Dict[str, str] = self._add_etag_header(method) + request_headers: dict[str, str] = self._add_etag_header(method) try: async with self.session.get(url, params=params, headers=request_headers) as response: @@ -429,7 +428,7 @@ async def get_liveboard( print(f"Liveboard for Brussels-South: {liveboard}") """ - extra_params: Dict[str, Any] = { + extra_params: dict[str, Any] = { "station": station, "id": id, "date": date, @@ -473,7 +472,7 @@ async def get_connections( print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}") """ - extra_params: Dict[str, Any] = { + extra_params: dict[str, Any] = { "from": from_station, "to": to_station, "date": date, @@ -504,7 +503,7 @@ async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = Fal vehicle_info = await client.get_vehicle("BE.NMBS.IC1832") """ - extra_params: Dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"} + extra_params: dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"} vehicle_response_dict = await self._do_request( "vehicle", {k: v for k, v in extra_params.items() if v is not None} ) @@ -527,7 +526,7 @@ async def get_composition(self, id: str, data: str | None = None) -> Composition composition = await client.get_composition('S51507') """ - extra_params: Dict[str, str | None] = {"id": id, "data": data} + extra_params: dict[str, str | None] = {"id": id, "data": data} composition_response_dict = await self._do_request( "composition", {k: v for k, v in extra_params.items() if v is not None} ) diff --git a/pyrail/models.py b/pyrail/models.py index f4b57af..5460117 100644 --- a/pyrail/models.py +++ b/pyrail/models.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import List +from typing import Any from mashumaro import field_options from mashumaro.mixins.orjson import DataClassORJSONMixin @@ -19,13 +19,6 @@ def _str_to_bool(strbool: str) -> bool: return strbool == "1" -class Orientation(Enum): - """Enum for the orientation of the material type of a train unit, either 'LEFT' or 'RIGHT'.""" - - LEFT = "LEFT" - RIGHT = "RIGHT" - - class DisturbanceType(Enum): """Enum for the type of disturbance, either 'disturbance' or 'planned'.""" @@ -33,14 +26,28 @@ class DisturbanceType(Enum): PLANNED = "planned" +class OccupancyName(Enum): + """Enum for the occupancy, either 'low', 'medium', 'high', or 'unknown'.""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + UNKNOWN = "unknown" + + +class Orientation(Enum): + """Enum for the orientation of the material type of a train unit, either 'LEFT' or 'RIGHT'.""" + + LEFT = "LEFT" + RIGHT = "RIGHT" + + @dataclass class ApiResponse(DataClassORJSONMixin): """Base class for API responses, including schema version and timestamp.""" version: str # Version of the response schema - timestamp: datetime = field( - metadata=field_options(deserialize=_timestamp_to_datetime) - ) # Timestamp of the response + timestamp: datetime = field(metadata=field_options(deserialize=_timestamp_to_datetime)) # Timestamp of the response @dataclass @@ -59,7 +66,7 @@ class StationDetails(DataClassORJSONMixin): class StationsApiResponse(ApiResponse): """Holds a list of station objects returned by the 'stations' endpoint.""" - stations: List[StationDetails] = field( + stations: list[StationDetails] = field( metadata=field_options(alias="station"), default_factory=list ) # List of stations information @@ -90,7 +97,7 @@ class Occupancy(DataClassORJSONMixin): """Represents occupancy details for a specific departure.""" at_id: str = field(metadata=field_options(alias="@id")) # Identifier for the occupancy level - name: str # Occupancy level (e.g., low, high) + name: OccupancyName # Occupancy level (e.g., low, high) @dataclass @@ -100,9 +107,7 @@ class LiveboardDeparture(DataClassORJSONMixin): id: str # ID of the departure station: str # Station name station_info: StationDetails = field(metadata=field_options(alias="stationinfo")) # Detailed station info - time: datetime = field( - metadata=field_options(deserialize=_timestamp_to_datetime) - ) # Departure time (timestamp) + time: datetime = field(metadata=field_options(deserialize=_timestamp_to_datetime)) # Departure time (timestamp) delay: int # Delay in seconds canceled: bool = field(metadata=field_options(deserialize=_str_to_bool)) # Whether the departure is canceled left: bool = field(metadata=field_options(deserialize=_str_to_bool)) # Whether the train has left @@ -117,14 +122,6 @@ class LiveboardDeparture(DataClassORJSONMixin): departure_connection: str = field(metadata=field_options(alias="departureConnection")) # Departure connection link -@dataclass -class LiveboardDepartures(DataClassORJSONMixin): - """Holds the number of departures and a list of detailed departure information.""" - - number: int # Number of departures - departure: List[LiveboardDeparture] = field(default_factory=list) # List of departure details - - @dataclass class LiveboardArrival(DataClassORJSONMixin): """Details of a single arrival in the liveboard response.""" @@ -132,9 +129,7 @@ class LiveboardArrival(DataClassORJSONMixin): id: str # ID of the arrival station: str # Station name station_info: StationDetails = field(metadata=field_options(alias="stationinfo")) # Detailed station info - time: datetime = field( - metadata=field_options(deserialize=_timestamp_to_datetime) - ) # Arrival time (timestamp) + time: datetime = field(metadata=field_options(deserialize=_timestamp_to_datetime)) # Arrival time (timestamp) delay: int # Delay in seconds canceled: bool # Whether the arrival is canceled arrived: bool # Whether the train has arrived @@ -147,23 +142,30 @@ class LiveboardArrival(DataClassORJSONMixin): @dataclass -class LiveboardArrivals(DataClassORJSONMixin): - """Holds the number of arrivals and a list of detailed arrival information.""" +class LiveboardApiResponse(ApiResponse): + """Represents a liveboard response containing station details and departure/arrival information.""" - number: int # Number of arrivals - arrival: List[LiveboardArrival] = field(default_factory=list) # List of arrival details + station: str # Name of the station + station_info: StationDetails = field(metadata=field_options(alias="stationinfo")) # Detailed station info + departures: list[LiveboardDeparture] | None = field(default=None) # Departures information + arrivals: list[LiveboardArrival] | None = field(default=None) # Arrivals information + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Pre-deserialization hook to safely flatten departures and arrivals.""" + # Safely flatten departures + if "departures" in d and d["departures"] is not None: + d["departures"] = d["departures"].get("departure", []) + else: + d["departures"] = None # Set to None if missing or null -@dataclass -class LiveboardApiResponse(ApiResponse): - """Represents a liveboard response containing station details and departures.""" + # Safely flatten arrivals + if "arrivals" in d and d["arrivals"] is not None: + d["arrivals"] = d["arrivals"].get("arrival", []) + else: + d["arrivals"] = None # Set to None if missing or null - station: str # Name of the station - station_info: StationDetails = field( - metadata=field_options(alias="stationinfo") - ) # Detailed station info - departures: LiveboardDepartures | None = field(default=None) # Departures information - arrivals: LiveboardArrivals | None = field(default=None) # Arrivals information + return d @dataclass @@ -196,14 +198,6 @@ class ConnectionStop(DataClassORJSONMixin): platform_info: PlatformInfo = field(metadata=field_options(alias="platforminfo")) # Detailed platform info -@dataclass -class ConnectionStops(DataClassORJSONMixin): - """Holds the number of stops and a list of detailed stop information for connections.""" - - number: int # Number of stops - stop: List[ConnectionStop] = field(default_factory=list) # List of stop details - - @dataclass class Direction(DataClassORJSONMixin): """Represents the direction of a train connection.""" @@ -218,15 +212,12 @@ class ConnectionDeparture(DataClassORJSONMixin): delay: int # Delay in seconds station: str # Station name station_info: StationDetails = field(metadata=field_options(alias="stationinfo")) # Detailed station info - time: datetime = field( - metadata=field_options(deserialize=_timestamp_to_datetime) - ) # Departure time (timestamp) + time: datetime = field(metadata=field_options(deserialize=_timestamp_to_datetime)) # Departure time (timestamp) vehicle: str # Vehicle identifier vehicle_info: VehicleInfo = field(metadata=field_options(alias="vehicleinfo")) # Vehicle details platform: str # Platform name platform_info: PlatformInfo = field(metadata=field_options(alias="platforminfo")) # Detailed platform info canceled: bool = field(metadata=field_options(deserialize=_str_to_bool)) # Whether the departure is canceled - stops: ConnectionStops # Stops along the journey departure_connection: str = field(metadata=field_options(alias="departureConnection")) # Departure connection link direction: Direction # Direction of the connection left: bool = field(metadata=field_options(deserialize=_str_to_bool)) # Whether the train has left @@ -234,6 +225,13 @@ class ConnectionDeparture(DataClassORJSONMixin): metadata=field_options(deserialize=_str_to_bool) ) # Indicates if the connection requires walking occupancy: Occupancy # Occupancy level + stops: list[ConnectionStop] = field(default_factory=list) # List of stop along the journey + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Extract 'stop' list from 'stops' before deserialization.""" + d["stops"] = d["stops"]["stop"] + return d @dataclass @@ -243,9 +241,7 @@ class ConnectionArrival(DataClassORJSONMixin): delay: int # Delay in seconds station: str # Station name station_info: StationDetails = field(metadata=field_options(alias="stationinfo")) # Detailed station info - time: datetime = field( - metadata=field_options(deserialize=_timestamp_to_datetime) - ) # Arrival time (timestamp) + time: datetime = field(metadata=field_options(deserialize=_timestamp_to_datetime)) # Arrival time (timestamp) vehicle: str # Vehicle identifier vehicle_info: VehicleInfo = field(metadata=field_options(alias="vehicleinfo")) # Vehicle details platform: str # Platform name @@ -273,14 +269,6 @@ class Via(DataClassORJSONMixin): vehicle_info: VehicleInfo = field(metadata=field_options(alias="vehicleinfo")) # Vehicle details -@dataclass -class Vias(DataClassORJSONMixin): - """Holds the number of vias and a list of detailed via information for connections.""" - - number: int # Number of vias - via: List[Via] = field(default_factory=list) # List of via details - - @dataclass class Remark(DataClassORJSONMixin): """Represents a single remark for a train connection, including type and content.""" @@ -290,14 +278,6 @@ class Remark(DataClassORJSONMixin): content: str # Remark content -@dataclass -class Remarks(DataClassORJSONMixin): - """Represents remarks for a train connection, including the type and content.""" - - number: int # Number of remarks - remark: List[Remark] = field(default_factory=list) # List of remarks - - @dataclass class Alert(DataClassORJSONMixin): """Represents a single alert for a train connection, including type and content.""" @@ -314,13 +294,6 @@ class Alert(DataClassORJSONMixin): ) # End time of the alert link: str | None = field(default=None) # Link to more information -@dataclass -class Alerts(DataClassORJSONMixin): - """Represents alerts for a train connection, including the type and content.""" - - number: int # Number of alerts - alert: List[Alert] = field(default_factory=list) # List of alerts - @dataclass class ConnectionDetails(DataClassORJSONMixin): @@ -330,16 +303,31 @@ class ConnectionDetails(DataClassORJSONMixin): departure: ConnectionDeparture # Departure details arrival: ConnectionArrival # Arrival details duration: int # Duration of the connection in minutes - remarks: Remarks # Remarks for the connection - alerts: Alerts # Alerts for the connection - vias: Vias | None = field(default=None) # Vias information + remarks: list[Remark] = field(default_factory=list) # List of remarks for the connection + alerts: list[Alert] = field(default_factory=list) # List of alerts for the connection + vias: list[Via] | None = field(default=None) # List of via details + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Flatten the structure of the 'connections' response by extracting lists of remarks, alerts, and optionally flatten vias before deserialization.""" + # Extract 'remark' and 'alert' list from 'remarks' and 'alerts' before deserialization. + d["remarks"] = d["remarks"]["remark"] + d["alerts"] = d["alerts"]["alert"] + + # Safely flatten vias + if "vias" in d and d["vias"] is not None: + d["vias"] = d["vias"].get("via", []) + else: + d["vias"] = None # Set to None if missing or null + + return d @dataclass class ConnectionsApiResponse(ApiResponse): """Holds a list of connections returned by the connections endpoint.""" - connections: List[ConnectionDetails] = field( + connections: list[ConnectionDetails] = field( metadata=field_options(alias="connection"), default_factory=list ) # List of connections @@ -383,21 +371,19 @@ class VehicleStop(DataClassORJSONMixin): ) # Departure connection link, not present in the last stop -@dataclass -class VehicleStops(DataClassORJSONMixin): - """Holds the number of stops and a list of detailed stop information for vehicles.""" - - number: int # Number of stops - stop: List[VehicleStop] = field(default_factory=list) # List of stop details - - @dataclass class VehicleApiResponse(ApiResponse): - """Provides detailed data about a particular vehicle, including its stops.""" + """Provides detailed data about a particular vehicle, including a list of its stops.""" vehicle: str # Vehicle identifier vehicle_info: VehicleInfo = field(metadata=field_options(alias="vehicleinfo")) # Vehicle information - stops: VehicleStops # Stops information + stops: list[VehicleStop] = field(default_factory=list) # List of stop details + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Extract 'stop' list from 'stops' before deserialization.""" + d["stops"] = d["stops"]["stop"] + return d @dataclass @@ -464,20 +450,18 @@ class Unit(DataClassORJSONMixin): ) # Whether the unit has a bike section -@dataclass -class CompositionUnits(DataClassORJSONMixin): - """Holds the number of units and a list of detailed unit information.""" - - number: int # Number of units - unit: List[Unit] = field(default_factory=list) # List of units - - @dataclass class SegmentComposition(DataClassORJSONMixin): """Describes a collection of train units and related metadata.""" source: str # Source of the composition - units: CompositionUnits # Units information + units: list[Unit] = field(default_factory=list) # List of units + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Extract 'unit' list from 'units' before deserialization.""" + d["units"] = d["units"]["unit"] + return d @dataclass @@ -490,26 +474,17 @@ class Segment(DataClassORJSONMixin): composition: SegmentComposition # Composition details of the segment -@dataclass -class Segments(DataClassORJSONMixin): - """Holds the number of segments and a list of detailed segment information.""" - - number: int # Number of segments - segment: List[Segment] = field(default_factory=list) # List of segments - - -@dataclass -class CompositionSegments(DataClassORJSONMixin): - """Encapsulated the composition segments of a specific train.""" - - segments: Segments # Segments information - - @dataclass class CompositionApiResponse(ApiResponse): """Encapsulates the response containing composition details of a specific train.""" - composition: CompositionSegments # Composition details + composition: list[Segment] = field(default_factory=list) # Composition details of the train + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Pre-deserialization hook to safely flatten composition segments.""" + d["composition"] = d["composition"]["segments"]["segment"] + return d @dataclass @@ -521,16 +496,6 @@ class DescriptionLink(DataClassORJSONMixin): text: str # Text displayed for the link -@dataclass -class DescriptionLinks(DataClassORJSONMixin): - """Holds the number of description links and a list of detailed description link information.""" - - number: int # Number of description links - description_link: List[DescriptionLink] = field( - metadata=field_options(alias="descriptionLink"), default_factory=list - ) # List of description links - - @dataclass class Disturbance(DataClassORJSONMixin): """Represents a railway system disturbance, including description and metadata.""" @@ -544,13 +509,21 @@ class Disturbance(DataClassORJSONMixin): metadata=field_options(deserialize=_timestamp_to_datetime) ) # Timestamp of the disturbance richtext: str # Rich-text description (HTML-like) - description_links: DescriptionLinks = field(metadata=field_options(alias="descriptionLinks")) # Description links + description_links: list[DescriptionLink] = field( + metadata=field_options(alias="descriptionLinks"), default_factory=list + ) # List of description links + + @classmethod + def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]: + """Extract 'descriptionLink' list from 'descriptionLinks' before deserialization.""" + d["descriptionLinks"] = d["descriptionLinks"]["descriptionLink"] + return d @dataclass class DisturbancesApiResponse(ApiResponse): """Encapsulates multiple disturbances returned by the disturbances endpoint.""" - disturbances: List[Disturbance] = field( + disturbances: list[Disturbance] = field( metadata=field_options(alias="disturbance"), default_factory=list ) # List of disturbances diff --git a/tests/test_irail.py b/tests/test_irail.py index 4c845b1..9856043 100644 --- a/tests/test_irail.py +++ b/tests/test_irail.py @@ -104,16 +104,16 @@ async def test_get_liveboard(): assert isinstance(liveboard, LiveboardApiResponse), "Expected response to be a dictionary" # Validate the structure of departure data - departure_list = liveboard.departures.departure + departure_list = liveboard.departures assert isinstance(departure_list, list), "Expected 'departure' to be a list" assert len(departure_list) > 0, "Expected at least one departure in the response" - assert isinstance( - departure_list[0], LiveboardDeparture - ), "Expected the first departure to be a LiveboardDeparture object" + assert isinstance(departure_list[0], LiveboardDeparture), ( + "Expected the first departure to be a LiveboardDeparture object" + ) # Test VehicleInfo dataclass - assert isinstance( - departure_list[0].vehicle_info, VehicleInfo - ), "Expected vehicle_info to be a VehicleInfo object" + assert isinstance(departure_list[0].vehicle_info, VehicleInfo), ( + "Expected vehicle_info to be a VehicleInfo object" + ) @pytest.mark.asyncio @@ -139,9 +139,9 @@ async def test_get_connections(): connection_list = connections.connections assert isinstance(connection_list, list), "Expected 'connection' to be a list" assert len(connection_list) > 0, "Expected at least one connection in the response" - assert isinstance( - connection_list[0], ConnectionDetails - ), "Expected the first connection to be a ConnectionDetails object" + assert isinstance(connection_list[0], ConnectionDetails), ( + "Expected the first connection to be a ConnectionDetails object" + ) @pytest.mark.asyncio @@ -160,11 +160,12 @@ async def test_get_vehicle(): assert vehicle is not None, "The response should not be None" assert isinstance(vehicle, VehicleApiResponse), "Expected response to be a VehicleApiResponse object" assert isinstance(vehicle.vehicle_info, VehicleInfo), "Expected vehicle_info to be a VehicleInfo object" - assert isinstance(vehicle.stops.stop, list), "Expected 'stop' to be a list" - assert vehicle.stops.number >= 0, "Expected 'number' to be a non-negative integer" - stop = vehicle.stops.stop[0] - assert isinstance(stop.platform_info, PlatformInfo), "Expected platform_info to be a PlatformInfo object" - assert isinstance(stop.occupancy, Occupancy), "Expected occupancy to be an Occupancy object" + assert isinstance(vehicle.stops, list), "Expected 'stop' to be a list" + assert len(vehicle.stops) > 0, "Expected at least one stop" + if len(vehicle.stops) > 0: + stop = vehicle.stops[0] + assert isinstance(stop.platform_info, PlatformInfo), "Expected platform_info to be a PlatformInfo object" + assert isinstance(stop.occupancy, Occupancy), "Expected occupancy to be an Occupancy object" @pytest.mark.asyncio @@ -184,26 +185,26 @@ async def test_get_composition(): composition = await api.get_composition("IC538") assert composition is not None, "The response should not be None" - assert isinstance( - composition, CompositionApiResponse - ), "Expected response to be a CompositionApiResponse object" + assert isinstance(composition, CompositionApiResponse), ( + "Expected response to be a CompositionApiResponse object" + ) # Test segments structure - segments = composition.composition.segments - assert isinstance(segments.segment, list), "Expected 'segment' to be a list" - assert segments.number >= 0, "Expected 'number' to be a non-negative integer" + segments = composition.composition + assert isinstance(segments, list), "Expected 'segments' to be a list" + assert len(segments) > 0, "Expected 'number' to be a non-negative integer" - if segments.number > 0: - segment = segments.segment[0] + if len(segments) > 0: + segment = segments[0] assert isinstance(segment.origin, StationDetails), "Expected origin to be a StationDetails object" assert isinstance(segment.destination, StationDetails), "Expected destination to be a StationDetails object" # Test units in composition units = segment.composition.units - assert units.number >= 0, "Expected 'number' to be a non-negative integer" + assert len(units) > 0, "Expected 'number' to be a non-negative integer" - if units.number > 0: - unit = units.unit[0] + if len(units) > 0: + unit = units[0] assert isinstance(unit.has_toilets, bool), "Expected 'has_toilets' to be a boolean" assert isinstance(unit.seats_first_class, int), "Expected 'seats_first_class' to be an integer" assert isinstance(unit.length_in_meter, int), "Expected 'length_in_meter' to be an integer" @@ -224,9 +225,9 @@ async def test_get_disturbances(): disturbances = await api.get_disturbances() assert disturbances is not None, "The response should not be None" - assert isinstance( - disturbances, DisturbancesApiResponse - ), "Expected response to be a DisturbancesApiResponse object" + assert isinstance(disturbances, DisturbancesApiResponse), ( + "Expected response to be a DisturbancesApiResponse object" + ) assert isinstance(disturbances.disturbances, list), "Expected 'disturbances' to be a list" # Test disturbance attributes @@ -343,99 +344,96 @@ async def test_timestamp_to_datetime(): async def test_timestamp_field_deserialization(): """Test timestamp field deserialization in various models.""" # Test ApiResponse timestamp - api_response = ApiResponse.from_dict({ - "version": "1.0", - "timestamp": "1705593600" - }) + api_response = ApiResponse.from_dict({"version": "1.0", "timestamp": "1705593600"}) assert api_response.timestamp == datetime(2024, 1, 18, 17, 0) # Test LiveboardDeparture time - departure = LiveboardDeparture.from_dict({ - "id": "0", - "station": "Brussels-South/Brussels-Midi", - "stationinfo": { - "@id": "http://irail.be/stations/NMBS/008814001", - "id": "BE.NMBS.008814001", - "name": "Brussels-South/Brussels-Midi", - "locationX": "4.336531", - "locationY": "50.835707", - "standardname": "Brussel-Zuid/Bruxelles-Midi" - }, - "time": "1705593600", - "delay": "0", - "canceled": "0", - "left": "0", - "isExtra": "0", - "vehicle": "BE.NMBS.EC9272", - "vehicleinfo": { - "name": "BE.NMBS.EC9272", - "shortname": "EC 9272", - "number": "9272", - "type": "EC", - "locationX": "0", - "locationY": "0", - "@id": "http://irail.be/vehicle/EC9272" - }, - "platform": "23", - "platforminfo": { - "name": "23", - "normal": "1" - }, - "occupancy": { - "@id": "http://api.irail.be/terms/low", - "name": "low" - }, - "departureConnection": "http://irail.be/connections/8821006/20250106/EC9272" - }) + departure = LiveboardDeparture.from_dict( + { + "id": "0", + "station": "Brussels-South/Brussels-Midi", + "stationinfo": { + "@id": "http://irail.be/stations/NMBS/008814001", + "id": "BE.NMBS.008814001", + "name": "Brussels-South/Brussels-Midi", + "locationX": "4.336531", + "locationY": "50.835707", + "standardname": "Brussel-Zuid/Bruxelles-Midi", + }, + "time": "1705593600", + "delay": "0", + "canceled": "0", + "left": "0", + "isExtra": "0", + "vehicle": "BE.NMBS.EC9272", + "vehicleinfo": { + "name": "BE.NMBS.EC9272", + "shortname": "EC 9272", + "number": "9272", + "type": "EC", + "locationX": "0", + "locationY": "0", + "@id": "http://irail.be/vehicle/EC9272", + }, + "platform": "23", + "platforminfo": {"name": "23", "normal": "1"}, + "occupancy": {"@id": "http://api.irail.be/terms/low", "name": "low"}, + "departureConnection": "http://irail.be/connections/8821006/20250106/EC9272", + } + ) assert departure.time == datetime(2024, 1, 18, 17, 0) # Test Alert start_time and end_time - alert = Alert.from_dict({ - "id": "0", - "header": "Anvers-Central / Antwerpen-Centraal - Anvers-Berchem / Antwerpen-Berchem", - "description": "During the weekends, from 4 to 19/01 Infrabel is working on the track. The departure times of this train change. The travel planner takes these changes into account.", - "lead": "During the weekends, from 4 to 19/01 Infrabel is working on the track", - "startTime": "1705593600", - "endTime": "1705597200" - }) + alert = Alert.from_dict( + { + "id": "0", + "header": "Anvers-Central / Antwerpen-Centraal - Anvers-Berchem / Antwerpen-Berchem", + "description": "During the weekends, from 4 to 19/01 Infrabel is working on the track. The departure times of this train change. The travel planner takes these changes into account.", + "lead": "During the weekends, from 4 to 19/01 Infrabel is working on the track", + "startTime": "1705593600", + "endTime": "1705597200", + } + ) assert alert.start_time == datetime(2024, 1, 18, 17, 0) assert alert.end_time == datetime(2024, 1, 18, 18, 0) # Test Disturbance timestamp - disturbance = Disturbance.from_dict({ - "id": "1", - "title": "Mouscron / Moeskroen - Lille Flandres (FR)", - "description": "On weekdays from 6 to 17/01 works will take place on the French rail network.An SNCB bus replaces some IC trains Courtrai / Kortrijk - Mouscron / Moeskroen - Lille Flandres (FR) between Mouscron / Moeskroen and Lille Flandres (FR).The travel planner takes these changes into account.Meer info over de NMBS-bussen (FAQ)En savoir plus sur les bus SNCB (FAQ)Où prendre mon bus ?Waar is mijn bushalte?", - "type": "planned", - "link": "https://www.belgiantrain.be/nl/support/faq/faq-routes-schedules/faq-bus", - "timestamp": "1705593600", - "richtext": "On weekdays from 6 to 17/01 works will take place on the French rail network.An SNCB bus replaces some IC trains Courtrai / Kortrijk - Mouscron / Moeskroen - Lille Flandres (FR) between Mouscron / Moeskroen and Lille Flandres (FR).The travel planner takes these changes into account.
Meer info over de NMBS-bussen (FAQ)
En savoir plus sur les bus SNCB (FAQ)
Où prendre mon bus ?
Waar is mijn bushalte?", - "descriptionLinks": { - "number": "4", - "descriptionLink": [ - { - "id": "0", - "link": "https://www.belgiantrain.be/nl/support/faq/faq-routes-schedules/faq-bus", - "text": "Meer info over de NMBS-bussen (FAQ)" - }, - { - "id": "1", - "link": "https://www.belgiantrain.be/fr/support/faq/faq-routes-schedules/faq-bus", - "text": "En savoir plus sur les bus SNCB (FAQ)" - }, - { - "id": "2", - "link": "https://www.belgianrail.be/jp/download/brail_him/1736172333792_FR_2501250_S.pdf", - "text": "Où prendre mon bus ?" - }, - { - "id": "3", - "link": "https://www.belgianrail.be/jp/download/brail_him/1736172333804_NL_2501250_S.pdf", - "text": "Waar is mijn bushalte?" - } - ] + disturbance = Disturbance.from_dict( + { + "id": "1", + "title": "Mouscron / Moeskroen - Lille Flandres (FR)", + "description": "On weekdays from 6 to 17/01 works will take place on the French rail network.An SNCB bus replaces some IC trains Courtrai / Kortrijk - Mouscron / Moeskroen - Lille Flandres (FR) between Mouscron / Moeskroen and Lille Flandres (FR).The travel planner takes these changes into account.Meer info over de NMBS-bussen (FAQ)En savoir plus sur les bus SNCB (FAQ)Où prendre mon bus ?Waar is mijn bushalte?", + "type": "planned", + "link": "https://www.belgiantrain.be/nl/support/faq/faq-routes-schedules/faq-bus", + "timestamp": "1705593600", + "richtext": "On weekdays from 6 to 17/01 works will take place on the French rail network.An SNCB bus replaces some IC trains Courtrai / Kortrijk - Mouscron / Moeskroen - Lille Flandres (FR) between Mouscron / Moeskroen and Lille Flandres (FR).The travel planner takes these changes into account.
Meer info over de NMBS-bussen (FAQ)
En savoir plus sur les bus SNCB (FAQ)
Où prendre mon bus ?
Waar is mijn bushalte?", + "descriptionLinks": { + "number": "4", + "descriptionLink": [ + { + "id": "0", + "link": "https://www.belgiantrain.be/nl/support/faq/faq-routes-schedules/faq-bus", + "text": "Meer info over de NMBS-bussen (FAQ)", + }, + { + "id": "1", + "link": "https://www.belgiantrain.be/fr/support/faq/faq-routes-schedules/faq-bus", + "text": "En savoir plus sur les bus SNCB (FAQ)", + }, + { + "id": "2", + "link": "https://www.belgianrail.be/jp/download/brail_him/1736172333792_FR_2501250_S.pdf", + "text": "Où prendre mon bus ?", + }, + { + "id": "3", + "link": "https://www.belgianrail.be/jp/download/brail_him/1736172333804_NL_2501250_S.pdf", + "text": "Waar is mijn bushalte?", + }, + ], + }, } - }) + ) assert disturbance.timestamp == datetime(2024, 1, 18, 17, 0) @@ -451,56 +449,46 @@ async def test_str_to_bool(): async def test_boolean_field_deserialization(): """Test the deserialization of boolean fields in models.""" # Test PlatformInfo boolean field - platform = PlatformInfo.from_dict({ - "name": "1", - "normal": "1" - }) + platform = PlatformInfo.from_dict({"name": "1", "normal": "1"}) assert platform.normal is True, "Platform normal field should be True when '1'" - platform = PlatformInfo.from_dict({ - "name": "1", - "normal": "0" - }) + platform = PlatformInfo.from_dict({"name": "1", "normal": "0"}) assert platform.normal is False, "Platform normal field should be False when '0'" # Test LiveboardDeparture multiple boolean fields - departure = LiveboardDeparture.from_dict({ - "id": "1", - "station": "Brussels", - "stationinfo": { - "@id": "1", + departure = LiveboardDeparture.from_dict( + { "id": "1", - "name": "Brussels", - "locationX": 4.3517, - "locationY": 50.8503, - "standardname": "Brussels-Central" - }, - "time": "1705593600", # Example timestamp - "delay": 0, - "canceled": "1", - "left": "0", - "isExtra": "1", - "vehicle": "BE.NMBS.IC1234", - "vehicleinfo": { - "name": "IC1234", - "shortname": "IC1234", - "number": "1234", - "type": "IC", - "locationX": 4.3517, - "locationY": 50.8503, - "@id": "1" - }, - "platform": "1", - "platforminfo": { - "name": "1", - "normal": "1" - }, - "occupancy": { - "@id": "1", - "name": "LOW" - }, - "departureConnection": "1" - }) + "station": "Brussels", + "stationinfo": { + "@id": "1", + "id": "1", + "name": "Brussels", + "locationX": 4.3517, + "locationY": 50.8503, + "standardname": "Brussels-Central", + }, + "time": "1705593600", # Example timestamp + "delay": 0, + "canceled": "1", + "left": "0", + "isExtra": "1", + "vehicle": "BE.NMBS.IC1234", + "vehicleinfo": { + "name": "IC1234", + "shortname": "IC1234", + "number": "1234", + "type": "IC", + "locationX": 4.3517, + "locationY": 50.8503, + "@id": "1", + }, + "platform": "1", + "platforminfo": {"name": "1", "normal": "1"}, + "occupancy": {"@id": "http://api.irail.be/terms/low", "name": "low"}, + "departureConnection": "1", + } + ) # Verify boolean fields are correctly deserialized assert departure.canceled is True, "Departure canceled field should be True when '1'"