diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..0d064c9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,67 @@ +name: Bug 反馈 +title: "Bug: 出现异常" +description: 提交 Bug 反馈以帮助我们改进代码 +labels: ["bug"] +body: + - type: dropdown + id: env-os + attributes: + label: 操作系统 + description: 选择运行 aioarxiv 的系统 + options: + - Windows + - MacOS + - Linux + - Other + validations: + required: true + + - type: input + id: env-python-ver + attributes: + label: Python 版本 + description: 填写运行 aioarxiv 的 Python 版本 + placeholder: e.g. 3.11.0 + validations: + required: true + + - type: input + id: env-nb-ver + attributes: + label: aioarxiv 版本 + description: 填写 aioarxiv 版本 + placeholder: e.g. 0.1.0 + validations: + required: true + + - type: textarea + id: describe + attributes: + label: 描述问题 + description: 清晰简洁地说明问题是什么 + validations: + required: true + + - type: textarea + id: reproduction + attributes: + label: 复现步骤 + description: 提供能复现此问题的详细操作步骤 + placeholder: | + 1. 首先…… + 2. 然后…… + 3. 发生…… + validations: + required: true + + - type: textarea + id: expected + attributes: + label: 期望的结果 + description: 清晰简洁地描述你期望发生的事情 + + - type: textarea + id: logs + attributes: + label: 截图或日志 + description: 提供有助于诊断问题的任何日志和截图 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..ec4bb38 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/document.yml b/.github/ISSUE_TEMPLATE/document.yml new file mode 100644 index 0000000..352a13b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/document.yml @@ -0,0 +1,18 @@ +name: 文档改进 +title: "Docs: 描述" +description: 文档错误及改进意见反馈 +labels: ["documentation"] +body: + - type: textarea + id: problem + attributes: + label: 描述问题或主题 + validations: + required: true + + - type: textarea + id: improve + attributes: + label: 需做出的修改 + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..4f2e79f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,20 @@ +name: 功能建议 +title: "Feature: 功能描述" +description: 提出关于项目新功能的想法 +labels: ["enhancement"] +body: + - type: textarea + id: problem + attributes: + label: 希望能解决的问题 + description: 在使用中遇到什么问题而需要新的功能? + validations: + required: true + + - type: textarea + id: feature + attributes: + label: 描述所需要的功能 + description: 请说明需要的功能或解决方法 + validations: + required: true diff --git a/.github/actions/build-api-doc/action.yml b/.github/actions/build-api-doc/action.yml new file mode 100644 index 0000000..efe4256 --- /dev/null +++ b/.github/actions/build-api-doc/action.yml @@ -0,0 +1,36 @@ +name: Build documentation +description: Build documentation + +on: + push: + branches: ["main"] + workflow_dispatch: + +env: + INSTANCE: 'Writerside/hi' + ARTIFACT: 'webHelpHI2-all.zip' + DOCKER_VERSION: '243.21565' + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Build docs using Writerside Docker builder + uses: JetBrains/writerside-github-action@v4 + with: + instance: ${{ env.INSTANCE }} + artifact: ${{ env.ARTIFACT }} + docker-version: ${{ env.DOCKER_VERSION }} + + - name: Save artifact with build results + uses: actions/upload-artifact@v4 + with: + name: docs + path: | + artifacts/${{ env.ARTIFACT }} + retention-days: 7 \ No newline at end of file diff --git a/.github/actions/setup-node/action.yml b/.github/actions/setup-node/action.yml new file mode 100644 index 0000000..24ad58a --- /dev/null +++ b/.github/actions/setup-node/action.yml @@ -0,0 +1,13 @@ +name: Setup Node +description: Setup Node + +runs: + using: "composite" + steps: + - uses: actions/setup-node@v4 + with: + node-version: "18" + cache: "yarn" + + - run: yarn install --frozen-lockfile + shell: bash diff --git a/.github/actions/setup-python/action.yml b/.github/actions/setup-python/action.yml new file mode 100644 index 0000000..e191b53 --- /dev/null +++ b/.github/actions/setup-python/action.yml @@ -0,0 +1,37 @@ +name: Setup Python +description: Setup Python + +inputs: + python-version: + description: Python version + required: false + default: "3.10" + env-dir: + description: Environment directory + required: false + default: "." + no-root: + description: Do not install package in the environment + required: false + default: "false" + +runs: + using: "composite" + steps: + - name: Install pdm + run: pipx install pdm + shell: bash + + - uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + cache: "pdm" + cache-dependency-path: | + ./pdm.lock + ${{ inputs.env-dir }}/pdm.lock + + - run: | + cd ${{ inputs.env-dir }} + pdm install --all + fi + shell: bash diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml new file mode 100644 index 0000000..5539984 --- /dev/null +++ b/.github/workflows/codecov.yml @@ -0,0 +1,54 @@ +name: Code Coverage + +on: + push: + branches: + - master + pull_request: + paths: + - "envs/**" + - "src/**" + - "tests/**" + - ".github/workflows/codecov.yml" + - "pyproject.toml" + - "pdm.lock" + +jobs: + test: + name: Test Coverage + runs-on: ${{ matrix.os }} + concurrency: + group: test-coverage-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.env }} + cancel-in-progress: true + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, windows-latest, macos-latest] + env: + OS: ${{ matrix.os }} + PYTHON_VERSION: ${{ matrix.python-version }} + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python environment + uses: ./.github/actions/setup-python + with: + python-version: ${{ matrix.python-version }} + env-dir: ./envs/${{ matrix.env }} + no-root: true + + - name: Run Pytest + run: | + cd ./envs/${{ matrix.env }} + poetry run bash "../../scripts/run-tests.sh" + + - name: Upload coverage report + uses: codecov/codecov-action@v4 + with: + env_vars: OS,PYTHON_VERSION,PYDANTIC_VERSION + files: ./tests/coverage.xml + flags: unittests + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..a0dafd5 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,35 @@ +name: Pyright Lint + +on: + push: + branches: + - master + pull_request: + paths: + - "envs/**" + - "src/**" + - "tests/**" + - ".github/actions/setup-python/**" + - ".github/workflows/pyright.yml" + - "pyproject.toml" + - "pdm.lock" + +jobs: + pyright: + name: Pyright Lint + runs-on: ubuntu-latest + concurrency: + group: pyright-${{ github.ref }}-${{ matrix.env }} + cancel-in-progress: true + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python environment + uses: ./.github/actions/setup-python + with: + env-dir: ./envs/${{ matrix.env }} + no-root: true + + - name: Run Pyright + uses: jakebailey/pyright-action@v2 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..1c305df --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,29 @@ +name: Ruff Lint + +on: + push: + branches: + - master + pull_request: + paths: + - "envs/**" + - "src/**" + - "tests/**" + - ".github/actions/setup-python/**" + - ".github/workflows/ruff.yml" + - "pyproject.toml" + - "pdm.lock" + +jobs: + ruff: + name: Ruff Lint + runs-on: ubuntu-latest + concurrency: + group: pyright-${{ github.ref }} + cancel-in-progress: true + + steps: + - uses: actions/checkout@v4 + + - name: Run Ruff Lint + uses: chartboost/ruff-action@v1 diff --git a/.github/workflows/website-deploy.yml b/.github/workflows/website-deploy.yml new file mode 100644 index 0000000..6faf10e --- /dev/null +++ b/.github/workflows/website-deploy.yml @@ -0,0 +1,46 @@ +#name: Site Deploy +# +#on: +# push: +# branches: +# - master +# +#jobs: +# publish: +# runs-on: ubuntu-latest +# concurrency: +# group: website-deploy-${{ github.ref }} +# cancel-in-progress: true +# +# steps: +# - uses: actions/checkout@v4 +# with: +# fetch-depth: 0 +# +# - name: Setup Python Environment +# uses: ./.github/actions/setup-python +# +# - name: Setup Node Environment +# uses: ./.github/actions/setup-node +# +# - name: Build API Doc +# uses: ./.github/actions/build-api-doc +# +# - name: Build Doc +# run: yarn build +# +# - name: Get Branch Name +# run: echo "BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_ENV +# +# - name: Deploy to Netlify +# uses: nwtgck/actions-netlify@v3 +# with: +# publish-dir: "./website/build" +# production-deploy: true +# github-token: ${{ secrets.GITHUB_TOKEN }} +# deploy-message: "Deploy ${{ env.BRANCH_NAME }}@${{ github.sha }}" +# enable-commit-comment: false +# alias: ${{ env.BRANCH_NAME }} +# env: +# NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} +# NETLIFY_SITE_ID: ${{ secrets.SITE_ID }} diff --git a/.github/workflows/website-preview.yml b/.github/workflows/website-preview.yml new file mode 100644 index 0000000..3657b8b --- /dev/null +++ b/.github/workflows/website-preview.yml @@ -0,0 +1,46 @@ +#name: Site Deploy(Preview) +# +#on: +# pull_request_target: +# +#jobs: +# preview: +# runs-on: ubuntu-latest +# concurrency: +# group: pull-request-preview-${{ github.event.number }} +# cancel-in-progress: true +# +# steps: +# - uses: actions/checkout@v4 +# with: +# ref: ${{ github.event.pull_request.head.sha }} +# fetch-depth: 0 +# +# - name: Setup Python Environment +# uses: ./.github/actions/setup-python +# +# - name: Setup Node Environment +# uses: ./.github/actions/setup-node +# +# - name: Build API Doc +# uses: ./.github/actions/build-api-doc +# +# - name: Build Doc +# run: yarn build +# +# - name: Get Deploy Name +# run: | +# echo "DEPLOY_NAME=deploy-preview-${{ github.event.number }}" >> $GITHUB_ENV +# +# - name: Deploy to Netlify +# uses: nwtgck/actions-netlify@v3 +# with: +# publish-dir: "./website/build" +# production-deploy: false +# github-token: ${{ secrets.GITHUB_TOKEN }} +# deploy-message: "Deploy ${{ env.DEPLOY_NAME }}@${{ github.sha }}" +# enable-commit-comment: false +# alias: ${{ env.DEPLOY_NAME }} +# env: +# NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} +# NETLIFY_SITE_ID: ${{ secrets.SITE_ID }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..afd8650 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ +test.py + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm-project.org/#use-with-ide +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# Visual Studio Code +.vscode/ \ No newline at end of file diff --git a/.markdownlint.yaml b/.markdownlint.yaml new file mode 100644 index 0000000..c1b97a3 --- /dev/null +++ b/.markdownlint.yaml @@ -0,0 +1,13 @@ +MD013: false +MD024: + siblings_only: true +MD033: false + +code-block: + ignore: true + +markdown-code-block: + format: false + +fenced-code-blocks: + validate: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..262c6a4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +ci: + autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks" + autofix_prs: true + autoupdate_branch: main + autoupdate_schedule: quarterly + autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks" +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + stages: [pre-commit] + - id: ruff-format + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + stages: [pre-commit] + + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + stages: [pre-commit] + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + types_or: [javascript, jsx, ts, tsx, markdown, yaml, json] + stages: [pre-commit] diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..530fe91 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.9 \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..bab4669 --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# Aioarxiv + +An async Python client for the arXiv API with enhanced performance and flexible configuration options. + +> ⚠️ Warning: This project is currently in beta. Not recommended for production use. + +## Features + +- Asynchronous API calls for better performance +- Flexible search and download capabilities +- Customizable rate limiting and concurrent requests +- Simple error handling + +## Installation + +```bash +pip install aioarxiv +``` + +## Quick Start + +```python +import asyncio +from aioarxiv import ArxivClient + +async def main(): + async with ArxivClient() as client: + async for paper in client.search("quantum computing", max_results=1): + print(f"Title: {paper.title}") + print(f"Authors: {', '.join(a.name for a in paper.authors)}") + print(f"Summary: {paper.summary[:200]}...") + + # Download PDF + file_path = await client.download_paper(paper) + print(f"Downloaded to: {file_path}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Configuration + +```python +from aioarxiv import ArxivConfig, ArxivClient + +config = ArxivConfig( + rate_limit_calls=3, # Rate limit per window + rate_limit_period=1.0, # Window period in seconds + max_concurrent_requests=3 # Max concurrent requests +) + +client = ArxivClient(config=config) +``` + +## Error Handling + +```python +try: + async for paper in client.search("quantum computing"): + print(paper.title) +except SearchCompleteException: + print("Search complete") +``` + +## Requirements +* Python 3.9 or higher + +## License +[MIT License (c) 2024 BalconyJH ](LICENSE) + +## Links +* Documentation for aioarxiv is WIP +* [ArXiv API](https://info.arxiv.org/help/api/index.html) \ No newline at end of file diff --git a/Writerside/a.tree b/Writerside/a.tree new file mode 100644 index 0000000..1458c58 --- /dev/null +++ b/Writerside/a.tree @@ -0,0 +1,12 @@ + + + + + + + + + \ No newline at end of file diff --git a/Writerside/c.list b/Writerside/c.list new file mode 100644 index 0000000..fadecfa --- /dev/null +++ b/Writerside/c.list @@ -0,0 +1,8 @@ + + + + + + + \ No newline at end of file diff --git a/Writerside/cfg/buildprofiles.xml b/Writerside/cfg/buildprofiles.xml new file mode 100644 index 0000000..e0663a7 --- /dev/null +++ b/Writerside/cfg/buildprofiles.xml @@ -0,0 +1,12 @@ + + + + + + + true + + + + diff --git a/Writerside/cfg/glossary.xml b/Writerside/cfg/glossary.xml new file mode 100644 index 0000000..bb59ee0 --- /dev/null +++ b/Writerside/cfg/glossary.xml @@ -0,0 +1,43 @@ + + + + + 异步arXiv API客户端的主类,用于执行论文搜索和获取操作。 + + + + 表示单篇arXiv论文的数据模型,包含标题、作者、摘要等信息。 + + + + 搜索参数配置类,用于定义查询条件和结果限制。 + + + + 客户端配置类,用于设置API访问参数如并发限制、请求频率等。 + + + + 批量处理上下文类,用于跟踪批量获取论文的状态和进度。 + + + + 论文分类信息模型,包含主分类和次分类。 + + + + 会话管理器,负责维护HTTP会话和实现访问频率限制。 + + + + XML解析器,用于解析arXiv API返回的Atom feed格式数据。 + + + + 搜索完成异常,用于标识搜索达到结果限制或无结果时的情况。 + + + + 解析异常,在XML数据解析失败时抛出。 + + \ No newline at end of file diff --git a/Writerside/topics/Installation-guide.md b/Writerside/topics/Installation-guide.md new file mode 100644 index 0000000..880a156 --- /dev/null +++ b/Writerside/topics/Installation-guide.md @@ -0,0 +1,108 @@ +# Installation Guide + +## Introduction + +本安装指南旨在帮助用户成功安装并配置 **aioarxiv** 或作为项目依赖。 + + +警告:该项目目前仍处于测试阶段,尚未进入稳定版本。请勿用于生产环境。 + + +--- + +## Installation types + +我们强烈建议使用如 [PDM](https://pdm-project.org/) , [Poetry](https://python-poetry.org/) +等包管理器管理项目依赖。 + +| **Type** | **Description** | **More information** | +|-----------|-----------------------------|--------------------------------| +| PDM 安装 | 添加 `aioarxiv` 到项目依赖中。 | [跳转到安装步骤](#installation-steps) | +| Poetry 安装 | 添加 `aioarxiv` 到项目依赖中。 | [跳转到安装步骤](#installation-steps) | +| PyPI 安装 | 使用 `pip` 在支持的 Python 版本中安装。 | [跳转到安装步骤](#installation-steps) | +| 源代码安装 | 从 GitHub 克隆代码并手动安装。 | [跳转到安装步骤](#installation-steps) | + +--- + +## Overview + +### 版本信息 + +以下是此库的可用版本。我们暂不推荐在生产环境中依赖该库: + +| **Version** | **Build** | **Release Date** | **Status** | **Notes** | +|-------------|--------------|------------------|------------|-----------------------| +| 0.1.0 | PyPI Release | 15/11/2024 | Latest | Initial testing phase | + +## System requirements + +### 支持的操作系统 + +- **Windows**: Windows 10 或更高版本。 +- **MacOS**: macOS 10.15 或更高版本。 +- **Linux**: Ubuntu 18.04 或更高版本。 + +### 支持的 Python 版本 + +- Python 3.9 或更高版本。 + +## Before you begin + +在安装此库之前,请确保以下事项: + +- **已安装 Python**: + - 确保您的系统中已安装支持的 Python 版本。 + - 使用 `python --version` 确认。 + - 如果未安装 + Python,请前往 [Python 官网](https://www.python.org/downloads/release/python-3920/) + 下载并安装。 +- **已安装 pip**: + - 确保您的系统中已包含 `pip` 并建议更新。 + - 使用 `pip --version` 确认版本号和使用 `pip install --upgrade pip` 更新。 +- **已安装 Git**: + - 如果您选择从源代码安装,请确保已安装 Git。 + - 使用 `git --version` 确认版本号。 +- **已安装包管理器**: + - 如果您选择使用 PDM 或 Poetry,请确保已正确安装。 + - 检查命令: + ```bash + pdm --version # 检查 PDM + poetry --version # 检查 Poetry + ``` + +## Installation steps + +### PDM 安装 + +1. 打开命令行工具。 +2. 执行以下命令以安装库: + ```bash + pdm add aioarxiv + ``` + +### Poetry 安装 + +1. 打开命令行工具。 +2. 执行以下命令以安装库: + ```bash + poetry add aioarxiv + ``` + +### 使用 pip 从 PyPI 安装 + +1. 打开命令行工具。 +2. 执行以下命令以安装库: + ```bash + pip install aioarxiv + ``` + +### 源代码安装 + +1. 打开命令行工具。 +2. 执行以下命令以克隆代码: + ```bash + git clone https://github.com/BalconyJH/aioarxiv.git + cd aioarxiv + pip install . + ``` + \ No newline at end of file diff --git a/Writerside/topics/starter.md b/Writerside/topics/starter.md new file mode 100644 index 0000000..4924928 --- /dev/null +++ b/Writerside/topics/starter.md @@ -0,0 +1,99 @@ +# Starter + +## 安装 + + +警告:该项目目前仍处于测试阶段,尚未进入稳定版本。请勿用于生产环境。 + + +使用 `pip` 安装: `pip install aioarxiv` + +## 基本概念 + +`aioarxiv` 是一个专为 arXiv API 构建的异步 Python 客户端,旨在简化高效的论文检索和下载操作。其主要功能包括: +- 异步 API 调用,提高高频请求的性能。 +- 灵活的查询和下载功能。 +- 自定义配置以满足特定需求。 + +## 主要类 + +* `ArxivClient`: 主要客户端类,用于执行搜索 +* `Paper`: 论文数据模型 +* `SearchParams`: 搜索参数配置模型 + +## 使用示例 + + +import asyncio + +from aioarxiv.client.arxiv_client import ArxivClient +from aioarxiv.config import ArxivConfig + +async def search(): + config = ArxivConfig( + log_level="DEBUG", + ) + client = ArxivClient(config=config) + + query = "quantum computing" + count = 0 + + print(f"Searching for: {query}") + print("-" * 50) + + async for paper in client.search(query=query, max_results=1): + count += 1 + print(f"\nPaper {count}:") + print(f"Title: {paper.title}") + print(f"Authors: {paper.authors}") + print(f"Summary: {paper.summary[:200]}...") + print("-" * 50) + file_path = await client.download_paper(paper) + print(f"Downloaded to: {file_path}") + +if __name__ == "__main__": +asyncio.run(search()) + + +## 自定义配置 + + +from aioarxiv import ArxivConfig + +config = ArxivConfig( + rate_limit_calls=3, # 请求频率限制 + max_concurrent_requests=3 # 最大并发数 +) +client = ArxivClient(config=config) + + +## 错误处理 + +`ArxivClient.search()` 方法返回一个异步生成器, 依赖 `SearchCompleteException` +异常来判断搜索是否完成. + +因此你应该使用以下结构来处理搜索: + + +from aioarxiv import ArxivClient, SearchCompleteException + +async def search_and_handle_error(): + try: + query = "quantum computing" + async for paper in client.search(query=query, max_results=1): + print(paper) + except SearchCompleteException: + print("Search complete.") + + +## 反馈 +功能请求和错误报告请提交到 [GitHub Issue](https://github.com/BalconyJH/aioarxiv/issues/new)。 + + + + arXiv API + + + aioarxiv + + diff --git a/Writerside/v.list b/Writerside/v.list new file mode 100644 index 0000000..2d12cb3 --- /dev/null +++ b/Writerside/v.list @@ -0,0 +1,5 @@ + + + + + diff --git a/Writerside/writerside.cfg b/Writerside/writerside.cfg new file mode 100644 index 0000000..d84ff0c --- /dev/null +++ b/Writerside/writerside.cfg @@ -0,0 +1,11 @@ + + + + + + + + + + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d8f87e5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,108 @@ +[project] +name = "aioarxiv" +version = "0.1.0" +description = "arxiv Parse library" +authors = [ + { name = "BalconyJH", email = "balconyjh@gmail.com" }, +] +dependencies = [ + "feedparser~=6.0", + "loguru~=0.7", + "pydantic>=2.9.2", + "aiohttp>=3.10.10", + "tenacity>=9.0.0", + "aiofiles>=24.1.0", +] +requires-python = ">=3.9" +readme = "README.md" +license = { text = "MIT" } + +[tool.pdm] +distribution = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +addopts = "--cov-report=term-missing --ignore=.venv/" +asyncio_default_fixture_loop_scope = "session" + +[tool.black] +line-length = 88 +target-version = ["py39", "py310", "py311", "py312"] +include = '\.pyi?$' +extend-exclude = ''' +''' + +[tool.isort] +profile = "black" +line_length = 88 +length_sort = true +skip_gitignore = true +force_sort_within_sections = true +src_paths = ["src", "tests"] +extra_standard_library = ["typing_extensions"] + +[tool.ruff] +line-length = 88 +target-version = "py39" + +[tool.ruff.lint] +select = [ + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "UP", # pyupgrade + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "T20", # flake8-print + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RUF", # Ruff-specific rules + "I", # isort +] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E501", # line-too-long + "UP037", # quoted-annotation + "RUF001", # ambiguous-unicode-character-string + "RUF002", # ambiguous-unicode-character-docstring + "RUF003", # ambiguous-unicode-character-comment + "ASYNC230", # blocking-open-call-in-async-function +] + +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.pyright] +pythonVersion = "3.9" +pythonPlatform = "All" +executionEnvironments = [ + { root = "./tests", extraPaths = [ + "./", + ] }, + { root = "./src" }, +] +typeCheckingMode = "standard" +disableBytesTypePromotions = true + +[dependency-groups] +test = [ + "pytest-cov~=5.0", + "pytest-xdist~=3.6", + "pytest-asyncio~=0.23", + "pytest-mock>=3.14.0", +] +dev = [ + "black~=24.4", + "ruff~=0.4", + "isort~=5.13", + "pre-commit~=4.0", + "bump-my-version>=0.28.1", +] + +[tool.coverage.run] +omit = [ + ".venv/*", +] \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/client/__init__.py b/src/client/__init__.py new file mode 100644 index 0000000..3935534 --- /dev/null +++ b/src/client/__init__.py @@ -0,0 +1,348 @@ +# import sys +# from collections.abc import AsyncGenerator +# from dataclasses import dataclass +# from types import TracebackType +# from typing import Optional +# +# from aiohttp import ClientError, ClientResponse +# from anyio import create_memory_object_stream, create_task_group, sleep +# from tenacity import ( +# retry, +# retry_if_exception_type, +# stop_after_attempt, +# wait_exponential, +# ) +# +# from ..config import ArxivConfig, default_config +# from ..exception import ( +# HTTPException, +# ParseErrorContext, +# ParserException, +# QueryBuildError, +# QueryContext, SearchCompleteException, +# ) +# from ..models import Paper, SearchParams +# from ..utils.log import logger +# from ..utils.parser import ArxivParser +# from ..utils.session import SessionManager +# +# +# class ArxivClient: +# """ +# arXiv API异步客户端 +# +# 用于执行arXiv API搜索请求。 +# +# Attributes: +# _config: arXiv API配置 +# _session_manager: 会话管理器 +# """ +# def __init__( +# self, +# config: Optional[ArxivConfig] = None, +# session_manager: Optional[SessionManager] = None, +# ): +# self._config = config or default_config +# self._session_manager = session_manager or SessionManager() +# +# async def search( +# self, +# query: str, +# max_results: Optional[int] = None, +# ) -> AsyncGenerator[Paper, None]: +# """ +# 执行搜索 +# +# Args: +# query: 搜索查询 +# max_results: 最大结果数 +# +# Yields: +# Paper: 论文对象 +# """ +# params = SearchParams(query=query, max_results=max_results) +# +# async with self._session_manager: +# async for paper in self._iter_papers(params): +# yield paper +# +# def _calculate_page_size(self, start: int, max_results: Optional[int]) -> int: +# """计算页面大小""" +# if max_results is None: +# return self._config.page_size +# return min(self._config.page_size, max_results - start) +# +# def _has_more_results( +# self, start: int, total: float, max_results: Optional[int] +# ) -> bool: +# """检查是否有更多结果""" +# max_total = int(total) if total != float("inf") else sys.maxsize +# max_allowed = max_results or sys.maxsize +# return start < min(max_total, max_allowed) +# +# async def _iter_papers(self, params: SearchParams) -> AsyncGenerator[Paper, None]: +# """迭代获取论文 +# +# Args: +# params: 搜索参数 +# +# Yields: +# Paper: 论文对象 +# """ +# start = 0 +# total_results = float("inf") +# concurrent_limit = min( +# self._config.max_concurrent_requests, +# self._config.rate_limit_calls +# ) +# results_count = 0 +# first_batch = True +# +# logger.info("开始搜索", extra={ +# "query": params.query, +# "max_results": params.max_results, +# "concurrent_limit": concurrent_limit +# }) +# +# send_stream, receive_stream = create_memory_object_stream[ +# tuple[list[Paper], int] +# ](max_buffer_size=concurrent_limit) +# +# @retry( +# retry=retry_if_exception_type((ClientError, TimeoutError, OSError)), +# stop=stop_after_attempt(self._config.max_retries), +# wait=wait_exponential( +# multiplier=self._config.rate_limit_period, +# min=self._config.min_wait, +# max=self._config.timeout +# ), +# ) +# async def fetch_page(page_start: int): +# """获取单页并发送结果""" +# try: +# async with self._session_manager.rate_limited_context(): +# response = await self._fetch_page(params, page_start) +# result = await self.parse_response(response) +# await send_stream.send(result) +# except Exception as e: +# logger.error(f"获取页面失败: {e!s}", extra={"page": page_start}) +# raise +# +# async def fetch_batch(batch_start: int, size: int): +# """批量获取论文""" +# try: +# async with create_task_group() as batch_tg: +# # 计算这一批次应获取的数量 +# batch_size = min( +# size, +# self._config.page_size, +# params.max_results - results_count if params.max_results else size +# ) +# +# # 计算结束位置 +# end = min( +# batch_start + batch_size, +# int(total_results) if total_results != float( +# "inf") else sys.maxsize +# ) +# +# # 启动单页获取任务 +# for offset in range(batch_start, end): +# if params.max_results and offset >= params.max_results: +# break +# batch_tg.start_soon(fetch_page, offset) +# except Exception as e: +# logger.error( +# "批量获取失败", +# extra={ +# "start": batch_start, +# "size": size, +# "error": str(e) +# } +# ) +# raise +# +# try: +# async with create_task_group() as tg: +# tg.start_soon( +# fetch_batch, +# start, +# concurrent_limit * self._config.page_size +# ) +# +# async for papers, total in receive_stream: +# total_results = total +# +# if first_batch and total == 0: +# logger.info("未找到结果", extra={"query": params.query}) +# raise SearchCompleteException(0) +# first_batch = False +# +# for paper in papers: +# yield paper +# results_count += 1 +# if params.max_results and results_count >= params.max_results: +# raise SearchCompleteException(results_count) +# +# start += len(papers) +# if start < min( +# int(total_results) if total_results != float( +# "inf") else sys.maxsize, +# params.max_results or sys.maxsize +# ): +# await sleep(self._config.rate_limit_period) +# tg.start_soon( +# fetch_batch, +# start, +# concurrent_limit * self._config.page_size +# ) +# finally: +# await receive_stream.aclose() +# logger.info("搜索结束", extra={"total_results": results_count}) +# +# async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse: +# """ +# 获取单页结果 +# +# Args: +# params: 搜索参数 +# start: 起始位置 +# +# Returns: +# ClientResponse: API响应 +# +# Raises: +# QueryBuildError: 如果构建查询参数失败 +# HTTPException: 如果请求失败 +# """ +# try: +# # 构建查询参数 +# query_params = self._build_query_params(params, start) +# +# # 发送请求 +# response = await self._session_manager.request( +# "GET", str(self._config.base_url), params=query_params +# ) +# +# if response.status != 200: +# logger.error( +# "搜索请求失败", extra={"status_code": response.status} +# ) +# raise HTTPException(response.status) +# +# return response +# +# except QueryBuildError as e: +# logger.error("查询参数构建失败", extra={"error_context": e.context}) +# raise +# except Exception as e: +# logger.error("未预期的错误", exc_info=True) +# raise QueryBuildError( +# message="构建查询参数失败", +# context=QueryContext( +# params={"query": params.query, "start": start}, +# field_name="query_params", +# ), +# original_error=e, +# ) +# +# def _build_query_params(self, params: SearchParams, start: int) -> dict: +# """ +# 构建查询参数 +# +# Args: +# params: 搜索参数 +# start: 起始位置 +# +# Returns: +# dict: 查询参数 +# +# Raises: +# QueryBuildError: 如果构建查询参数失败 +# """ +# if not params.query: +# raise QueryBuildError( +# message="搜索查询不能为空", +# context=QueryContext( +# params={"query": None}, field_name="query", constraint="required" +# ), +# ) +# +# if start < 0: +# raise QueryBuildError( +# message="起始位置不能为负数", +# context=QueryContext( +# params={"start": start}, +# field_name="start", +# constraint="non_negative", +# ), +# ) +# +# try: +# page_size = min( +# self._config.page_size, +# params.max_results - start if params.max_results else float("inf"), +# ) +# except Exception as e: +# raise QueryBuildError( +# message="计算页面大小失败", +# context=QueryContext( +# params={ +# "page_size": self._config.page_size, +# "max_results": params.max_results, +# "start": start, +# }, +# field_name="page_size", +# ), +# original_error=e, +# ) +# +# query_params = { +# "search_query": params.query, +# "start": start, +# "max_results": page_size, +# } +# +# if params.sort_by: +# query_params["sortBy"] = params.sort_by.value +# +# if params.sort_order: +# query_params["sortOrder"] = params.sort_order.value +# +# return query_params +# +# async def parse_response( +# self, +# response: ClientResponse +# ) -> tuple[list[Paper], int]: +# """解析API响应""" +# content = await response.text() +# try: +# return await ArxivParser.parse_feed( +# content=content, +# url=str(response.url) +# ) +# except ParserException: +# raise +# except Exception as e: +# raise ParserException( +# url=str(response.url), +# message="解析响应失败", +# context=ParseErrorContext(raw_content=content), +# original_error=e, +# ) +# +# async def __aenter__(self) -> "ArxivClient": +# return self +# +# async def __aexit__( +# self, +# exc_type: Optional[type[BaseException]], +# exc_val: Optional[BaseException], +# exc_tb: Optional[TracebackType] +# ) -> None: +# await self._session_manager.close() +# +# async def close(self) -> None: +# """关闭客户端并清理资源""" +# await self._session_manager.close() diff --git a/src/client/arxiv_client.py b/src/client/arxiv_client.py new file mode 100644 index 0000000..a6fb652 --- /dev/null +++ b/src/client/arxiv_client.py @@ -0,0 +1,217 @@ +from collections.abc import AsyncGenerator +from pathlib import Path +from types import TracebackType +from typing import Optional + +from aiohttp import ClientResponse + +from ..config import ArxivConfig, default_config +from ..exception import ( + HTTPException, + ParseErrorContext, + ParserException, + QueryBuildError, + QueryContext, + SearchCompleteException, +) +from ..models import Paper, SearchParams +from ..utils import logger +from ..utils.parser import ArxivParser +from ..utils.session import SessionManager +from .base import BaseSearchManager +from .downloader import ArxivDownloader +from .search import ArxivSearchManager + + +class ArxivClient: + def __init__( + self, + config: Optional[ArxivConfig] = None, + session_manager: Optional[SessionManager] = None, + *, + search_manager_class: type[BaseSearchManager] = ArxivSearchManager, + ): + self._config = config or default_config + self._session_manager = session_manager or SessionManager(config=self._config) + self._search_manager = search_manager_class(self) + self._downloader = ArxivDownloader(self._session_manager) + + async def search( + self, + query: str, + max_results: Optional[int] = None, + ) -> AsyncGenerator[Paper, None]: + """执行搜索""" + params = SearchParams(query=query, max_results=max_results) + + try: + async with self._session_manager: + async for paper in self._search_manager.execute_search(params): + yield paper + except SearchCompleteException: + return + + async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse: + """ + 获取单页结果 + + Args: + params: 搜索参数 + start: 起始位置 + + Returns: + 响应对象 + + Raises: + QueryBuildError: 如果构建查询参数失败 + """ + try: + query_params = self._build_query_params(params, start) + response = await self._session_manager.request( + "GET", str(self._config.base_url), params=query_params + ) + + if response.status != 200: + logger.error("搜索请求失败", extra={"status_code": response.status}) + raise HTTPException(response.status) + + return response + + except QueryBuildError: + raise + except Exception as e: + logger.error("未预期的错误", exc_info=True) + raise QueryBuildError( + message="构建查询参数失败", + context=QueryContext( + params={"query": params.query, "start": start}, + field_name="query_params", + ), + original_error=e, + ) + + def _build_query_params(self, params: SearchParams, start: int) -> dict: + """ + 构建查询参数 + + Args: + params: 搜索参数 + start: 起始位置 + + Returns: + dict: 查询参数 + + Raises: + QueryBuildError: 如果构建查询参数失败 + """ + self._validate_params(params, start) + + try: + page_size = min( + self._config.page_size, + params.max_results - start if params.max_results else float("inf"), + ) + + query_params = { + "search_query": params.query, + "start": start, + "max_results": page_size, + } + + if params.sort_by: + query_params["sortBy"] = params.sort_by.value + + if params.sort_order: + query_params["sortOrder"] = params.sort_order.value + + return query_params + + except Exception as e: + raise QueryBuildError( + message="计算页面大小失败", + context=QueryContext( + params={ + "page_size": self._config.page_size, + "max_results": params.max_results, + "start": start, + }, + field_name="page_size", + ), + original_error=e, + ) + + def _validate_params(self, params: SearchParams, start: int) -> None: + """验证查询参数""" + if not params.query: + raise QueryBuildError( + message="搜索查询不能为空", + context=QueryContext( + params={"query": None}, field_name="query", constraint="required" + ), + ) + + if start < 0: + raise QueryBuildError( + message="起始位置不能为负数", + context=QueryContext( + params={"start": start}, + field_name="start", + constraint="non_negative", + ), + ) + + async def parse_response(self, response: ClientResponse) -> tuple[list[Paper], int]: + """ + 解析API响应 + + Args: + response: 响应对象 + + Returns: + 解析后的论文列表和总结果数 + + Raises: + ParserException: 如果解析失败 + """ + content = await response.text() + try: + return await ArxivParser.parse_feed(content=content, url=str(response.url)) + except ParserException: + raise + except Exception as e: + raise ParserException( + url=str(response.url), + message="解析响应失败", + context=ParseErrorContext(raw_content=content), + original_error=e, + ) + + async def download_paper( + self, paper: Paper, filename: Optional[str] = None + ) -> Path: + """ + 下载论文 + + Args: + paper: 论文对象 + filename: 文件名 + + Returns: + 下载文件的存放路径 + """ + return await self._downloader.download_paper(str(paper.pdf_url), filename) + + async def close(self) -> None: + """关闭客户端""" + await self._session_manager.close() + + async def __aenter__(self) -> "ArxivClient": + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() diff --git a/src/client/base.py b/src/client/base.py new file mode 100644 index 0000000..3c7c74a --- /dev/null +++ b/src/client/base.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from typing import Protocol + +from aiohttp import ClientResponse + +from ..config import ArxivConfig +from ..models import Paper, SearchParams + + +class ClientProtocol(Protocol): + """客户端协议""" + async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse: + ... + + async def parse_response( + self, + response: ClientResponse + ) -> tuple[list[Paper], int]: + ... + + @property + def _config(self) -> "ArxivConfig": + ... + +class BaseSearchManager(ABC): + """搜索管理器基类""" + @abstractmethod + def __init__(self, client: ClientProtocol): + pass + + @abstractmethod + def execute_search(self, params: SearchParams) -> AsyncGenerator[Paper, None]: + pass diff --git a/src/client/downloader.py b/src/client/downloader.py new file mode 100644 index 0000000..cda4469 --- /dev/null +++ b/src/client/downloader.py @@ -0,0 +1,100 @@ +from pathlib import Path +from types import TracebackType +from typing import Optional + +import aiofiles + +from ..utils import get_project_root +from ..utils.log import logger +from ..utils.session import SessionManager + + +class ArxivDownloader: + """arXiv论文下载器""" + def __init__( + self, + session_manager: Optional[SessionManager] = None, + download_dir: Optional[str] = None + ): + """ + 初始化下载器 + + Args: + session_manager: 会话管理器,可选 + download_dir: 下载目录,可选 + """ + project_root = get_project_root() + self._session_manager = session_manager + self._own_session = False + self.download_dir = Path(download_dir) if download_dir else project_root / "downloads" + self.download_dir.mkdir(parents=True, exist_ok=True) + + @property + def session_manager(self) -> SessionManager: + """懒加载会话管理器""" + if self._session_manager is None: + self._session_manager = SessionManager() + self._own_session = True + return self._session_manager + + async def download_paper( + self, + pdf_url: str, + filename: Optional[str] = None + ) -> Path: + """下载论文PDF文件""" + if not filename: + filename = pdf_url.split("/")[-1] + if not filename.endswith(".pdf"): + filename += ".pdf" + + file_path = self.download_dir / filename + temp_path = file_path.with_suffix(file_path.suffix + ".tmp") + + logger.info(f"开始下载论文: {pdf_url}") + try: + async with self.session_manager.rate_limited_context(): + # 先获取响应 + response = await self.session_manager.request("GET", pdf_url) + # 然后使用 async with 管理响应生命周期 + async with response: + response.raise_for_status() + + async with aiofiles.open(temp_path, "wb") as f: + total_size = int(response.headers.get("content-length", 0)) + downloaded_size = 0 + + async for chunk, _ in response.content.iter_chunks(): + if chunk: + await f.write(chunk) + downloaded_size += len(chunk) + + if 0 < total_size != downloaded_size: + raise RuntimeError( + f"下载不完整: 预期 {total_size} 字节, 实际下载 {downloaded_size} 字节") + + temp_path.rename(file_path) + + logger.info(f"下载完成: {file_path}") + return file_path + + except Exception as e: + logger.error(f"下载失败: {e!s}") + # 清理临时文件和目标文件 + if temp_path.exists(): + temp_path.unlink() + if file_path.exists(): + file_path.unlink() + raise + + async def __aenter__(self) -> "ArxivDownloader": + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType] + ) -> None: + if self._own_session and self._session_manager: + await self._session_manager.close() diff --git a/src/client/processor.py b/src/client/processor.py new file mode 100644 index 0000000..e13a03e --- /dev/null +++ b/src/client/processor.py @@ -0,0 +1,136 @@ +import sys +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Optional + +from anyio import create_memory_object_stream, create_task_group, sleep + +from ..config import ArxivConfig +from ..exception import SearchCompleteException +from ..models import Paper, SearchParams +from ..utils.log import logger + + +@dataclass +class BatchContext: + """批处理上下文""" + start: int + size: int + total_results: float + results_count: int + max_results: Optional[int] + +class ResultProcessor: + """处理搜索结果的处理器类""" + def __init__(self, concurrent_limit: int, config: ArxivConfig): + self.concurrent_limit = concurrent_limit + self.config = config + + async def create_streams(self): + """创建内存流""" + return create_memory_object_stream[tuple[list[Paper], int]]( + max_buffer_size=self.concurrent_limit + ) + + def calculate_batch_size(self, ctx: BatchContext) -> int: + """计算批次大小""" + return min( + ctx.size, + self.config.page_size, + ctx.max_results - ctx.results_count if ctx.max_results else ctx.size + ) + + def calculate_batch_end(self, ctx: BatchContext, batch_size: int) -> int: + """计算批次结束位置""" + return min( + ctx.start + batch_size, + int(ctx.total_results) if ctx.total_results != float("inf") else sys.maxsize + ) + +async def _iter_papers(self, params: SearchParams) -> AsyncGenerator[Paper, None]: + """迭代获取论文""" + # 初始化上下文 + concurrent_limit = min( + self._config.max_concurrent_requests, + self._config.rate_limit_calls + ) + processor = ResultProcessor(concurrent_limit, self._config) + ctx = BatchContext( + start=0, + size=concurrent_limit * self._config.page_size, + total_results=float("inf"), + results_count=0, + max_results=params.max_results + ) + + logger.info("开始搜索", extra={ + "query": params.query, + "max_results": params.max_results, + "concurrent_limit": concurrent_limit + }) + + send_stream, receive_stream = await processor.create_streams() + + async def fetch_page(page_start: int): + """获取单页数据""" + try: + async with self._session_manager.rate_limited_context(): + response = await self._fetch_page(params, page_start) + result = await self.parse_response(response) + await send_stream.send(result) + except Exception as e: + logger.error(f"获取页面失败: {e!s}", extra={"page": page_start}) + raise + + async def process_batch(batch_ctx: BatchContext): + """处理单个批次""" + try: + async with create_task_group() as batch_tg: + batch_size = processor.calculate_batch_size(batch_ctx) + end = processor.calculate_batch_end(batch_ctx, batch_size) + + for offset in range(batch_ctx.start, end): + if batch_ctx.max_results and offset >= batch_ctx.max_results: + break + batch_tg.start_soon(fetch_page, offset) + except Exception as e: + logger.error("批量获取失败", extra={ + "start": batch_ctx.start, + "size": batch_ctx.size, + "error": str(e) + }) + raise + + try: + async with create_task_group() as tg: + tg.start_soon(process_batch, ctx) + + async for papers, total in receive_stream: + # 更新总结果数 + ctx.total_results = total + + # 检查首次批次是否有结果 + if ctx.results_count == 0 and total == 0: + logger.info("未找到结果", extra={"query": params.query}) + raise SearchCompleteException(0) + + # 处理论文 + for paper in papers: + yield paper + ctx.results_count += 1 + if ctx.max_results and ctx.results_count >= ctx.max_results: + raise SearchCompleteException(ctx.results_count) + + # 准备下一批次 + ctx.start += len(papers) + if ctx.start < min( + int(ctx.total_results) if ctx.total_results != float("inf") + else sys.maxsize, + ctx.max_results or sys.maxsize + ): + await sleep(self._config.rate_limit_period) + tg.start_soon(process_batch, ctx) + + finally: + await receive_stream.aclose() + logger.info("搜索结束", extra={"total_results": ctx.results_count}) diff --git a/src/client/search.py b/src/client/search.py new file mode 100644 index 0000000..3e1de5b --- /dev/null +++ b/src/client/search.py @@ -0,0 +1,169 @@ +import sys +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any + +from anyio import create_memory_object_stream, create_task_group, sleep + +from ..exception import SearchCompleteException +from ..models import Paper, SearchParams +from ..utils.log import logger +from .base import BaseSearchManager, ClientProtocol + + +@dataclass +class SearchContext: + """搜索上下文""" + + params: SearchParams + start: int = 0 + total_results: float = float("inf") + results_count: int = 0 + first_batch: bool = True + + def update_with_papers(self, papers: list[Paper]) -> None: + """更新处理进度""" + self.start += len(papers) + self.results_count += len(papers) + + def should_continue(self) -> bool: + """检查是否应继续处理""" + max_total = ( + int(self.total_results) + if self.total_results != float("inf") + else sys.maxsize + ) + max_allowed = self.params.max_results or sys.maxsize + return self.start < min(max_total, max_allowed) + + def reached_limit(self) -> bool: + """检查是否达到最大结果限制""" + return ( + self.params.max_results is not None + and self.results_count >= self.params.max_results + ) + + +class ArxivSearchManager(BaseSearchManager): + def __init__(self, client: ClientProtocol): + self.client = client + self.config = client._config + self.concurrent_limit = min( + self.config.max_concurrent_requests, self.config.rate_limit_calls + ) + + def execute_search(self, params: SearchParams) -> AsyncGenerator[Paper, None]: + """执行搜索""" + + async def search_implementation(): + ctx = SearchContext(params=params) + send_stream, receive_stream = create_memory_object_stream[ + tuple[list[Paper], int] + ](max_buffer_size=self.concurrent_limit) + + try: + async with create_task_group() as tg: + await self._start_batch(tg, ctx, send_stream) + try: + async for paper in self._process_results( + ctx, tg, receive_stream + ): + yield paper + except SearchCompleteException: + # 正常的搜索完成,记录日志后继续传播 + logger.info(f"搜索完成,共获取{ctx.results_count}条结果") + raise + finally: + # 确保资源清理 + await send_stream.aclose() + await receive_stream.aclose() + + # 返回异步生成器 + return search_implementation() + + async def _start_batch(self, tg: Any, ctx: SearchContext, send_stream: Any) -> None: + """启动批量获取""" + tg.start_soon( + self._fetch_batch, + ctx, + self.concurrent_limit * self.config.page_size, + send_stream, + ) + + async def _fetch_batch( + self, ctx: SearchContext, size: int, send_stream: Any + ) -> None: + """获取一批论文""" + try: + async with create_task_group() as batch_tg: + batch_size = self._calculate_batch_size(ctx, size) + end = self._calculate_batch_end(ctx, batch_size) + + for offset in range(ctx.start, end): + if ctx.params.max_results and offset >= ctx.params.max_results: + break + batch_tg.start_soon(self._fetch_page, offset, ctx, send_stream) + except Exception as e: + logger.error( + "批量获取失败", + extra={"start": ctx.start, "size": size, "error": str(e)}, + ) + raise + + def _calculate_batch_size(self, ctx: SearchContext, size: int) -> int: + """计算批次大小""" + return min( + size, + self.config.page_size, + ctx.params.max_results - ctx.results_count + if ctx.params.max_results + else size, + ) + + def _calculate_batch_end(self, ctx: SearchContext, batch_size: int) -> int: + """计算批次结束位置""" + return min( + ctx.start + batch_size, + int(ctx.total_results) + if ctx.total_results != float("inf") + else sys.maxsize, + ) + + async def _fetch_page( + self, offset: int, ctx: SearchContext, send_stream: Any + ) -> None: + """获取并发送单页数据""" + try: + response = await self.client._fetch_page(ctx.params, offset) + result = await self.client.parse_response(response) + await send_stream.send(result) + except Exception as e: + logger.error(f"获取页面失败: {e!s}", extra={"page": offset}) + raise + + async def _process_results( + self, ctx: SearchContext, tg: Any, receive_stream: Any + ) -> AsyncGenerator[Paper, None]: + """处理搜索结果""" + async for papers, total in receive_stream: + # 更新总结果数 + ctx.total_results = total + + # 检查首次批次是否有结果 + if ctx.first_batch and total == 0: + logger.info("未找到结果", extra={"query": ctx.params.query}) + raise SearchCompleteException(0) + ctx.first_batch = False + + # 处理论文 + for paper in papers: + yield paper + ctx.results_count += 1 + if ctx.reached_limit(): + raise SearchCompleteException(ctx.results_count) + + # 更新进度并启动下一批次 + ctx.start += len(papers) + if ctx.should_continue(): + await sleep(self.config.rate_limit_period) + await self._start_batch(tg, ctx, receive_stream) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..630047b --- /dev/null +++ b/src/config.py @@ -0,0 +1,25 @@ +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field, HttpUrl + + +class ArxivConfig(BaseModel): + """arXiv API 配置类""" + + base_url: HttpUrl = Field( + default="http://export.arxiv.org/api/query", description="arXiv API 基础URL" + ) + timeout: float = Field(default=30.0, description="请求超时时间(秒)", gt=0) + max_retries: int = Field(default=3, description="最大重试次数", ge=0) + rate_limit_calls: int = Field(default=5, description="速率限制窗口内的最大请求数") + rate_limit_period: float = Field(default=1.0, description="速率限制窗口期(秒)") + max_concurrent_requests: int = Field(default=5, description="最大并发请求数") + proxy: Optional[str] = Field(default=None, description="HTTP/HTTPS代理URL") + log_level: str = Field(default="INFO", description="日志等级") + page_size: int = Field(default=10, description="每页结果数") + min_wait: float = Field(default=1.0, description="最小重试等待时间(秒)", gt=0) + + model_config = ConfigDict(extra="allow") + + +default_config = ArxivConfig() diff --git a/src/exception.py b/src/exception.py new file mode 100644 index 0000000..086604b --- /dev/null +++ b/src/exception.py @@ -0,0 +1,244 @@ +from dataclasses import dataclass +from http import HTTPStatus +from typing import Any, Optional + +from pydantic import BaseModel, HttpUrl + + +class ArxivException(Exception): + """基础异常类""" + + def __str__(self) -> str: + return super().__repr__() + + +class HTTPException(ArxivException): + """HTTP 请求相关异常""" + + def __init__(self, status_code: int, message: Optional[str] = None): + self.status_code = status_code + self.message = message or HTTPStatus(status_code).description + super().__init__(self.message) + + +class RateLimitException(HTTPException): + """达到 API 速率限制时的异常""" + + def __init__(self, retry_after: Optional[int] = None): + self.retry_after = retry_after + super().__init__(429, "Too Many Requests") + + +class ValidationException(ArxivException): + """数据验证异常""" + + def __init__( + self, + message: str, + field_name: str, + input_value: Any, + expected_type: type, + model: Optional[type[BaseModel]] = None, + validation_errors: Optional[dict] = None, + ): + self.field_name = field_name + self.input_value = input_value + self.expected_type = expected_type + self.model = model + self.validation_errors = validation_errors + super().__init__(message) + + def __str__(self) -> str: + error_msg = [ + f"Validation error for field '{self.field_name}':", + f"Input value: {self.input_value!r}", + f"Expected type: {self.expected_type.__name__}", + ] + + if self.model: + error_msg.append(f"Model: {self.model.__name__}") + + if self.validation_errors: + error_msg.append("Detailed errors:") + for key, err in self.validation_errors.items(): + error_msg.append(f" - {key}: {err}") + + return "\n".join(error_msg) + + +class TimeoutException(ArxivException): + """请求超时异常""" + + def __init__( + self, + timeout: float, + message: Optional[str] = None, + proxy: Optional[HttpUrl] = None, + link: Optional[HttpUrl] = None, + ): + self.timeout = timeout + self.proxy = proxy + self.link = link + self.message = message or f"Request timed out after {timeout} seconds" + super().__init__(message) + + def __str__(self) -> str: + error_msg = [ + f"Request timed out after {self.timeout} seconds", + self.message, + ] + + if self.proxy: + error_msg.append(f"Proxy: {self.proxy}") + + if self.link: + error_msg.append(f"Link: {self.link}") + + return "\n".join(error_msg) + + +class RetryError(ArxivException): + """重试次数用尽异常""" + + def __init__(self, attempts: int, last_error: Exception): + self.attempts = attempts + self.last_error = last_error + super().__init__(f"Failed after {attempts} attempts. Last error: {last_error}") + + +@dataclass +class ConfigError: + """配置错误详情""" + + property_name: str + input_value: Any + expected_type: type + message: str + + +class ConfigurationError(ArxivException): + """配置错误异常""" + + def __init__( + self, + message: str, + property_name: str, + input_value: Any, + expected_type: type, + config_class: Optional[type] = None, + ): + self.property_name = property_name + self.input_value = input_value + self.expected_type = expected_type + self.config_class = config_class + self.message = message + super().__init__(message) + + def __str__(self) -> str: + error_parts = [ + f"Configuration error for '{self.property_name}':", + f"Input value: {self.input_value!r}", + f"Expected type: {self.expected_type.__name__}", + f"Message: {self.message}", + ] + + if self.config_class: + error_parts.append(f"Config class: {self.config_class.__name__}") + + return "\n".join(error_parts) + + +@dataclass +class QueryContext: + """查询构建上下文""" + + params: dict[str, Any] # 查询参数 + field_name: Optional[str] = None # 出错的字段 + value: Optional[Any] = None # 问题值 + constraint: Optional[str] = None # 违反的约束 + + +class QueryBuildError(ArxivException): + """查询构建错误""" + + def __init__( + self, + message: str, + context: Optional[QueryContext] = None, + original_error: Optional[Exception] = None, + ): + self.message = message + self.context = context + self.original_error = original_error + super().__init__(message) + + def __str__(self) -> str: + parts = [f"查询构建错误: {self.message}"] + + if self.context: + parts.append("查询参数:") + for k, v in self.context.params.items(): + parts.append(f" {k}: {v}") + + if self.context.field_name: + parts.append(f"问题字段: {self.context.field_name}") + if self.context.value is not None: + parts.append(f"问题值: {self.context.value}") + if self.context.constraint: + parts.append(f"违反约束: {self.context.constraint}") + + if self.original_error: + parts.append(f"原始错误: {self.original_error!s}") + + return "\n".join(parts) + + +@dataclass +class ParseErrorContext: + """解析错误上下文""" + + raw_content: Optional[str] = None + position: Optional[int] = None + element_name: Optional[str] = None + namespace: Optional[str] = None + + +class ParserException(Exception): + """XML解析异常""" + + def __init__( + self, + url: str, + message: str, + context: Optional[ParseErrorContext] = None, + original_error: Optional[Exception] = None, + ): + self.url = url + self.message = message + self.context = context + self.original_error = original_error + super().__init__(self.message) + + def __str__(self) -> str: + parts = [f"解析错误: {self.message}", f"URL: {self.url}"] + + if self.context: + if self.context.element_name: + parts.append(f"元素: {self.context.element_name}") + if self.context.namespace: + parts.append(f"命名空间: {self.context.namespace}") + if self.context.position is not None: + parts.append(f"位置: {self.context.position}") + if self.context.raw_content: + parts.append(f"原始内容: \n{self.context.raw_content[:200]}...") + + if self.original_error: + parts.append(f"原始错误: {self.original_error!s}") + + return "\n".join(parts) + +class SearchCompleteException(ArxivException): + """搜索完成异常""" + def __init__(self, total_results: int): + self.total_results = total_results + super().__init__(f"搜索完成,共获取{total_results}条结果") diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..353bea1 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,61 @@ +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field, HttpUrl + + +class SortCriterion(str, Enum): + """排序标准""" + + RELEVANCE = "relevance" + LAST_UPDATED = "lastUpdatedDate" + SUBMITTED = "submittedDate" + + +class SortOrder(str, Enum): + """排序方向""" + + ASCENDING = "ascending" + DESCENDING = "descending" + + +class Author(BaseModel): + """作者模型""" + + name: str + affiliation: Optional[str] = None + + +class Category(BaseModel): + """分类模型""" + + primary: str + secondary: list[str] = Field(default_factory=list) + + +class Paper(BaseModel): + """论文模型""" + + id: str = Field(description="arXiv ID") + title: str + summary: str + authors: list[Author] + categories: Category + doi: Optional[str] = None + journal_ref: Optional[str] = None + pdf_url: Optional[HttpUrl] = None + published: datetime + updated: datetime + comment: Optional[str] = None + + +class SearchParams(BaseModel): + """搜索参数""" + + query: str + id_list: list[str] = Field(default_factory=list) + max_results: Optional[int] = Field(default=None, gt=0) + start: int = Field(default=0, ge=0) + sort_by: Optional[SortCriterion] = None + sort_order: Optional[SortOrder] = None diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..90f2a64 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,44 @@ +from pathlib import Path +from time import monotonic +from types import SimpleNamespace + +import aiohttp + +from .log import logger + + +def create_trace_config() -> aiohttp.TraceConfig: + """ + 创建请求追踪配置。 + + Returns: + aiohttp.TraceConfig: 请求追踪配置 + """ + + async def _on_request_start( + session: aiohttp.ClientSession, + trace_config_ctx: SimpleNamespace, + params: aiohttp.TraceRequestStartParams, + ) -> None: + logger.debug(f"Starting request: {params.method} {params.url}") + trace_config_ctx.start_time = monotonic() + + async def _on_request_end( + session: aiohttp.ClientSession, + trace_config_ctx: SimpleNamespace, + params: aiohttp.TraceRequestEndParams, + ) -> None: + elapsed_time = monotonic() - trace_config_ctx.start_time + logger.debug( + f"Ending request: {params.response.status} {params.url} - Time elapsed: " + f"{elapsed_time:.2f} seconds" + ) + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(_on_request_start) + trace_config.on_request_end.append(_on_request_end) + return trace_config + +def get_project_root() -> Path: + """获取项目根目录""" + return Path(__file__).parent.parent.parent diff --git a/src/utils/log.py b/src/utils/log.py new file mode 100644 index 0000000..c983e53 --- /dev/null +++ b/src/utils/log.py @@ -0,0 +1,75 @@ +import inspect +import logging +import sys +from typing import TYPE_CHECKING + +import loguru + +if TYPE_CHECKING: + # avoid sphinx autodoc resolve annotation failed + # because loguru module do not have `Logger` class actually + from loguru import Logger, Record + +logger: "Logger" = loguru.logger +"""日志记录器对象。 + +default: + +- 格式: `[%(asctime)s %(name)s] %(levelname)s: %(message)s` +- 等级: `INFO` ,根据 `config.log_level` 配置改变 +- 输出: 输出至 stdout + +usage: + ```python + from log import logger + ``` +""" + + +# https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging +class LoguruHandler(logging.Handler): # pragma: no cover + """logging 与 loguru 之间的桥梁,将 logging 的日志转发到 loguru。""" + + def emit(self, record: logging.LogRecord): + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = inspect.currentframe(), 0 + while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +def default_filter(record: "Record"): + """默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。""" + log_level = record["extra"].get("log_level", "INFO") + levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level + return record["level"].no >= levelno + + +default_format: str = ( + "{time:MM-DD HH:mm:ss} " + "[{level}] " + "{name} | " + # "{function}:{line}| " + "{message}" +) +"""默认日志格式""" + +logger.remove() +logger_id = logger.add( + sys.stdout, + level=0, + diagnose=False, + filter=default_filter, + format=default_format, +) +"""默认日志处理器 id""" + +__autodoc__ = {"logger_id": False} diff --git a/src/utils/parser.py b/src/utils/parser.py new file mode 100644 index 0000000..a2acbe2 --- /dev/null +++ b/src/utils/parser.py @@ -0,0 +1,339 @@ +import xml.etree.ElementTree as ET +from datetime import datetime +from typing import ClassVar, Optional, cast + +from pydantic import HttpUrl + +from ..exception import ParseErrorContext, ParserException +from ..models import Author, Category, Paper +from .log import logger + + +class ArxivParser: + """ + arXiv API响应解析器 + + Attributes: + NS (ClassVar[dict[str, str]]): XML命名空间 + + Args: + entry (ET.Element): 根元素 + + Raises: + ParserException: 如果解析失败 + """ + NS: ClassVar[dict[str, str]] = { + "atom": "http://www.w3.org/2005/Atom", + "opensearch": "http://a9.com/-/spec/opensearch/1.1/", + "arxiv": "http://arxiv.org/schemas/atom", + } + def __init__(self, entry: ET.Element): + self.entry = entry + + def _create_parser_exception( + self, message: str, url: str = "", error: Optional[Exception] = None + ) -> ParserException: + """创建解析异常""" + return ParserException( + url=url, + message=message, + context=ParseErrorContext( + raw_content=ET.tostring(self.entry, encoding="unicode"), + element_name=self.entry.tag, + ), + original_error=error, + ) + + def parse_authors(self) -> list[Author]: + """ + 解析作者信息 + + Returns: + list[Author]: 作者列表 + + Raises: + ParserException: 如果作者信息无效 + """ + logger.debug("开始解析作者信息") + authors = [] + for author_elem in self.entry.findall("atom:author", self.NS): + name = author_elem.find("atom:name", self.NS) + if name is not None and name.text: + authors.append(Author(name=name.text)) + else: + logger.warning("发现作者信息不完整") + return authors + + def parse_entry_id(self) -> str: + """解析论文ID + + Returns: + str: 论文ID + + Raises: + ParserException: 如果ID元素缺失或无效 + """ + id_elem = self.entry.find("atom:id", self.NS) + if id_elem is None or id_elem.text is None: + raise ParserException( + url="", + message="缺少论文ID", + context=ParseErrorContext( + raw_content=ET.tostring(self.entry, encoding="unicode"), + element_name="id", + namespace=self.NS["atom"], + ), + ) + + return id_elem.text + + def parse_categories(self) -> Category: + """ + 解析分类信息 + + Returns: + Category: 分类信息 + + Raises: + ParserException: 如果分类信息无效 + """ + logger.debug("开始解析分类信息") + primary = self.entry.find("arxiv:primary_category", self.NS) + categories = self.entry.findall("atom:category", self.NS) + + if primary is None or "term" not in primary.attrib: + logger.warning("未找到主分类信息,使用默认分类") + primary_category = "unknown" + else: + primary_category = primary.attrib["term"] + + return Category( + primary=primary_category, + secondary=[c.attrib["term"] for c in categories if "term" in c.attrib], + ) + + def parse_required_fields(self) -> dict: + """ + 解析必要字段 + + Returns: + dict: 必要字段字典 + + Raises: + ParserException: 如果字段缺失 + """ + fields = { + "title": self.entry.find("atom:title", self.NS), + "summary": self.entry.find("atom:summary", self.NS), + "published": self.entry.find("atom:published", self.NS), + "updated": self.entry.find("atom:updated", self.NS), + } + + missing = [k for k, v in fields.items() if v is None or v.text is None] + if missing: + raise self._create_parser_exception( + f"缺少必要字段: {', '.join(missing)}" + ) + + return { + k: v.text for k, v in fields.items() if v is not None and v.text is not None + } + + def _parse_pdf_url(self) -> Optional[str]: + """ + 解析PDF链接 + + Returns: + Optional[str]: PDF链接或None + + Raises: + ParserException: 如果PDF链接无效 + """ + try: + links = self.entry.findall("atom:link", self.NS) + if not links: + logger.warning("未找到任何链接") + return None + + pdf_url = next( + ( + link.attrib["href"] + for link in links + if link.attrib.get("type") == "application/pdf" + ), + None, + ) + + if pdf_url is None: + logger.warning("未找到PDF链接") + + return pdf_url + + except (KeyError, AttributeError) as e: + logger.error("解析PDF链接失败", exc_info=True) + raise ParserException( + url="", + message="解析PDF链接失败", + context=ParseErrorContext( + raw_content=ET.tostring(self.entry, encoding="unicode"), + element_name="link", + namespace=self.NS["atom"], + ), + original_error=e, + ) + + def parse_optional_fields(self) -> dict: + """ + 解析可选字段 + + Returns: + dict: 可选字段字典 + + Raises: + ParserException: 如果字段无效 + """ + fields = { + "comment": self.entry.find("arxiv:comment", self.NS), + "journal_ref": self.entry.find("arxiv:journal_ref", self.NS), + "doi": self.entry.find("arxiv:doi", self.NS), + } + + return {k: v.text if v is not None else None for k, v in fields.items()} + + @staticmethod + def parse_datetime(date_str: str) -> datetime: + """ + 解析ISO格式的日期时间字符串 + + Args: + date_str: ISO格式的日期时间字符串 + + Returns: + datetime: 解析后的datetime对象 + + Raises: + ValueError: 日期格式无效 + """ + try: + normalized_date = date_str.replace("Z", "+00:00") + return datetime.fromisoformat(normalized_date) + except ValueError as e: + logger.error(f"日期解析失败: {date_str}", exc_info=True) + raise ValueError(f"无效的日期格式: {date_str}") from e + + def build_paper( + self, + index: int, + ) -> Paper: + """统一处理论文解析""" + try: + required_fields = self.parse_required_fields() + return Paper( + id=self.parse_entry_id().split("/")[-1], + title=required_fields["title"], + summary=required_fields["summary"], + authors=self.parse_authors(), + categories=self.parse_categories(), + pdf_url=cast(HttpUrl, self._parse_pdf_url()), + published=self.parse_datetime( + required_fields["published"].replace("Z", "+00:00") + ), + updated=self.parse_datetime( + required_fields["updated"].replace("Z", "+00:00") + ), + **self.parse_optional_fields(), + ) + except ParserException: + raise + except Exception as e: + raise ParserException( + url="", + message=f"解析第 {index + 1} 篇论文失败", + context=ParseErrorContext( + raw_content=ET.tostring(self.entry, encoding="unicode"), + position=index, + element_name=self.entry.tag, + ), + original_error=e, + ) + + @classmethod + def _parse_root( + cls, + root: ET.Element, + url: str + ) -> tuple[list[Paper], int]: + """解析根元素""" + # 解析总结果数 + total_element = root.find("opensearch:totalResults", cls.NS) + if total_element is None or total_element.text is None: + raise ParserException( + url=url, + message="缺少总结果数元素", + context=ParseErrorContext( + raw_content=ET.tostring(root, encoding="unicode"), + element_name="totalResults", + namespace=cls.NS["opensearch"], + ), + ) + + total_results = int(total_element.text) + + # 解析论文列表 + papers = [] + for i, entry in enumerate(root.findall("atom:entry", cls.NS)): + try: + parser = cls(entry) + papers.append(parser.build_paper(i)) + except Exception as e: + raise ParserException( + url=url, + message=f"解析第 {i + 1} 篇论文失败", + context=ParseErrorContext( + raw_content=ET.tostring(entry, encoding="unicode"), + position=i, + element_name=entry.tag, + namespace=cls.NS["atom"], + ), + original_error=e, + ) + + return papers, total_results + + @classmethod + async def parse_feed( + cls, + content: str, + url: str = "" + ) -> tuple[list[Paper], int]: + """ + 解析arXiv API的Atom feed内容 + + Args: + content: XML内容 + url: 请求URL,用于错误上下文 + + Returns: + tuple[list[Paper], int]: 论文列表和总结果数 + """ + logger.debug("开始解析feed内容") + try: + root = ET.fromstring(content) + return cls._parse_root(root, url) + except ET.ParseError as e: + logger.error("XML格式错误", exc_info=True) + raise ParserException( + url=url, + message="XML格式错误", + context=ParseErrorContext(raw_content=content), + original_error=e, + ) + except ParserException: + raise + except Exception as e: + raise ParserException( + url=url, + message="未知解析错误", + context=ParseErrorContext(raw_content=content), + original_error=e, + ) diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py new file mode 100644 index 0000000..4c816ea --- /dev/null +++ b/src/utils/rate_limiter.py @@ -0,0 +1,137 @@ +import asyncio +from dataclasses import dataclass +from types import TracebackType +from typing import ClassVar, Optional + +from ..config import default_config +from .log import logger + + +@dataclass +class RateLimitState: + """速率限制状态""" + + remaining: int + reset_at: float + window_start: float + + +class RateLimiter: + """ + 速率限制器 + + 用于限制请求速率,防止过多请求导致服务器拒绝服务。 + + Attributes: + DEFAULT_CALLS: ClassVar[int]: 默认窗口期内的最大请求数 + DEFAULT_PERIOD: ClassVar[float]: 默认窗口期 + calls: int: 窗口期内的最大请求数 + period: float: 窗口期 + timestamps: list[float]: 请求时间戳列表 + _lock: asyncio.Lock: 锁 + _last_check: Optional[float]: 上次检查时间 + _logger: logging.Logger: 日志记录器 + """ + + # 从配置获取默认值 + DEFAULT_CALLS: ClassVar[int] = default_config.rate_limit_calls + DEFAULT_PERIOD: ClassVar[float] = default_config.rate_limit_period + + def __init__(self, calls: int = DEFAULT_CALLS, period: float = DEFAULT_PERIOD): + """ + 初始化速率限制器 + + Args: + calls: 窗口期内的最大请求数,默认从配置获取 + period: 窗口期,默认从配置获取 + """ + if calls <= 0: + raise ValueError("calls must be positive") + if period <= 0: + raise ValueError("period must be positive") + self.calls = calls + self.period = period + self.timestamps: list[float] = [] + self._lock = asyncio.Lock() + self._last_check: Optional[float] = None + self._logger = logger + + @property + def is_limited(self) -> bool: + """当前是否处于限制状态""" + # 使用当前时间作为参考点 + loop = asyncio.get_running_loop() + now = loop.time() + + # 获取有效时间戳并更新 + valid_stamps = [t for t in self.timestamps if (now - t) < self.period] + self.timestamps = valid_stamps + + return len(valid_stamps) >= self.calls + + def _get_valid_timestamps(self, now: float) -> list[float]: + """获取有效的时间戳列表""" + return [t for t in self.timestamps if now - t <= self.period] + + @property + async def state(self) -> RateLimitState: + """获取当前速率限制状态""" + async with self._lock: + now = asyncio.get_event_loop().time() + valid_timestamps = self._get_valid_timestamps(now) + + return RateLimitState( + remaining=max(0, self.calls - len(valid_timestamps)), + reset_at=min(self.timestamps, default=now) + self.period + if self.timestamps + else now, + window_start=now, + ) + + async def acquire(self) -> None: + """获取访问许可""" + async with self._lock: + now = asyncio.get_event_loop().time() + + self.timestamps = self._get_valid_timestamps(now) + + # 检查是否需要等待 + if len(self.timestamps) >= self.calls: + sleep_time = self.timestamps[0] + self.period - now + if sleep_time > 0: + self._logger.debug( + "触发速率限制", + extra={ + "wait_time": f"{sleep_time:.2f}s", + "current_calls": len(self.timestamps), + "max_calls": self.calls, + }, + ) + await asyncio.sleep(sleep_time) + + self.timestamps.append(now) + self._last_check = now + + self._logger.debug( + "获取访问许可", + extra={ + "remaining_calls": self.calls - len(self.timestamps), + "window_reset_in": f"{(self.timestamps[0] + self.period - now):.2f}s" + if self.timestamps + else "0s", + }, + ) + + async def __aenter__(self) -> "RateLimiter": + """进入速率限制上下文""" + await self.acquire() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """退出速率限制上下文""" + pass diff --git a/src/utils/session.py b/src/utils/session.py new file mode 100644 index 0000000..d5593e7 --- /dev/null +++ b/src/utils/session.py @@ -0,0 +1,121 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from types import TracebackType +from typing import Optional + +from aiohttp import ( + ClientResponse, + ClientSession, + ClientTimeout, + TraceConfig, +) + +from ..config import ArxivConfig, default_config +from ..utils import create_trace_config +from .log import logger +from .rate_limiter import RateLimiter + + +class SessionManager: + def __init__( + self, + config: Optional[ArxivConfig] = None, + session: Optional[ClientSession] = None, + rate_limiter: Optional[RateLimiter] = None, + trace_config: Optional[TraceConfig] = None, + ): + """ + 初始化会话管理器 + + Args: + config: arXiv API配置对象 + session: aiohttp会话 + rate_limiter: 速率限制器 + trace_config: 请求追踪配置 + """ + self._config = config or default_config + self._timeout = ClientTimeout( + total=self._config.timeout, + ) + self._session = session + self._rate_limiter = rate_limiter or RateLimiter( + calls=self._config.rate_limit_calls, + period=self._config.rate_limit_period, + ) + self._trace_config = trace_config or create_trace_config() + + @asynccontextmanager + async def rate_limited_context(self) -> AsyncIterator[None]: + """获取速率限制上下文""" + limiter = RateLimiter() + await limiter.acquire() + yield + + async def get_session(self) -> ClientSession: + """ + 获取或创建会话 + + Returns: + ClientSession: aiohttp会话 + """ + if self._session is None or self._session.closed: + self._session = ClientSession( + timeout=self._timeout, + trace_configs=[self._trace_config], + ) + return self._session + + async def request(self, method: str, url: str, **kwargs) -> ClientResponse: + """ + 发送受速率限制的请求 + + Args: + method: HTTP 方法 + url: 请求URL + **kwargs: 传递给 aiohttp.ClientSession.request 的额外参数 + + Returns: + ClientResponse: aiohttp 响应对象 + """ + session = await self.get_session() + + if self._rate_limiter: + await self._rate_limiter.acquire() + + if self._config.proxy: + logger.debug(f"使用代理: {self._config.proxy}") + + return await session.request( + method, url, proxy=self._config.proxy or None, **kwargs + ) + + async def close(self) -> None: + """关闭会话""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def __aenter__(self) -> "SessionManager": + """ + 进入会话管理器 + + Returns: + SessionManager: 会话管理器实例 + """ + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """ + 退出会话管理器 + + Args: + exc_type: 异常类型 + exc_val: 异常值 + exc_tb: 异常回溯 + """ + await self.close() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_client/__init__.py b/tests/test_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_rate_limiter.py b/tests/utils/test_rate_limiter.py new file mode 100644 index 0000000..fae1995 --- /dev/null +++ b/tests/utils/test_rate_limiter.py @@ -0,0 +1,92 @@ +import asyncio + +import pytest + +from src.utils.rate_limiter import RateLimiter + + +@pytest.fixture +def rate_limiter(): + return RateLimiter(calls=3, period=1.0) + + +class TestRateLimiter: + def test_init_with_invalid_params(self): + """测试无效的初始化参数""" + with pytest.raises(ValueError, match="calls must be positive"): + RateLimiter(calls=0) + + with pytest.raises(ValueError, match="period must be positive"): + RateLimiter(period=0) + + @pytest.mark.asyncio + async def test_acquire_within_limit(self, rate_limiter): + """测试在限制范围内获取许可""" + for _ in range(rate_limiter.calls): + await rate_limiter.acquire() + assert len(rate_limiter.timestamps) == rate_limiter.calls + + @pytest.mark.asyncio + async def test_acquire_exceeds_limit(self, rate_limiter): + """测试超出限制时的等待""" + start_time = asyncio.get_event_loop().time() + + # 先用完配额 + for _ in range(rate_limiter.calls): + await rate_limiter.acquire() + + # 下一次请求应该等待 + await rate_limiter.acquire() + + elapsed = asyncio.get_event_loop().time() - start_time + assert elapsed >= rate_limiter.period + + @pytest.mark.asyncio + async def test_state(self, rate_limiter): + """测试速率限制状态""" + # 初始状态 + state = await rate_limiter.state + assert state.remaining == rate_limiter.calls + + # 使用一个许可后的状态 + await rate_limiter.acquire() + state = await rate_limiter.state + assert state.remaining == rate_limiter.calls - 1 + + @pytest.mark.asyncio + async def test_context_manager(self, rate_limiter): + """测试上下文管理器""" + async with rate_limiter as limiter: + assert isinstance(limiter, RateLimiter) + assert len(limiter.timestamps) == 1 + + @pytest.mark.asyncio + async def test_window_sliding(self, rate_limiter): + """测试时间窗口滑动""" + # 填满时间窗口 + for _ in range(rate_limiter.calls): + await rate_limiter.acquire() + + # 等待窗口滑动 + await asyncio.sleep(rate_limiter.period) + + # 应该可以立即获取新的许可 + start = asyncio.get_event_loop().time() + await rate_limiter.acquire() + elapsed = asyncio.get_event_loop().time() - start + assert elapsed < 0.1 # 不应该有明显延迟 + + @pytest.mark.asyncio + async def test_is_limited(self, rate_limiter): + """测试是否处于限制状态""" + # 初始状态未受限 + assert not rate_limiter.is_limited + + # 填满时间窗口 + for _ in range(rate_limiter.calls): + await rate_limiter.acquire() + assert rate_limiter.is_limited + + # 等待窗口滑动 + await asyncio.sleep(rate_limiter.period) + assert not rate_limiter.is_limited diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py new file mode 100644 index 0000000..4493491 --- /dev/null +++ b/tests/utils/test_session.py @@ -0,0 +1,154 @@ +# tests/utils/test_session.py + +import pytest +from aiohttp import ClientResponse, ClientSession, ClientTimeout +from pytest_mock import MockerFixture + +from src.config import ArxivConfig +from src.utils.rate_limiter import RateLimiter +from src.utils.session import SessionManager + + +@pytest.fixture +def config(): + return ArxivConfig( + timeout=30, + rate_limit_calls=3, + rate_limit_period=1, + ) + + +@pytest.fixture +def rate_limiter(): + return RateLimiter(calls=3, period=1) + + +@pytest.fixture +async def session_manager(config, rate_limiter): + manager = SessionManager(config=config, rate_limiter=rate_limiter) + yield manager + await manager.close() + + +@pytest.mark.asyncio +async def test_init_with_config(config): + """测试使用配置初始化""" + manager = SessionManager(config=config) + assert manager._config == config + assert isinstance(manager._timeout, ClientTimeout) + assert manager._timeout.total == config.timeout + + +@pytest.mark.asyncio +async def test_get_session(session_manager): + """测试获取会话""" + session = await session_manager.get_session() + assert isinstance(session, ClientSession) + assert not session.closed + + +@pytest.mark.asyncio +async def test_reuse_existing_session(session_manager): + """测试复用现有会话""" + session1 = await session_manager.get_session() + session2 = await session_manager.get_session() + assert session1 is session2 + + +@pytest.mark.asyncio +async def test_rate_limited_context(session_manager, mocker: MockerFixture): + """测试速率限制上下文""" + mock_limiter = mocker.patch.object(RateLimiter, "acquire") + + async with session_manager.rate_limited_context(): + pass + + mock_limiter.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_close(session_manager): + """测试关闭会话""" + session = await session_manager.get_session() + assert not session.closed + + await session_manager.close() + assert session.closed + assert session_manager._session is None + + +@pytest.mark.asyncio +async def test_context_manager(session_manager): + """测试上下文管理器""" + async with session_manager as manager: + session = await manager.get_session() + assert not session.closed + + assert session.closed + + +@pytest.mark.asyncio +async def test_request(session_manager, mocker: MockerFixture): + """测试请求方法(包含速率限制)""" + # 1. 模拟响应 + mock_response = mocker.AsyncMock(spec=ClientResponse) + + # 2. 模拟会话 + mock_session = mocker.AsyncMock(spec=ClientSession) + mock_session.request = mocker.AsyncMock(return_value=mock_response) + + # 3. 模拟 get_session + async def mock_get_session(): + return mock_session + + session_manager.get_session = mock_get_session + + # 4. 模拟速率限制器 + mock_acquire = mocker.AsyncMock() + session_manager._rate_limiter.acquire = mock_acquire + + # 5. 执行请求 + response = await session_manager.request("GET", "http://example.com") + + # 6. 验证 + assert response == mock_response + mock_acquire.assert_awaited_once() + mock_session.request.assert_called_once_with( + "GET", "http://example.com", proxy=None + ) + + +@pytest.mark.asyncio +async def test_request_creates_new_session_if_closed(session_manager): + """测试请求时如果会话已关闭则创建新会话""" + session1 = await session_manager.get_session() + await session1.close() + + session2 = await session_manager.get_session() + assert session2 is not session1 + assert not session2.closed + + +@pytest.mark.asyncio +async def test_request_with_proxy(session_manager, mocker: MockerFixture): + """测试请求方法(包含代理)""" + mock_response = mocker.AsyncMock(spec=ClientResponse) + mock_session = mocker.AsyncMock(spec=ClientSession) + mock_session.request = mocker.AsyncMock(return_value=mock_response) + + async def mock_get_session(): + return mock_session + + session_manager.get_session = mock_get_session + + mock_acquire = mocker.AsyncMock() + session_manager._rate_limiter.acquire = mock_acquire + + session_manager._config.proxy = "http://proxy.com" + response = await session_manager.request("GET", "http://example.com") + + assert response == mock_response + mock_acquire.assert_awaited_once() + mock_session.request.assert_called_once_with( + "GET", "http://example.com", proxy="http://proxy.com" + )