diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..a2088f8 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,27 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python Dev Container", + "image": "mcr.microsoft.com/devcontainers/python:3.12", + // Features to add to the dev container. More info: https://containers.dev/features. + "features": { + "ghcr.io/devcontainers-extra/features/poetry:2": {}, + "ghcr.io/devcontainers-extra/features/ruff:1": {}, + "ghcr.io/devcontainers/features/git:1": {}, + "ghcr.io/devcontainers/features/github-cli:1": {}, + "ghcr.io/devcontainers/features/python:1": {} + }, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "poetry install", + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "charliermarsh.ruff", + "ms-python.python" + ] + } + } +} \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..ca9df49 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,28 @@ +name: Build Package + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install poetry + poetry install + + - name: Build the package + run: poetry build diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..31f6333 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,30 @@ +name: Publish to PyPI + +on: + push: + tags: + - "v*" + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install Poetry + run: pip install poetry + + - name: Install dependencies + run: poetry install --no-dev + + - name: Build and Publish + env: + POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_API_TOKEN }} + run: poetry publish --build diff --git a/.github/workflows/ruff-lint.yml b/.github/workflows/ruff-lint.yml new file mode 100644 index 0000000..80ce22f --- /dev/null +++ b/.github/workflows/ruff-lint.yml @@ -0,0 +1,28 @@ +name: Lint with Ruff + +on: + push: + branches: + - main + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install poetry + poetry install + + - name: Run Ruff + run: poetry run ruff check . diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..b5d869c --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,28 @@ +name: Run Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Install dependencies + run: | + pip install poetry + poetry install + + - name: Run tests + run: poetry run pytest diff --git a/.gitignore b/.gitignore index a65d046..0db7850 100644 --- a/.gitignore +++ b/.gitignore @@ -6,9 +6,12 @@ __pycache__/ # C extensions *.so +# Virtual environment +env/ +.venv/ + # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ @@ -24,6 +27,13 @@ var/ .installed.cfg *.egg +# Pytest cache +.pytest_cache/ + +# IDE files +.vscode/ +.idea/ + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 1cd9ac7..0000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,21 +0,0 @@ -image: python:alpine - -stages: - - deploy - -before_script: - - pip install twine - -variables: - TWINE_USERNAME: SECURE - TWINE_PASSWORD: SECURE - -deploy: - stage: deploy - script: - - python setup.py sdist bdist_wheel - - twine upload dist/* - only: - - tags - except: - - branches \ No newline at end of file diff --git a/README.md b/README.md index 17fc576..5a891ca 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,79 @@ # pyRail -A Python wrapper for the iRail API +A Python wrapper for the iRail API, designed to make interacting with iRail simple and efficient. + +## Overview +pyRail is a Python library that provides a convenient interface for interacting with the iRail API. It supports various endpoints such as stations, liveboard, vehicle, connections, and disturbances. The library includes features like caching and rate limiting to optimize API usage. + +## Features +- Retrieve real-time train information, including liveboards and vehicle details. +- Access train station data, connections, and disturbances. +- Supports API endpoints: stations, liveboard, vehicle, connections, and disturbances. +- Caching and conditional GET requests using ETags. +- Rate limiting to handle API request limits efficiently. + +## Installation +To install pyRail, use pip: + +```bash +pip install pyrail +``` + +## Usage +Here is an example of how to use pyRail: + +```python +from pyrail.irail import iRail + +# Create an instance of the iRail class +api = iRail(format='json', lang='en') + +# Make a request to the 'stations' endpoint +stations = api.get_stations() + +# Print the response +print(stations) +``` + +## Configuration +You can configure the format and language for the API requests: + +```python +api = iRail(format='json', lang='en') +``` + +- Supported formats: json, xml, jsonp +- Supported languages: nl, fr, en, de + +## Development +1. Clone the repository: + ```bash + git clone https://github.com/tjorim/pyrail.git + ``` +2. Install dependencies using Poetry: + ```bash + poetry install + ``` +3. Run tests: + ```bash + poetry run pytest + ``` + +## Logging +You can set the logging level at runtime to get detailed logs: + +```python +import logging + +api.set_logging_level(logging.DEBUG) +``` + +## Contributing +Contributions are welcome! Please open an issue or submit a pull request. + +## Contributors +- @tjorim +- @jcoetsie + +## License +This project is licensed under the Apache 2.0 License. See the LICENSE file for details. diff --git a/poetry.lock b/poetry.lock index 6255513..f3ff977 100644 --- a/poetry.lock +++ b/poetry.lock @@ -242,6 +242,17 @@ files = [ {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"}, ] +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + [[package]] name = "idna" version = "3.10" @@ -257,9 +268,83 @@ files = [ all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] [[package]] -name = "multidict" -version = "6.1.0" -description = "multidict implementation" +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "packaging" +version = "24.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pytest" +version = "8.3.4" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, + {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + +[[package]] +name = "requests" +version = "2.32.3" +description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ @@ -449,9 +534,36 @@ files = [ ] [[package]] -name = "yarl" -version = "1.18.3" -description = "Yet another URL library" +name = "ruff" +version = "0.8.4" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.8.4-py3-none-linux_armv6l.whl", hash = "sha256:58072f0c06080276804c6a4e21a9045a706584a958e644353603d36ca1eb8a60"}, + {file = "ruff-0.8.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ffb60904651c00a1e0b8df594591770018a0f04587f7deeb3838344fe3adabac"}, + {file = "ruff-0.8.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ddf5d654ac0d44389f6bf05cee4caeefc3132a64b58ea46738111d687352296"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e248b1f0fa2749edd3350a2a342b67b43a2627434c059a063418e3d375cfe643"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf197b98ed86e417412ee3b6c893f44c8864f816451441483253d5ff22c0e81e"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c41319b85faa3aadd4d30cb1cffdd9ac6b89704ff79f7664b853785b48eccdf3"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9f8402b7c4f96463f135e936d9ab77b65711fcd5d72e5d67597b543bbb43cf3f"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4e56b3baa9c23d324ead112a4fdf20db9a3f8f29eeabff1355114dd96014604"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:736272574e97157f7edbbb43b1d046125fce9e7d8d583d5d65d0c9bf2c15addf"}, + {file = "ruff-0.8.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fe710ab6061592521f902fca7ebcb9fabd27bc7c57c764298b1c1f15fff720"}, + {file = "ruff-0.8.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:13e9ec6d6b55f6da412d59953d65d66e760d583dd3c1c72bf1f26435b5bfdbae"}, + {file = "ruff-0.8.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:97d9aefef725348ad77d6db98b726cfdb075a40b936c7984088804dfd38268a7"}, + {file = "ruff-0.8.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ab78e33325a6f5374e04c2ab924a3367d69a0da36f8c9cb6b894a62017506111"}, + {file = "ruff-0.8.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8ef06f66f4a05c3ddbc9121a8b0cecccd92c5bf3dd43b5472ffe40b8ca10f0f8"}, + {file = "ruff-0.8.4-py3-none-win32.whl", hash = "sha256:552fb6d861320958ca5e15f28b20a3d071aa83b93caee33a87b471f99a6c0835"}, + {file = "ruff-0.8.4-py3-none-win_amd64.whl", hash = "sha256:f21a1143776f8656d7f364bd264a9d60f01b7f52243fbe90e7670c0dfe0cf65d"}, + {file = "ruff-0.8.4-py3-none-win_arm64.whl", hash = "sha256:9183dd615d8df50defa8b1d9a074053891ba39025cf5ae88e8bcb52edcc4bf08"}, + {file = "ruff-0.8.4.tar.gz", hash = "sha256:0d5f89f254836799af1615798caa5f80b7f935d7a670fad66c5007928e57ace8"}, +] + +[[package]] +name = "urllib3" +version = "2.3.0" +description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" files = [ @@ -547,4 +659,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "40dc940931c0b74f835f2700d4b3f15266e2558c59994795ae14dbc662de6881" +content-hash = "ab53c61dfd9b15ea8f62b1a213f6dfa351bcf25b585abe490908f1d822a5ae6d" diff --git a/pyproject.toml b/pyproject.toml index 233c525..ba61a4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,39 @@ packages = [ python = "^3.12" aiohttp = "^3.11.11" +[tool.poetry.group.test.dependencies] +pytest = "^8.3.4" +pytest-mock = "^3.14.0" + +[tool.poetry.group.dev.dependencies] +ruff = "^0.8.4" + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "C", # complexity + "D", # docstrings + "E", # pycodestyle + "F", # pyflakes + "G", # flake8-logging-format + "I", # isort + "W", # pycodestyle warnings +] +ignore = [ + "E501", # line too long + "W191", # indentation contains tabs +] + +[tool.ruff.lint.isort] +force-sort-within-sections = true +combine-as-imports = true +split-on-trailing-comma = false + +[tool.ruff.lint.pydocstyle] +property-decorators = ["propcache.cached_property"] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/pyrail/__init__.py b/pyrail/__init__.py index 49a5cb0..61531d3 100644 --- a/pyrail/__init__.py +++ b/pyrail/__init__.py @@ -1 +1,4 @@ -from .irail import iRail \ No newline at end of file +"""Package initialization for pyrail.""" +from .irail import iRail + +__all__ = ["iRail"] diff --git a/pyrail/irail.py b/pyrail/irail.py index be46afb..8837c42 100644 --- a/pyrail/irail.py +++ b/pyrail/irail.py @@ -1,81 +1,263 @@ -from .api_methods import methods +"""Module providing the iRail class for interacting with the iRail API.""" + +import logging +from threading import Lock +import time +from typing import Any, Dict, Optional + import aiohttp import asyncio -base_url = 'https://api.irail.be/{}/' +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +base_url: str = "https://api.irail.be/{}/" headers = {'user-agent': 'pyRail (tielemans.jorim@gmail.com)'} class iRail: + """A Python wrapper for the iRail API, handling rate limiting and endpoint requests. + + Attributes: + format (str): The data format for API responses ('json', 'xml', 'jsonp'). + lang (str): The language for API responses ('nl', 'fr', 'en', 'de'). + + Endpoints: + stations: Retrieve all stations. + liveboard: Retrieve a liveboard for a station or ID. + connections: Retrieve connections between two stations. + vehicle: Retrieve information about a specific train. + composition: Retrieve the composition of a train. + disturbances: Retrieve information about current disturbances on the rail network. + + """ + + # Allowed endpoints and their expected parameters + endpoints: Dict[str, Dict[str, Any]] = { + "stations": {}, + "liveboard": {"xor": ["station", "id"], "optional": ["date", "time", "arrdep", "alerts"]}, + "connections": { + "required": ["from", "to"], + "optional": ["date", "time", "timesel", "typeOfTransport", "alerts", "results"], + }, + "vehicle": {"required": ["id"], "optional": ["date", "alerts"]}, + "composition": {"required": ["id"], "optional": ["data"]}, + "disturbances": {"optional": ["lineBreakCharacter"]}, + } + + def __init__(self, format: str = "json", lang: str = "en") -> None: + """Initialize the iRail API client. - def __init__(self, format=None, lang=None): - if format is None: - format = 'json' - self.format = format - if lang is None: - lang = 'en' - self.lang = lang + Args: + format (str): The format of the API responses. Default is 'json'. + lang (str): The language for API responses. Default is 'en'. + + """ + self.format: str = format + self.lang: str = lang + self.tokens: int = 3 + self.burst_tokens: int = 5 + self.last_request_time: float = time.time() + self.lock: Lock = Lock() + session.headers.update({"User-Agent": "pyRail (https://github.com/tjorim/pyrail; tielemans.jorim@gmail.com)"}) + self.etag_cache: Dict[str, str] = {} + logger.info("iRail instance created") @property - def format(self): + def format(self) -> str: return self.__format @format.setter - def format(self, value): - if value in ['xml', 'json', 'jsonp']: + def format(self, value: str) -> None: + if value in ["xml", "json", "jsonp"]: self.__format = value else: - self.__format = 'json' + self.__format = "json" @property - def lang(self): + def lang(self) -> str: return self.__lang @lang.setter - def lang(self, value): - if value in ['nl', 'fr', 'en', 'de']: + def lang(self, value: str) -> None: + if value in ["nl", "fr", "en", "de"]: self.__lang = value else: - self.__lang = 'en' - - async def do_request(self, method, args=None): - if method in methods: - url = base_url.format(method) - params = {'format': self.format, 'lang': self.lang} - if args: - params.update(args) - - async with aiohttp.ClientSession(headers=headers) as session: - try: - async with session.get(url, params=params) as response: - if response.status == 200: - try: - return await response.json() - except aiohttp.ContentTypeError: - return -1 - else: - print(f"HTTP error: {response.status}") - return -1 - except aiohttp.ClientError as e: - print(f"Request failed: {e}") - return -1 + self.__lang = "en" + + def _refill_tokens(self) -> None: + """Refill rate limit tokens based on elapsed time.""" + logger.debug("Refilling tokens") + current_time: float = time.time() + elapsed: float = current_time - self.last_request_time + self.last_request_time = current_time + + # Refill tokens, 3 tokens per second, cap tokens at 3 + self.tokens = min(3, self.tokens + int(elapsed * 3)) + # Refill burst tokens, 3 burst tokens per second, cap burst tokens at 5 + self.burst_tokens = min(5, self.burst_tokens + int(elapsed * 3)) + + def _handle_rate_limit(self) -> None: + """Handle rate limiting by refilling tokens or waiting.""" + logger.debug("Handling rate limit") + self._refill_tokens() + if self.tokens < 1: + if self.burst_tokens >= 1: + self.burst_tokens -= 1 + else: + logger.warning("Rate limiting active, waiting for tokens") + time.sleep(1 - (time.time() - self.last_request_time)) + self._refill_tokens() + self.tokens -= 1 + else: + self.tokens -= 1 + + def _add_etag_header(self, method: str) -> Dict[str, str]: + """Add ETag header if a cached ETag exists.""" + headers: Dict[str, str] = {} + if method in self.etag_cache: + logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method]) + headers["If-None-Match"] = self.etag_cache[method] + return headers + + def validate_params(self, method: str, params: Optional[Dict[str, Any]] = None) -> bool: + """Validate parameters and XOR conditions for a given endpoint.""" + if method not in self.endpoints: + logger.error("Unknown API endpoint: %s", method) + return False - + endpoint = self.endpoints[method] + required = endpoint.get("required", []) + xor = endpoint.get("xor", []) + optional = endpoint.get("optional", []) - async def get_stations(self): + params = params or {} + + # Ensure all required parameters are present + for param in required: + if param not in params or params[param] is None: + logger.error("Missing required parameter: %s for endpoint: %s", param, method) + return False + + # Check XOR logic (only one of XOR parameters can be set) + if xor: + xor_values = [params.get(param) is not None for param in xor] + if sum(xor_values) != 1: + logger.error("Exactly one of the XOR parameters %s must be provided for endpoint: %s", xor, method) + return False + + # Ensure no unexpected parameters are included + all_params = required + xor + optional + for param in params.keys(): + if param not in all_params: + logger.error("Unexpected parameter: %s for endpoint: %s", param, method) + return False + + return True + + async def do_request(self, method: str, args: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + """Send a request to the specified iRail API endpoint.""" + logger.info("Starting request to endpoint: %s", method) + if not self.validate_params(method, args or {}): + logger.error("Validation failed for method: %s with args: %s", method, args) + return None + + with self.lock: + self._handle_rate_limit() + + # Construct the request URL and parameters + url: str = base_url.format(method) + params = {"format": self.format, "lang": self.lang} + if args: + params.update(args) + + request_headers: Dict[str, str] = self._add_etag_header(method) + + async with aiohttp.ClientSession(headers=request_headers) as session: + try: + async with session.get(url, params=params) as response: + if response.status == 429: + retry_after: int = int(response.headers.get("Retry-After", 1)) + logger.warning("Rate limited, retrying after %d seconds", retry_after) + time.sleep(retry_after) + return self.do_request(method, args) + if response.status == 200: + # Cache the ETag from the response + if "Etag" in response.headers: + self.etag_cache[method] = response.headers["Etag"] + try: + return await response.json() + except aiohttp.ContentTypeError: + logger.error("Failed to parse JSON response") + return None + elif response.status == 304: + logger.info("Data not modified, using cached data") + return None + else: + logger.error("Request failed with status code: %s, response: %s", response.status, response.text) + return None + except aiohttp.ClientError as e: + logger.error("Request failed due to an exception: %s", e) + return None + + async def get_stations(self) -> Optional[Dict[str, Any]]: """Retrieve a list of all stations.""" - return await self.do_request('stations') - - async def get_liveboard(self, station=None, id=None): - if bool(station) ^ bool(id): - extra_params = {'station': station, 'id': id} - return await self.do_request('liveboard', extra_params) - - async def get_connections(self, from_station=None, to_station=None): - if from_station and to_station: - extra_params = {'from': from_station, 'to': to_station} - return await self.do_request('connections', extra_params) - - async def get_vehicle(self, id=None): - if id: - extra_params = {'id': id} - return await self.do_request('vehicle', extra_params) + return await self.do_request("stations") + + async def get_liveboard( + self, + station: Optional[str] = None, + id: Optional[str] = None, + date: Optional[str] = None, + time: Optional[str] = None, + arrdep: str = "departure", + alerts: bool = False, + ) -> Optional[Dict[str, Any]]: + """Retrieve a liveboard for a station or station ID.""" + extra_params = { + "station": station, + "id": id, + "date": date, + "time": time, + "arrdep": arrdep, + "alerts": "true" if alerts else "false", + } + return await self.do_request("liveboard", {k: v for k, v in extra_params.items() if v is not None}) + + async def get_connections( + self, + from_station: str, + to_station: str, + date: Optional[str] = None, + time: Optional[str] = None, + timesel: str = "departure", + type_of_transport: str = "automatic", + alerts: bool = False, + results: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Retrieve connections between two stations.""" + extra_params = { + "from": from_station, + "to": to_station, + "date": date, + "time": time, + "timesel": timesel, + "typeOfTransport": type_of_transport, + "alerts": "true" if alerts else "false", + "results": results, + } + return await self.do_request("connections", {k: v for k, v in extra_params.items() if v is not None}) + + async def get_vehicle(self, id: str, date: Optional[str] = None, alerts: bool = False) -> Optional[Dict[str, Any]]: + """Retrieve information about a vehicle (train).""" + extra_params = {"id": id, "date": date, "alerts": "true" if alerts else "false"} + return await self.do_request("vehicle", {k: v for k, v in extra_params.items() if v is not None}) + + async def get_composition(self, id: str, data: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Retrieve the composition of a train.""" + extra_params = {"id": id, "data": data} + return await self.do_request("composition", {k: v for k, v in extra_params.items() if v is not None}) + + async def get_disturbances(self, line_break_character: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Retrieve information about current disturbances on the rail network.""" + extra_params = {"lineBreakCharacter": line_break_character} + return await self.do_request("disturbances", {k: v for k, v in extra_params.items() if v is not None}) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_irail.py b/tests/test_irail.py index bab1a76..bbc93ca 100644 --- a/tests/test_irail.py +++ b/tests/test_irail.py @@ -1,7 +1,41 @@ -from pyrail import iRail +from unittest.mock import MagicMock, patch -def test_irail_station(): +from pyrail.irail import iRail - irail_instance = iRail() - response = irail_instance.get_stations() - print(response) +""" +Unit tests for the iRail API wrapper. +""" + +@patch('requests.Session.get') +def test_successful_request(mock_get): + """Test a successful API request by mocking the iRail response.""" + # Mock the response to simulate a successful request + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'data': 'some_data'} + mock_get.return_value = mock_response + + api = iRail() + response = api.do_request('stations') + + # Check that the request was successful + assert mock_get.call_count == 1, "Expected one call to the requests.Session.get method" + assert response == {'data': 'some_data'}, "Expected response data to match the mocked response" + +def test_get_stations(): + api = iRail() + stations = api.get_stations() + + # Ensure the response is not None + assert stations is not None, "The response should not be None" + + # Validate that the response is a dictionary + assert isinstance(stations, dict), "Expected response to be a dictionary" + + # Validate the presence of key fields + assert 'station' in stations, "Expected the response to contain a 'station' key" + + # Validate the structure of station data + station_list = stations.get('station', []) + assert isinstance(station_list, list), "Expected 'station' to be a list" + assert len(station_list) > 0, "Expected at least one station in the response"