Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance iRail class documentation and refactor code for clarity #45

Merged
merged 6 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions pyrail/irail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import time
from types import TracebackType
from typing import Any, Dict, Type
from typing import Any, Type

from aiohttp import ClientError, ClientResponse, ClientSession

Expand Down Expand Up @@ -45,7 +45,7 @@ class iRail:
# Available iRail API endpoints and their parameter requirements.
# Each endpoint is configured with required parameters, optional parameters, and XOR
# parameter groups (where exactly one parameter from the group must be provided).
endpoints: Dict[str, Dict[str, Any]] = {
endpoints: dict[str, dict[str, Any]] = {
"stations": {},
"liveboard": {"xor": ["station", "id"], "optional": ["date", "time", "arrdep", "alerts"]},
"connections": {
Expand All @@ -71,8 +71,8 @@ def __init__(self, lang: str = "en", session: ClientSession | None = None) -> No
self.last_request_time: float = time.time()
self.lock: Lock = Lock()
self.session: ClientSession | None = session
self._owns_session = session is None # Track ownership
self.etag_cache: Dict[str, str] = {}
self._owns_session = session is None # Track ownership
self.etag_cache: dict[str, str] = {}
logger.info("iRail instance created")

async def __aenter__(self) -> "iRail":
Expand Down Expand Up @@ -130,11 +130,10 @@ def clear_etag_cache(self) -> None:
logger.info("ETag cache cleared")

def _refill_tokens(self) -> None:
"""Refill tokens for rate limiting based on elapsed time.

This method refills both standard tokens (max 3) and burst tokens (max 5)
using a token bucket algorithm. The refill rate is 3 tokens per second.
"""Refill tokens for rate limiting using a token bucket algorithm.

- Standard tokens: Refill rate of 3 tokens/second, max 3 tokens.
- Burst tokens: Refilled only when standard tokens are full, max 5 tokens.
"""
logger.debug("Refilling tokens")
current_time: float = time.time()
Expand All @@ -151,9 +150,9 @@ def _refill_tokens(self) -> None:
async def _handle_rate_limit(self) -> None:
"""Handle rate limiting using a token bucket algorithm.

The implementation uses two buckets:
- Normal bucket: 3 tokens/second
- Burst bucket: 5 tokens/second
- Standard tokens: 3 requests/second.
- Burst tokens: Additional 5 requests/second for spikes.
- Waits and refills tokens if both are exhausted.
"""
logger.debug("Handling rate limit")
self._refill_tokens()
Expand All @@ -169,18 +168,18 @@ async def _handle_rate_limit(self) -> None:
else:
self.tokens -= 1

def _add_etag_header(self, method: str) -> Dict[str, str]:
def _add_etag_header(self, method: str) -> dict[str, str]:
"""Add ETag header for the given method if a cached ETag is available.

Args:
method (str): The API endpoint for which the header is being generated.

Returns:
Dict[str, str]: A dictionary containing HTTP headers, including the ETag header
dict[str, str]: A dictionary containing HTTP headers, including the ETag header
if a cached value exists.

"""
headers: Dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"}
headers: dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"}
if method in self.etag_cache:
logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method])
headers["If-None-Match"] = self.etag_cache[method]
Expand Down Expand Up @@ -226,12 +225,12 @@ def _validate_time(self, time: str | None) -> bool:
logger.error("Invalid time format. Expected HHMM (e.g., 1430 for 2:30 PM), got: %s", time)
return False

def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> bool:
def _validate_params(self, method: str, params: dict[str, Any] | None = None) -> bool:
"""Validate parameters for a specific iRail API endpoint based on predefined requirements.

Args:
method (str): The API endpoint method to validate parameters for.
params (Dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None.
params (dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None.

Returns:
bool: True if parameters are valid, False otherwise.
Expand Down Expand Up @@ -289,12 +288,12 @@ def _validate_params(self, method: str, params: Dict[str, Any] | None = None) ->

return True

async def _handle_success_response(self, response: ClientResponse, method: str) -> Dict[str, Any] | None:
async def _handle_success_response(self, response: ClientResponse, method: str) -> dict[str, Any] | None:
"""Handle a successful API response."""
if "Etag" in response.headers:
self.etag_cache[method] = response.headers["Etag"]
try:
json_data: Dict[str, Any] | None = await response.json()
json_data: dict[str, Any] | None = await response.json()
if not json_data:
logger.warning("Empty response received")
return json_data
Expand All @@ -303,8 +302,8 @@ async def _handle_success_response(self, response: ClientResponse, method: str)
return None

async def _handle_response(
self, response: ClientResponse, method: str, args: Dict[str, Any] | None = None
) -> Dict[str, Any] | None:
self, response: ClientResponse, method: str, args: dict[str, Any] | None = None
) -> dict[str, Any] | None:
"""Handle the API response based on status code."""
if response.status == 429:
retry_after: int = int(response.headers.get("Retry-After", 1))
Expand All @@ -326,7 +325,7 @@ async def _handle_response(
logger.error("Request failed with status code: %s, response: %s", response.status, await response.text())
return None

async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> Dict[str, Any] | None:
async def _do_request(self, method: str, args: dict[str, Any] | None = None) -> dict[str, Any] | None:
"""Send an asynchronous request to the specified iRail API endpoint.

This method handles API requests with rate limiting, parameter validation,
Expand Down Expand Up @@ -368,7 +367,7 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) ->
if args:
params.update(args)

request_headers: Dict[str, str] = self._add_etag_header(method)
request_headers: dict[str, str] = self._add_etag_header(method)

try:
async with self.session.get(url, params=params, headers=request_headers) as response:
Expand Down Expand Up @@ -429,7 +428,7 @@ async def get_liveboard(
print(f"Liveboard for Brussels-South: {liveboard}")

"""
extra_params: Dict[str, Any] = {
extra_params: dict[str, Any] = {
"station": station,
"id": id,
"date": date,
Expand Down Expand Up @@ -473,7 +472,7 @@ async def get_connections(
print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}")

"""
extra_params: Dict[str, Any] = {
extra_params: dict[str, Any] = {
"from": from_station,
"to": to_station,
"date": date,
Expand Down Expand Up @@ -504,7 +503,7 @@ async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = Fal
vehicle_info = await client.get_vehicle("BE.NMBS.IC1832")

"""
extra_params: Dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"}
extra_params: dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"}
vehicle_response_dict = await self._do_request(
"vehicle", {k: v for k, v in extra_params.items() if v is not None}
)
Expand All @@ -527,7 +526,7 @@ async def get_composition(self, id: str, data: str | None = None) -> Composition
composition = await client.get_composition('S51507')

"""
extra_params: Dict[str, str | None] = {"id": id, "data": data}
extra_params: dict[str, str | None] = {"id": id, "data": data}
composition_response_dict = await self._do_request(
"composition", {k: v for k, v in extra_params.items() if v is not None}
)
Expand Down
Loading
Loading