Skip to content

Commit

Permalink
Refactor type hints for improved clarity and consistency in iRail cla…
Browse files Browse the repository at this point in the history
…ss methods
  • Loading branch information
tjorim committed Jan 6, 2025
1 parent f298244 commit d63284b
Showing 1 changed file with 62 additions and 27 deletions.
89 changes: 62 additions & 27 deletions pyrail/irail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import time
from types import TracebackType
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Type

from aiohttp import ClientError, ClientResponse, ClientSession

Expand Down Expand Up @@ -58,17 +58,17 @@ def __init__(self, lang: str = "en") -> None:
self.burst_tokens: int = 5
self.last_request_time: float = time.time()
self.lock: Lock = Lock()
self.session: Optional[ClientSession] = None
self.session: ClientSession | None = None
self.etag_cache: Dict[str, str] = {}
logger.info("iRail instance created")

async def __aenter__(self):
async def __aenter__(self) -> "iRail":
"""Initialize and return the aiohttp client session when entering the async context."""
self.session = ClientSession()
return self

async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType]
self, exc_type: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None
) -> None:
"""Close the aiohttp client session when exiting the async context."""
if self.session:
Expand Down Expand Up @@ -102,7 +102,12 @@ def lang(self, value: str) -> None:
self.__lang = "en"

def _refill_tokens(self) -> None:
"""Refill rate limit tokens based on elapsed time."""
"""Refill tokens for rate limiting based on elapsed time.
This method refills both standard tokens (max 3) and burst tokens (max 5)
using a token bucket algorithm. The refill rate is 3 tokens per second.
"""
logger.debug("Refilling tokens")
current_time: float = time.time()
elapsed: float = current_time - self.last_request_time
Expand Down Expand Up @@ -135,6 +140,16 @@ async def _handle_rate_limit(self) -> None:
self.tokens -= 1

def _add_etag_header(self, method: str) -> Dict[str, str]:
"""Add ETag header for the given method if a cached ETag is available.
Args:
method (str): The API endpoint for which the header is being generated.
Returns:
Dict[str, str]: A dictionary containing HTTP headers, including the ETag header
if a cached value exists.
"""
headers: Dict[str, str] = {
"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"}
if method in self.etag_cache:
Expand All @@ -143,8 +158,17 @@ def _add_etag_header(self, method: str) -> Dict[str, str]:
headers["If-None-Match"] = self.etag_cache[method]
return headers

def _validate_date(self, date: Optional[str]) -> bool:
"""Validate date format (DDMMYY)."""
def _validate_date(self, date: str | None) -> bool:
"""Validate the date format (DDMMYY).
Args:
date (str, optional): The date string to validate. Expected format is DDMMYY,
e.g., '150923' for September 15, 2023.
Returns:
bool: True if the date is valid or None is provided, False otherwise.
"""
if not date:
return True
try:
Expand All @@ -155,8 +179,17 @@ def _validate_date(self, date: Optional[str]) -> bool:
"Invalid date format. Expected DDMMYY (e.g., 150923 for September 15, 2023), got: %s", date)
return False

def _validate_time(self, time: Optional[str]) -> bool:
"""Validate time format (HHMM)."""
def _validate_time(self, time: str | None) -> bool:
"""Validate the time format (HHMM).
Args:
time (str, optional): The time string to validate. Expected format is HHMM,
e.g., '1430' for 2:30 PM.
Returns:
bool: True if the time is valid or None is provided, False otherwise.
"""
if not time:
return True
try:
Expand Down Expand Up @@ -323,13 +356,13 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) ->
logger.error("Request failed due to an exception: %s", e)
return None

async def get_stations(self) -> Optional[Dict[str, Any]]:
async def get_stations(self) -> Dict[str, Any] | None:
"""Retrieve a list of all train stations from the iRail API.
This method fetches the complete list of available train stations without any additional filtering parameters.
Returns:
Optional[Dict[str, Any]]: A dictionary containing station information, or None if the request fails.
Dict[str, Any] or None: A dictionary containing station information, or None if the request fails.
The returned dictionary typically includes details about all train stations supported by the iRail API.
Example:
Expand All @@ -343,13 +376,13 @@ async def get_stations(self) -> Optional[Dict[str, Any]]:

async def get_liveboard(
self,
station: Optional[str] = None,
id: Optional[str] = None,
date: Optional[str] = None,
time: Optional[str] = None,
station: str | None = None,
id: str | None = None,
date: str | None = None,
time: str | None = None,
arrdep: str = "departure",
alerts: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Dict[str, Any] | None:
"""Retrieve a liveboard for a specific train station.
Asynchronously fetches live departure or arrival information for a given station.
Expand All @@ -373,7 +406,7 @@ async def get_liveboard(
print(f"Liveboard for Brussels-South: {liveboard}")
"""
extra_params: Dict[str, Optional[Any]] = {
extra_params: Dict[str, Any] = {
"station": station,
"id": id,
"date": date,
Expand All @@ -387,11 +420,11 @@ async def get_connections(
self,
from_station: str,
to_station: str,
date: Optional[str] = None,
time: Optional[str] = None,
date: str | None = None,
time: str | None = None,
timesel: str = "departure",
type_of_transport: str = "automatic",
) -> Optional[Dict[str, Any]]:
) -> Dict[str, Any] | None:
"""Retrieve train connections between two stations using the iRail API.
Args:
Expand All @@ -412,7 +445,7 @@ async def get_connections(
print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}")
"""
extra_params: Dict[str, Optional[Any]] = {
extra_params: Dict[str, Any] = {
"from": from_station,
"to": to_station,
"date": date,
Expand All @@ -422,7 +455,7 @@ async def get_connections(
}
return await self._do_request("connections", {k: v for k, v in extra_params.items() if v is not None})

async def get_vehicle(self, id: str, date: Optional[str] = None, alerts: bool = False) -> Optional[Dict[str, Any]]:
async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = False) -> Dict[str, Any] | None:
"""Retrieve detailed information about a specific train vehicle.
Args:
Expand All @@ -438,34 +471,36 @@ async def get_vehicle(self, id: str, date: Optional[str] = None, alerts: bool =
vehicle_info = await client.get_vehicle("BE.NMBS.IC1832")
"""
extra_params: Dict[str, Any] = {
"id": id, "date": date, "alerts": "true" if alerts else "false"}
return await self._do_request("vehicle", {k: v for k, v in extra_params.items() if v is not None})

async def get_composition(self, id: str, data: Optional[str] = None) -> Optional[Dict[str, Any]]:
async def get_composition(self, id: str, data: str | None = None) -> Dict[str, Any] | None:
"""Retrieve the composition details of a specific train.
Args:
id (str): The unique identifier of the train for which composition details are requested.
data (str, optional): Additional data parameter to get all raw unfiltered data as iRail fetches it from the NMBS (set to 'all'). Defaults to '' (filtered data).
Returns:
Optional[Dict[str, Any]]: A dictionary containing the train composition details, or None if the request fails.
Dict[str, Any] or None: A dictionary containing the train composition details, or None if the request fails.
Example:
async with iRail() as client:
composition = await client.get_composition('S51507')
"""
extra_params: Dict[str, Optional[str]] = {"id": id, "data": data}
extra_params: Dict[str, str | None] = {"id": id, "data": data}
return await self._do_request("composition", {k: v for k, v in extra_params.items() if v is not None})

async def get_disturbances(self, line_break_character: Optional[str] = None) -> Optional[Dict[str, Any]]:
async def get_disturbances(self, line_break_character: str | None = None) -> Dict[str, Any] | None:
"""Retrieve information about current disturbances on the rail network.
Args:
line_break_character (str, optional): A custom character to use for line breaks in the disturbance description. Defaults to ''.
Returns:
Optional[Dict[str, Any]]: A dictionary containing disturbance information from the iRail API, or None if no disturbances are found.
Dict[str, Any] or None: A dictionary containing disturbance information from the iRail API, or None if no disturbances are found.
Example:
async with iRail() as client:
Expand Down

0 comments on commit d63284b

Please sign in to comment.