From d63284bf748086061f53abf8c49f7b476847cae5 Mon Sep 17 00:00:00 2001 From: Jorim Tielemans Date: Mon, 6 Jan 2025 12:41:51 +0000 Subject: [PATCH] Refactor type hints for improved clarity and consistency in iRail class methods --- pyrail/irail.py | 89 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 27 deletions(-) diff --git a/pyrail/irail.py b/pyrail/irail.py index cf415be..e34370d 100644 --- a/pyrail/irail.py +++ b/pyrail/irail.py @@ -6,7 +6,7 @@ import logging import time from types import TracebackType -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Type from aiohttp import ClientError, ClientResponse, ClientSession @@ -58,17 +58,17 @@ def __init__(self, lang: str = "en") -> None: self.burst_tokens: int = 5 self.last_request_time: float = time.time() self.lock: Lock = Lock() - self.session: Optional[ClientSession] = None + self.session: ClientSession | None = None self.etag_cache: Dict[str, str] = {} logger.info("iRail instance created") - async def __aenter__(self): + async def __aenter__(self) -> "iRail": """Initialize and return the aiohttp client session when entering the async context.""" self.session = ClientSession() return self async def __aexit__( - self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType] + self, exc_type: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None ) -> None: """Close the aiohttp client session when exiting the async context.""" if self.session: @@ -102,7 +102,12 @@ def lang(self, value: str) -> None: self.__lang = "en" def _refill_tokens(self) -> None: - """Refill rate limit tokens based on elapsed time.""" + """Refill tokens for rate limiting based on elapsed time. + + This method refills both standard tokens (max 3) and burst tokens (max 5) + using a token bucket algorithm. The refill rate is 3 tokens per second. + + """ logger.debug("Refilling tokens") current_time: float = time.time() elapsed: float = current_time - self.last_request_time @@ -135,6 +140,16 @@ async def _handle_rate_limit(self) -> None: self.tokens -= 1 def _add_etag_header(self, method: str) -> Dict[str, str]: + """Add ETag header for the given method if a cached ETag is available. + + Args: + method (str): The API endpoint for which the header is being generated. + + Returns: + Dict[str, str]: A dictionary containing HTTP headers, including the ETag header + if a cached value exists. + + """ headers: Dict[str, str] = { "User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"} if method in self.etag_cache: @@ -143,8 +158,17 @@ def _add_etag_header(self, method: str) -> Dict[str, str]: headers["If-None-Match"] = self.etag_cache[method] return headers - def _validate_date(self, date: Optional[str]) -> bool: - """Validate date format (DDMMYY).""" + def _validate_date(self, date: str | None) -> bool: + """Validate the date format (DDMMYY). + + Args: + date (str, optional): The date string to validate. Expected format is DDMMYY, + e.g., '150923' for September 15, 2023. + + Returns: + bool: True if the date is valid or None is provided, False otherwise. + + """ if not date: return True try: @@ -155,8 +179,17 @@ def _validate_date(self, date: Optional[str]) -> bool: "Invalid date format. Expected DDMMYY (e.g., 150923 for September 15, 2023), got: %s", date) return False - def _validate_time(self, time: Optional[str]) -> bool: - """Validate time format (HHMM).""" + def _validate_time(self, time: str | None) -> bool: + """Validate the time format (HHMM). + + Args: + time (str, optional): The time string to validate. Expected format is HHMM, + e.g., '1430' for 2:30 PM. + + Returns: + bool: True if the time is valid or None is provided, False otherwise. + + """ if not time: return True try: @@ -323,13 +356,13 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> logger.error("Request failed due to an exception: %s", e) return None - async def get_stations(self) -> Optional[Dict[str, Any]]: + async def get_stations(self) -> Dict[str, Any] | None: """Retrieve a list of all train stations from the iRail API. This method fetches the complete list of available train stations without any additional filtering parameters. Returns: - Optional[Dict[str, Any]]: A dictionary containing station information, or None if the request fails. + Dict[str, Any] or None: A dictionary containing station information, or None if the request fails. The returned dictionary typically includes details about all train stations supported by the iRail API. Example: @@ -343,13 +376,13 @@ async def get_stations(self) -> Optional[Dict[str, Any]]: async def get_liveboard( self, - station: Optional[str] = None, - id: Optional[str] = None, - date: Optional[str] = None, - time: Optional[str] = None, + station: str | None = None, + id: str | None = None, + date: str | None = None, + time: str | None = None, arrdep: str = "departure", alerts: bool = False, - ) -> Optional[Dict[str, Any]]: + ) -> Dict[str, Any] | None: """Retrieve a liveboard for a specific train station. Asynchronously fetches live departure or arrival information for a given station. @@ -373,7 +406,7 @@ async def get_liveboard( print(f"Liveboard for Brussels-South: {liveboard}") """ - extra_params: Dict[str, Optional[Any]] = { + extra_params: Dict[str, Any] = { "station": station, "id": id, "date": date, @@ -387,11 +420,11 @@ async def get_connections( self, from_station: str, to_station: str, - date: Optional[str] = None, - time: Optional[str] = None, + date: str | None = None, + time: str | None = None, timesel: str = "departure", type_of_transport: str = "automatic", - ) -> Optional[Dict[str, Any]]: + ) -> Dict[str, Any] | None: """Retrieve train connections between two stations using the iRail API. Args: @@ -412,7 +445,7 @@ async def get_connections( print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}") """ - extra_params: Dict[str, Optional[Any]] = { + extra_params: Dict[str, Any] = { "from": from_station, "to": to_station, "date": date, @@ -422,7 +455,7 @@ async def get_connections( } return await self._do_request("connections", {k: v for k, v in extra_params.items() if v is not None}) - async def get_vehicle(self, id: str, date: Optional[str] = None, alerts: bool = False) -> Optional[Dict[str, Any]]: + async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = False) -> Dict[str, Any] | None: """Retrieve detailed information about a specific train vehicle. Args: @@ -438,9 +471,11 @@ async def get_vehicle(self, id: str, date: Optional[str] = None, alerts: bool = vehicle_info = await client.get_vehicle("BE.NMBS.IC1832") """ + extra_params: Dict[str, Any] = { + "id": id, "date": date, "alerts": "true" if alerts else "false"} return await self._do_request("vehicle", {k: v for k, v in extra_params.items() if v is not None}) - async def get_composition(self, id: str, data: Optional[str] = None) -> Optional[Dict[str, Any]]: + async def get_composition(self, id: str, data: str | None = None) -> Dict[str, Any] | None: """Retrieve the composition details of a specific train. Args: @@ -448,24 +483,24 @@ async def get_composition(self, id: str, data: Optional[str] = None) -> Optional data (str, optional): Additional data parameter to get all raw unfiltered data as iRail fetches it from the NMBS (set to 'all'). Defaults to '' (filtered data). Returns: - Optional[Dict[str, Any]]: A dictionary containing the train composition details, or None if the request fails. + Dict[str, Any] or None: A dictionary containing the train composition details, or None if the request fails. Example: async with iRail() as client: composition = await client.get_composition('S51507') """ - extra_params: Dict[str, Optional[str]] = {"id": id, "data": data} + extra_params: Dict[str, str | None] = {"id": id, "data": data} return await self._do_request("composition", {k: v for k, v in extra_params.items() if v is not None}) - async def get_disturbances(self, line_break_character: Optional[str] = None) -> Optional[Dict[str, Any]]: + async def get_disturbances(self, line_break_character: str | None = None) -> Dict[str, Any] | None: """Retrieve information about current disturbances on the rail network. Args: line_break_character (str, optional): A custom character to use for line breaks in the disturbance description. Defaults to ''. Returns: - Optional[Dict[str, Any]]: A dictionary containing disturbance information from the iRail API, or None if no disturbances are found. + Dict[str, Any] or None: A dictionary containing disturbance information from the iRail API, or None if no disturbances are found. Example: async with iRail() as client: