diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 0000000..0d064c9
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,67 @@
+name: Bug 反馈
+title: "Bug: 出现异常"
+description: 提交 Bug 反馈以帮助我们改进代码
+labels: ["bug"]
+body:
+ - type: dropdown
+ id: env-os
+ attributes:
+ label: 操作系统
+ description: 选择运行 aioarxiv 的系统
+ options:
+ - Windows
+ - MacOS
+ - Linux
+ - Other
+ validations:
+ required: true
+
+ - type: input
+ id: env-python-ver
+ attributes:
+ label: Python 版本
+ description: 填写运行 aioarxiv 的 Python 版本
+ placeholder: e.g. 3.11.0
+ validations:
+ required: true
+
+ - type: input
+ id: env-nb-ver
+ attributes:
+ label: aioarxiv 版本
+ description: 填写 aioarxiv 版本
+ placeholder: e.g. 0.1.0
+ validations:
+ required: true
+
+ - type: textarea
+ id: describe
+ attributes:
+ label: 描述问题
+ description: 清晰简洁地说明问题是什么
+ validations:
+ required: true
+
+ - type: textarea
+ id: reproduction
+ attributes:
+ label: 复现步骤
+ description: 提供能复现此问题的详细操作步骤
+ placeholder: |
+ 1. 首先……
+ 2. 然后……
+ 3. 发生……
+ validations:
+ required: true
+
+ - type: textarea
+ id: expected
+ attributes:
+ label: 期望的结果
+ description: 清晰简洁地描述你期望发生的事情
+
+ - type: textarea
+ id: logs
+ attributes:
+ label: 截图或日志
+ description: 提供有助于诊断问题的任何日志和截图
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000..ec4bb38
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: false
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/document.yml b/.github/ISSUE_TEMPLATE/document.yml
new file mode 100644
index 0000000..352a13b
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/document.yml
@@ -0,0 +1,18 @@
+name: 文档改进
+title: "Docs: 描述"
+description: 文档错误及改进意见反馈
+labels: ["documentation"]
+body:
+ - type: textarea
+ id: problem
+ attributes:
+ label: 描述问题或主题
+ validations:
+ required: true
+
+ - type: textarea
+ id: improve
+ attributes:
+ label: 需做出的修改
+ validations:
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 0000000..4f2e79f
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,20 @@
+name: 功能建议
+title: "Feature: 功能描述"
+description: 提出关于项目新功能的想法
+labels: ["enhancement"]
+body:
+ - type: textarea
+ id: problem
+ attributes:
+ label: 希望能解决的问题
+ description: 在使用中遇到什么问题而需要新的功能?
+ validations:
+ required: true
+
+ - type: textarea
+ id: feature
+ attributes:
+ label: 描述所需要的功能
+ description: 请说明需要的功能或解决方法
+ validations:
+ required: true
diff --git a/.github/actions/build-api-doc/action.yml b/.github/actions/build-api-doc/action.yml
new file mode 100644
index 0000000..efe4256
--- /dev/null
+++ b/.github/actions/build-api-doc/action.yml
@@ -0,0 +1,36 @@
+name: Build documentation
+description: Build documentation
+
+on:
+ push:
+ branches: ["main"]
+ workflow_dispatch:
+
+env:
+ INSTANCE: 'Writerside/hi'
+ ARTIFACT: 'webHelpHI2-all.zip'
+ DOCKER_VERSION: '243.21565'
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Build docs using Writerside Docker builder
+ uses: JetBrains/writerside-github-action@v4
+ with:
+ instance: ${{ env.INSTANCE }}
+ artifact: ${{ env.ARTIFACT }}
+ docker-version: ${{ env.DOCKER_VERSION }}
+
+ - name: Save artifact with build results
+ uses: actions/upload-artifact@v4
+ with:
+ name: docs
+ path: |
+ artifacts/${{ env.ARTIFACT }}
+ retention-days: 7
\ No newline at end of file
diff --git a/.github/actions/setup-node/action.yml b/.github/actions/setup-node/action.yml
new file mode 100644
index 0000000..24ad58a
--- /dev/null
+++ b/.github/actions/setup-node/action.yml
@@ -0,0 +1,13 @@
+name: Setup Node
+description: Setup Node
+
+runs:
+ using: "composite"
+ steps:
+ - uses: actions/setup-node@v4
+ with:
+ node-version: "18"
+ cache: "yarn"
+
+ - run: yarn install --frozen-lockfile
+ shell: bash
diff --git a/.github/actions/setup-python/action.yml b/.github/actions/setup-python/action.yml
new file mode 100644
index 0000000..e191b53
--- /dev/null
+++ b/.github/actions/setup-python/action.yml
@@ -0,0 +1,37 @@
+name: Setup Python
+description: Setup Python
+
+inputs:
+ python-version:
+ description: Python version
+ required: false
+ default: "3.10"
+ env-dir:
+ description: Environment directory
+ required: false
+ default: "."
+ no-root:
+ description: Do not install package in the environment
+ required: false
+ default: "false"
+
+runs:
+ using: "composite"
+ steps:
+ - name: Install pdm
+ run: pipx install pdm
+ shell: bash
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ inputs.python-version }}
+ cache: "pdm"
+ cache-dependency-path: |
+ ./pdm.lock
+ ${{ inputs.env-dir }}/pdm.lock
+
+ - run: |
+ cd ${{ inputs.env-dir }}
+ pdm install --all
+ fi
+ shell: bash
diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml
new file mode 100644
index 0000000..5539984
--- /dev/null
+++ b/.github/workflows/codecov.yml
@@ -0,0 +1,54 @@
+name: Code Coverage
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ paths:
+ - "envs/**"
+ - "src/**"
+ - "tests/**"
+ - ".github/workflows/codecov.yml"
+ - "pyproject.toml"
+ - "pdm.lock"
+
+jobs:
+ test:
+ name: Test Coverage
+ runs-on: ${{ matrix.os }}
+ concurrency:
+ group: test-coverage-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.env }}
+ cancel-in-progress: true
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ env:
+ OS: ${{ matrix.os }}
+ PYTHON_VERSION: ${{ matrix.python-version }}
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Setup Python environment
+ uses: ./.github/actions/setup-python
+ with:
+ python-version: ${{ matrix.python-version }}
+ env-dir: ./envs/${{ matrix.env }}
+ no-root: true
+
+ - name: Run Pytest
+ run: |
+ cd ./envs/${{ matrix.env }}
+ poetry run bash "../../scripts/run-tests.sh"
+
+ - name: Upload coverage report
+ uses: codecov/codecov-action@v4
+ with:
+ env_vars: OS,PYTHON_VERSION,PYDANTIC_VERSION
+ files: ./tests/coverage.xml
+ flags: unittests
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml
new file mode 100644
index 0000000..a0dafd5
--- /dev/null
+++ b/.github/workflows/pyright.yml
@@ -0,0 +1,35 @@
+name: Pyright Lint
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ paths:
+ - "envs/**"
+ - "src/**"
+ - "tests/**"
+ - ".github/actions/setup-python/**"
+ - ".github/workflows/pyright.yml"
+ - "pyproject.toml"
+ - "pdm.lock"
+
+jobs:
+ pyright:
+ name: Pyright Lint
+ runs-on: ubuntu-latest
+ concurrency:
+ group: pyright-${{ github.ref }}-${{ matrix.env }}
+ cancel-in-progress: true
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Setup Python environment
+ uses: ./.github/actions/setup-python
+ with:
+ env-dir: ./envs/${{ matrix.env }}
+ no-root: true
+
+ - name: Run Pyright
+ uses: jakebailey/pyright-action@v2
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
new file mode 100644
index 0000000..1c305df
--- /dev/null
+++ b/.github/workflows/ruff.yml
@@ -0,0 +1,29 @@
+name: Ruff Lint
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ paths:
+ - "envs/**"
+ - "src/**"
+ - "tests/**"
+ - ".github/actions/setup-python/**"
+ - ".github/workflows/ruff.yml"
+ - "pyproject.toml"
+ - "pdm.lock"
+
+jobs:
+ ruff:
+ name: Ruff Lint
+ runs-on: ubuntu-latest
+ concurrency:
+ group: pyright-${{ github.ref }}
+ cancel-in-progress: true
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Run Ruff Lint
+ uses: chartboost/ruff-action@v1
diff --git a/.github/workflows/website-deploy.yml b/.github/workflows/website-deploy.yml
new file mode 100644
index 0000000..6faf10e
--- /dev/null
+++ b/.github/workflows/website-deploy.yml
@@ -0,0 +1,46 @@
+#name: Site Deploy
+#
+#on:
+# push:
+# branches:
+# - master
+#
+#jobs:
+# publish:
+# runs-on: ubuntu-latest
+# concurrency:
+# group: website-deploy-${{ github.ref }}
+# cancel-in-progress: true
+#
+# steps:
+# - uses: actions/checkout@v4
+# with:
+# fetch-depth: 0
+#
+# - name: Setup Python Environment
+# uses: ./.github/actions/setup-python
+#
+# - name: Setup Node Environment
+# uses: ./.github/actions/setup-node
+#
+# - name: Build API Doc
+# uses: ./.github/actions/build-api-doc
+#
+# - name: Build Doc
+# run: yarn build
+#
+# - name: Get Branch Name
+# run: echo "BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_ENV
+#
+# - name: Deploy to Netlify
+# uses: nwtgck/actions-netlify@v3
+# with:
+# publish-dir: "./website/build"
+# production-deploy: true
+# github-token: ${{ secrets.GITHUB_TOKEN }}
+# deploy-message: "Deploy ${{ env.BRANCH_NAME }}@${{ github.sha }}"
+# enable-commit-comment: false
+# alias: ${{ env.BRANCH_NAME }}
+# env:
+# NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
+# NETLIFY_SITE_ID: ${{ secrets.SITE_ID }}
diff --git a/.github/workflows/website-preview.yml b/.github/workflows/website-preview.yml
new file mode 100644
index 0000000..3657b8b
--- /dev/null
+++ b/.github/workflows/website-preview.yml
@@ -0,0 +1,46 @@
+#name: Site Deploy(Preview)
+#
+#on:
+# pull_request_target:
+#
+#jobs:
+# preview:
+# runs-on: ubuntu-latest
+# concurrency:
+# group: pull-request-preview-${{ github.event.number }}
+# cancel-in-progress: true
+#
+# steps:
+# - uses: actions/checkout@v4
+# with:
+# ref: ${{ github.event.pull_request.head.sha }}
+# fetch-depth: 0
+#
+# - name: Setup Python Environment
+# uses: ./.github/actions/setup-python
+#
+# - name: Setup Node Environment
+# uses: ./.github/actions/setup-node
+#
+# - name: Build API Doc
+# uses: ./.github/actions/build-api-doc
+#
+# - name: Build Doc
+# run: yarn build
+#
+# - name: Get Deploy Name
+# run: |
+# echo "DEPLOY_NAME=deploy-preview-${{ github.event.number }}" >> $GITHUB_ENV
+#
+# - name: Deploy to Netlify
+# uses: nwtgck/actions-netlify@v3
+# with:
+# publish-dir: "./website/build"
+# production-deploy: false
+# github-token: ${{ secrets.GITHUB_TOKEN }}
+# deploy-message: "Deploy ${{ env.DEPLOY_NAME }}@${{ github.sha }}"
+# enable-commit-comment: false
+# alias: ${{ env.DEPLOY_NAME }}
+# env:
+# NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
+# NETLIFY_SITE_ID: ${{ secrets.SITE_ID }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..afd8650
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,166 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+test.py
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm-project.org/#use-with-ide
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+# Visual Studio Code
+.vscode/
\ No newline at end of file
diff --git a/.markdownlint.yaml b/.markdownlint.yaml
new file mode 100644
index 0000000..c1b97a3
--- /dev/null
+++ b/.markdownlint.yaml
@@ -0,0 +1,13 @@
+MD013: false
+MD024:
+ siblings_only: true
+MD033: false
+
+code-block:
+ ignore: true
+
+markdown-code-block:
+ format: false
+
+fenced-code-blocks:
+ validate: false
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..262c6a4
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,33 @@
+ci:
+ autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks"
+ autofix_prs: true
+ autoupdate_branch: main
+ autoupdate_schedule: quarterly
+ autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks"
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.7.2
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+ stages: [pre-commit]
+ - id: ruff-format
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ stages: [pre-commit]
+
+ - repo: https://github.com/psf/black
+ rev: 24.10.0
+ hooks:
+ - id: black
+ stages: [pre-commit]
+
+ - repo: https://github.com/pre-commit/mirrors-prettier
+ rev: v4.0.0-alpha.8
+ hooks:
+ - id: prettier
+ types_or: [javascript, jsx, ts, tsx, markdown, yaml, json]
+ stages: [pre-commit]
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..530fe91
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.11.9
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..bab4669
--- /dev/null
+++ b/README.md
@@ -0,0 +1,73 @@
+# Aioarxiv
+
+An async Python client for the arXiv API with enhanced performance and flexible configuration options.
+
+> ⚠️ Warning: This project is currently in beta. Not recommended for production use.
+
+## Features
+
+- Asynchronous API calls for better performance
+- Flexible search and download capabilities
+- Customizable rate limiting and concurrent requests
+- Simple error handling
+
+## Installation
+
+```bash
+pip install aioarxiv
+```
+
+## Quick Start
+
+```python
+import asyncio
+from aioarxiv import ArxivClient
+
+async def main():
+ async with ArxivClient() as client:
+ async for paper in client.search("quantum computing", max_results=1):
+ print(f"Title: {paper.title}")
+ print(f"Authors: {', '.join(a.name for a in paper.authors)}")
+ print(f"Summary: {paper.summary[:200]}...")
+
+ # Download PDF
+ file_path = await client.download_paper(paper)
+ print(f"Downloaded to: {file_path}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+## Configuration
+
+```python
+from aioarxiv import ArxivConfig, ArxivClient
+
+config = ArxivConfig(
+ rate_limit_calls=3, # Rate limit per window
+ rate_limit_period=1.0, # Window period in seconds
+ max_concurrent_requests=3 # Max concurrent requests
+)
+
+client = ArxivClient(config=config)
+```
+
+## Error Handling
+
+```python
+try:
+ async for paper in client.search("quantum computing"):
+ print(paper.title)
+except SearchCompleteException:
+ print("Search complete")
+```
+
+## Requirements
+* Python 3.9 or higher
+
+## License
+[MIT License (c) 2024 BalconyJH ](LICENSE)
+
+## Links
+* Documentation for aioarxiv is WIP
+* [ArXiv API](https://info.arxiv.org/help/api/index.html)
\ No newline at end of file
diff --git a/Writerside/a.tree b/Writerside/a.tree
new file mode 100644
index 0000000..1458c58
--- /dev/null
+++ b/Writerside/a.tree
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Writerside/c.list b/Writerside/c.list
new file mode 100644
index 0000000..fadecfa
--- /dev/null
+++ b/Writerside/c.list
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Writerside/cfg/buildprofiles.xml b/Writerside/cfg/buildprofiles.xml
new file mode 100644
index 0000000..e0663a7
--- /dev/null
+++ b/Writerside/cfg/buildprofiles.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+ true
+
+
+
+
diff --git a/Writerside/cfg/glossary.xml b/Writerside/cfg/glossary.xml
new file mode 100644
index 0000000..bb59ee0
--- /dev/null
+++ b/Writerside/cfg/glossary.xml
@@ -0,0 +1,43 @@
+
+
+
+
+ 异步arXiv API客户端的主类,用于执行论文搜索和获取操作。
+
+
+
+ 表示单篇arXiv论文的数据模型,包含标题、作者、摘要等信息。
+
+
+
+ 搜索参数配置类,用于定义查询条件和结果限制。
+
+
+
+ 客户端配置类,用于设置API访问参数如并发限制、请求频率等。
+
+
+
+ 批量处理上下文类,用于跟踪批量获取论文的状态和进度。
+
+
+
+ 论文分类信息模型,包含主分类和次分类。
+
+
+
+ 会话管理器,负责维护HTTP会话和实现访问频率限制。
+
+
+
+ XML解析器,用于解析arXiv API返回的Atom feed格式数据。
+
+
+
+ 搜索完成异常,用于标识搜索达到结果限制或无结果时的情况。
+
+
+
+ 解析异常,在XML数据解析失败时抛出。
+
+
\ No newline at end of file
diff --git a/Writerside/topics/Installation-guide.md b/Writerside/topics/Installation-guide.md
new file mode 100644
index 0000000..880a156
--- /dev/null
+++ b/Writerside/topics/Installation-guide.md
@@ -0,0 +1,108 @@
+# Installation Guide
+
+## Introduction
+
+本安装指南旨在帮助用户成功安装并配置 **aioarxiv** 或作为项目依赖。
+
+
+警告:该项目目前仍处于测试阶段,尚未进入稳定版本。请勿用于生产环境。
+
+
+---
+
+## Installation types
+
+我们强烈建议使用如 [PDM](https://pdm-project.org/) , [Poetry](https://python-poetry.org/)
+等包管理器管理项目依赖。
+
+| **Type** | **Description** | **More information** |
+|-----------|-----------------------------|--------------------------------|
+| PDM 安装 | 添加 `aioarxiv` 到项目依赖中。 | [跳转到安装步骤](#installation-steps) |
+| Poetry 安装 | 添加 `aioarxiv` 到项目依赖中。 | [跳转到安装步骤](#installation-steps) |
+| PyPI 安装 | 使用 `pip` 在支持的 Python 版本中安装。 | [跳转到安装步骤](#installation-steps) |
+| 源代码安装 | 从 GitHub 克隆代码并手动安装。 | [跳转到安装步骤](#installation-steps) |
+
+---
+
+## Overview
+
+### 版本信息
+
+以下是此库的可用版本。我们暂不推荐在生产环境中依赖该库:
+
+| **Version** | **Build** | **Release Date** | **Status** | **Notes** |
+|-------------|--------------|------------------|------------|-----------------------|
+| 0.1.0 | PyPI Release | 15/11/2024 | Latest | Initial testing phase |
+
+## System requirements
+
+### 支持的操作系统
+
+- **Windows**: Windows 10 或更高版本。
+- **MacOS**: macOS 10.15 或更高版本。
+- **Linux**: Ubuntu 18.04 或更高版本。
+
+### 支持的 Python 版本
+
+- Python 3.9 或更高版本。
+
+## Before you begin
+
+在安装此库之前,请确保以下事项:
+
+- **已安装 Python**:
+ - 确保您的系统中已安装支持的 Python 版本。
+ - 使用 `python --version` 确认。
+ - 如果未安装
+ Python,请前往 [Python 官网](https://www.python.org/downloads/release/python-3920/)
+ 下载并安装。
+- **已安装 pip**:
+ - 确保您的系统中已包含 `pip` 并建议更新。
+ - 使用 `pip --version` 确认版本号和使用 `pip install --upgrade pip` 更新。
+- **已安装 Git**:
+ - 如果您选择从源代码安装,请确保已安装 Git。
+ - 使用 `git --version` 确认版本号。
+- **已安装包管理器**:
+ - 如果您选择使用 PDM 或 Poetry,请确保已正确安装。
+ - 检查命令:
+ ```bash
+ pdm --version # 检查 PDM
+ poetry --version # 检查 Poetry
+ ```
+
+## Installation steps
+
+### PDM 安装
+
+1. 打开命令行工具。
+2. 执行以下命令以安装库:
+ ```bash
+ pdm add aioarxiv
+ ```
+
+### Poetry 安装
+
+1. 打开命令行工具。
+2. 执行以下命令以安装库:
+ ```bash
+ poetry add aioarxiv
+ ```
+
+### 使用 pip 从 PyPI 安装
+
+1. 打开命令行工具。
+2. 执行以下命令以安装库:
+ ```bash
+ pip install aioarxiv
+ ```
+
+### 源代码安装
+
+1. 打开命令行工具。
+2. 执行以下命令以克隆代码:
+ ```bash
+ git clone https://github.com/BalconyJH/aioarxiv.git
+ cd aioarxiv
+ pip install .
+ ```
+
\ No newline at end of file
diff --git a/Writerside/topics/starter.md b/Writerside/topics/starter.md
new file mode 100644
index 0000000..4924928
--- /dev/null
+++ b/Writerside/topics/starter.md
@@ -0,0 +1,99 @@
+# Starter
+
+## 安装
+
+
+警告:该项目目前仍处于测试阶段,尚未进入稳定版本。请勿用于生产环境。
+
+
+使用 `pip` 安装: `pip install aioarxiv`
+
+## 基本概念
+
+`aioarxiv` 是一个专为 arXiv API 构建的异步 Python 客户端,旨在简化高效的论文检索和下载操作。其主要功能包括:
+- 异步 API 调用,提高高频请求的性能。
+- 灵活的查询和下载功能。
+- 自定义配置以满足特定需求。
+
+## 主要类
+
+* `ArxivClient`: 主要客户端类,用于执行搜索
+* `Paper`: 论文数据模型
+* `SearchParams`: 搜索参数配置模型
+
+## 使用示例
+
+
+import asyncio
+
+from aioarxiv.client.arxiv_client import ArxivClient
+from aioarxiv.config import ArxivConfig
+
+async def search():
+ config = ArxivConfig(
+ log_level="DEBUG",
+ )
+ client = ArxivClient(config=config)
+
+ query = "quantum computing"
+ count = 0
+
+ print(f"Searching for: {query}")
+ print("-" * 50)
+
+ async for paper in client.search(query=query, max_results=1):
+ count += 1
+ print(f"\nPaper {count}:")
+ print(f"Title: {paper.title}")
+ print(f"Authors: {paper.authors}")
+ print(f"Summary: {paper.summary[:200]}...")
+ print("-" * 50)
+ file_path = await client.download_paper(paper)
+ print(f"Downloaded to: {file_path}")
+
+if __name__ == "__main__":
+asyncio.run(search())
+
+
+## 自定义配置
+
+
+from aioarxiv import ArxivConfig
+
+config = ArxivConfig(
+ rate_limit_calls=3, # 请求频率限制
+ max_concurrent_requests=3 # 最大并发数
+)
+client = ArxivClient(config=config)
+
+
+## 错误处理
+
+`ArxivClient.search()` 方法返回一个异步生成器, 依赖 `SearchCompleteException`
+异常来判断搜索是否完成.
+
+因此你应该使用以下结构来处理搜索:
+
+
+from aioarxiv import ArxivClient, SearchCompleteException
+
+async def search_and_handle_error():
+ try:
+ query = "quantum computing"
+ async for paper in client.search(query=query, max_results=1):
+ print(paper)
+ except SearchCompleteException:
+ print("Search complete.")
+
+
+## 反馈
+功能请求和错误报告请提交到 [GitHub Issue](https://github.com/BalconyJH/aioarxiv/issues/new)。
+
+
+
+ arXiv API
+
+
+ aioarxiv
+
+
diff --git a/Writerside/v.list b/Writerside/v.list
new file mode 100644
index 0000000..2d12cb3
--- /dev/null
+++ b/Writerside/v.list
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/Writerside/writerside.cfg b/Writerside/writerside.cfg
new file mode 100644
index 0000000..d84ff0c
--- /dev/null
+++ b/Writerside/writerside.cfg
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..d8f87e5
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,108 @@
+[project]
+name = "aioarxiv"
+version = "0.1.0"
+description = "arxiv Parse library"
+authors = [
+ { name = "BalconyJH", email = "balconyjh@gmail.com" },
+]
+dependencies = [
+ "feedparser~=6.0",
+ "loguru~=0.7",
+ "pydantic>=2.9.2",
+ "aiohttp>=3.10.10",
+ "tenacity>=9.0.0",
+ "aiofiles>=24.1.0",
+]
+requires-python = ">=3.9"
+readme = "README.md"
+license = { text = "MIT" }
+
+[tool.pdm]
+distribution = true
+
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+addopts = "--cov-report=term-missing --ignore=.venv/"
+asyncio_default_fixture_loop_scope = "session"
+
+[tool.black]
+line-length = 88
+target-version = ["py39", "py310", "py311", "py312"]
+include = '\.pyi?$'
+extend-exclude = '''
+'''
+
+[tool.isort]
+profile = "black"
+line_length = 88
+length_sort = true
+skip_gitignore = true
+force_sort_within_sections = true
+src_paths = ["src", "tests"]
+extra_standard_library = ["typing_extensions"]
+
+[tool.ruff]
+line-length = 88
+target-version = "py39"
+
+[tool.ruff.lint]
+select = [
+ "F", # Pyflakes
+ "W", # pycodestyle warnings
+ "E", # pycodestyle errors
+ "UP", # pyupgrade
+ "ASYNC", # flake8-async
+ "C4", # flake8-comprehensions
+ "T10", # flake8-debugger
+ "T20", # flake8-print
+ "PYI", # flake8-pyi
+ "PT", # flake8-pytest-style
+ "Q", # flake8-quotes
+ "RUF", # Ruff-specific rules
+ "I", # isort
+]
+ignore = [
+ "E402", # module-import-not-at-top-of-file
+ "E501", # line-too-long
+ "UP037", # quoted-annotation
+ "RUF001", # ambiguous-unicode-character-string
+ "RUF002", # ambiguous-unicode-character-docstring
+ "RUF003", # ambiguous-unicode-character-comment
+ "ASYNC230", # blocking-open-call-in-async-function
+]
+
+[tool.ruff.lint.flake8-pytest-style]
+fixture-parentheses = false
+mark-parentheses = false
+
+[tool.pyright]
+pythonVersion = "3.9"
+pythonPlatform = "All"
+executionEnvironments = [
+ { root = "./tests", extraPaths = [
+ "./",
+ ] },
+ { root = "./src" },
+]
+typeCheckingMode = "standard"
+disableBytesTypePromotions = true
+
+[dependency-groups]
+test = [
+ "pytest-cov~=5.0",
+ "pytest-xdist~=3.6",
+ "pytest-asyncio~=0.23",
+ "pytest-mock>=3.14.0",
+]
+dev = [
+ "black~=24.4",
+ "ruff~=0.4",
+ "isort~=5.13",
+ "pre-commit~=4.0",
+ "bump-my-version>=0.28.1",
+]
+
+[tool.coverage.run]
+omit = [
+ ".venv/*",
+]
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/client/__init__.py b/src/client/__init__.py
new file mode 100644
index 0000000..3935534
--- /dev/null
+++ b/src/client/__init__.py
@@ -0,0 +1,348 @@
+# import sys
+# from collections.abc import AsyncGenerator
+# from dataclasses import dataclass
+# from types import TracebackType
+# from typing import Optional
+#
+# from aiohttp import ClientError, ClientResponse
+# from anyio import create_memory_object_stream, create_task_group, sleep
+# from tenacity import (
+# retry,
+# retry_if_exception_type,
+# stop_after_attempt,
+# wait_exponential,
+# )
+#
+# from ..config import ArxivConfig, default_config
+# from ..exception import (
+# HTTPException,
+# ParseErrorContext,
+# ParserException,
+# QueryBuildError,
+# QueryContext, SearchCompleteException,
+# )
+# from ..models import Paper, SearchParams
+# from ..utils.log import logger
+# from ..utils.parser import ArxivParser
+# from ..utils.session import SessionManager
+#
+#
+# class ArxivClient:
+# """
+# arXiv API异步客户端
+#
+# 用于执行arXiv API搜索请求。
+#
+# Attributes:
+# _config: arXiv API配置
+# _session_manager: 会话管理器
+# """
+# def __init__(
+# self,
+# config: Optional[ArxivConfig] = None,
+# session_manager: Optional[SessionManager] = None,
+# ):
+# self._config = config or default_config
+# self._session_manager = session_manager or SessionManager()
+#
+# async def search(
+# self,
+# query: str,
+# max_results: Optional[int] = None,
+# ) -> AsyncGenerator[Paper, None]:
+# """
+# 执行搜索
+#
+# Args:
+# query: 搜索查询
+# max_results: 最大结果数
+#
+# Yields:
+# Paper: 论文对象
+# """
+# params = SearchParams(query=query, max_results=max_results)
+#
+# async with self._session_manager:
+# async for paper in self._iter_papers(params):
+# yield paper
+#
+# def _calculate_page_size(self, start: int, max_results: Optional[int]) -> int:
+# """计算页面大小"""
+# if max_results is None:
+# return self._config.page_size
+# return min(self._config.page_size, max_results - start)
+#
+# def _has_more_results(
+# self, start: int, total: float, max_results: Optional[int]
+# ) -> bool:
+# """检查是否有更多结果"""
+# max_total = int(total) if total != float("inf") else sys.maxsize
+# max_allowed = max_results or sys.maxsize
+# return start < min(max_total, max_allowed)
+#
+# async def _iter_papers(self, params: SearchParams) -> AsyncGenerator[Paper, None]:
+# """迭代获取论文
+#
+# Args:
+# params: 搜索参数
+#
+# Yields:
+# Paper: 论文对象
+# """
+# start = 0
+# total_results = float("inf")
+# concurrent_limit = min(
+# self._config.max_concurrent_requests,
+# self._config.rate_limit_calls
+# )
+# results_count = 0
+# first_batch = True
+#
+# logger.info("开始搜索", extra={
+# "query": params.query,
+# "max_results": params.max_results,
+# "concurrent_limit": concurrent_limit
+# })
+#
+# send_stream, receive_stream = create_memory_object_stream[
+# tuple[list[Paper], int]
+# ](max_buffer_size=concurrent_limit)
+#
+# @retry(
+# retry=retry_if_exception_type((ClientError, TimeoutError, OSError)),
+# stop=stop_after_attempt(self._config.max_retries),
+# wait=wait_exponential(
+# multiplier=self._config.rate_limit_period,
+# min=self._config.min_wait,
+# max=self._config.timeout
+# ),
+# )
+# async def fetch_page(page_start: int):
+# """获取单页并发送结果"""
+# try:
+# async with self._session_manager.rate_limited_context():
+# response = await self._fetch_page(params, page_start)
+# result = await self.parse_response(response)
+# await send_stream.send(result)
+# except Exception as e:
+# logger.error(f"获取页面失败: {e!s}", extra={"page": page_start})
+# raise
+#
+# async def fetch_batch(batch_start: int, size: int):
+# """批量获取论文"""
+# try:
+# async with create_task_group() as batch_tg:
+# # 计算这一批次应获取的数量
+# batch_size = min(
+# size,
+# self._config.page_size,
+# params.max_results - results_count if params.max_results else size
+# )
+#
+# # 计算结束位置
+# end = min(
+# batch_start + batch_size,
+# int(total_results) if total_results != float(
+# "inf") else sys.maxsize
+# )
+#
+# # 启动单页获取任务
+# for offset in range(batch_start, end):
+# if params.max_results and offset >= params.max_results:
+# break
+# batch_tg.start_soon(fetch_page, offset)
+# except Exception as e:
+# logger.error(
+# "批量获取失败",
+# extra={
+# "start": batch_start,
+# "size": size,
+# "error": str(e)
+# }
+# )
+# raise
+#
+# try:
+# async with create_task_group() as tg:
+# tg.start_soon(
+# fetch_batch,
+# start,
+# concurrent_limit * self._config.page_size
+# )
+#
+# async for papers, total in receive_stream:
+# total_results = total
+#
+# if first_batch and total == 0:
+# logger.info("未找到结果", extra={"query": params.query})
+# raise SearchCompleteException(0)
+# first_batch = False
+#
+# for paper in papers:
+# yield paper
+# results_count += 1
+# if params.max_results and results_count >= params.max_results:
+# raise SearchCompleteException(results_count)
+#
+# start += len(papers)
+# if start < min(
+# int(total_results) if total_results != float(
+# "inf") else sys.maxsize,
+# params.max_results or sys.maxsize
+# ):
+# await sleep(self._config.rate_limit_period)
+# tg.start_soon(
+# fetch_batch,
+# start,
+# concurrent_limit * self._config.page_size
+# )
+# finally:
+# await receive_stream.aclose()
+# logger.info("搜索结束", extra={"total_results": results_count})
+#
+# async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse:
+# """
+# 获取单页结果
+#
+# Args:
+# params: 搜索参数
+# start: 起始位置
+#
+# Returns:
+# ClientResponse: API响应
+#
+# Raises:
+# QueryBuildError: 如果构建查询参数失败
+# HTTPException: 如果请求失败
+# """
+# try:
+# # 构建查询参数
+# query_params = self._build_query_params(params, start)
+#
+# # 发送请求
+# response = await self._session_manager.request(
+# "GET", str(self._config.base_url), params=query_params
+# )
+#
+# if response.status != 200:
+# logger.error(
+# "搜索请求失败", extra={"status_code": response.status}
+# )
+# raise HTTPException(response.status)
+#
+# return response
+#
+# except QueryBuildError as e:
+# logger.error("查询参数构建失败", extra={"error_context": e.context})
+# raise
+# except Exception as e:
+# logger.error("未预期的错误", exc_info=True)
+# raise QueryBuildError(
+# message="构建查询参数失败",
+# context=QueryContext(
+# params={"query": params.query, "start": start},
+# field_name="query_params",
+# ),
+# original_error=e,
+# )
+#
+# def _build_query_params(self, params: SearchParams, start: int) -> dict:
+# """
+# 构建查询参数
+#
+# Args:
+# params: 搜索参数
+# start: 起始位置
+#
+# Returns:
+# dict: 查询参数
+#
+# Raises:
+# QueryBuildError: 如果构建查询参数失败
+# """
+# if not params.query:
+# raise QueryBuildError(
+# message="搜索查询不能为空",
+# context=QueryContext(
+# params={"query": None}, field_name="query", constraint="required"
+# ),
+# )
+#
+# if start < 0:
+# raise QueryBuildError(
+# message="起始位置不能为负数",
+# context=QueryContext(
+# params={"start": start},
+# field_name="start",
+# constraint="non_negative",
+# ),
+# )
+#
+# try:
+# page_size = min(
+# self._config.page_size,
+# params.max_results - start if params.max_results else float("inf"),
+# )
+# except Exception as e:
+# raise QueryBuildError(
+# message="计算页面大小失败",
+# context=QueryContext(
+# params={
+# "page_size": self._config.page_size,
+# "max_results": params.max_results,
+# "start": start,
+# },
+# field_name="page_size",
+# ),
+# original_error=e,
+# )
+#
+# query_params = {
+# "search_query": params.query,
+# "start": start,
+# "max_results": page_size,
+# }
+#
+# if params.sort_by:
+# query_params["sortBy"] = params.sort_by.value
+#
+# if params.sort_order:
+# query_params["sortOrder"] = params.sort_order.value
+#
+# return query_params
+#
+# async def parse_response(
+# self,
+# response: ClientResponse
+# ) -> tuple[list[Paper], int]:
+# """解析API响应"""
+# content = await response.text()
+# try:
+# return await ArxivParser.parse_feed(
+# content=content,
+# url=str(response.url)
+# )
+# except ParserException:
+# raise
+# except Exception as e:
+# raise ParserException(
+# url=str(response.url),
+# message="解析响应失败",
+# context=ParseErrorContext(raw_content=content),
+# original_error=e,
+# )
+#
+# async def __aenter__(self) -> "ArxivClient":
+# return self
+#
+# async def __aexit__(
+# self,
+# exc_type: Optional[type[BaseException]],
+# exc_val: Optional[BaseException],
+# exc_tb: Optional[TracebackType]
+# ) -> None:
+# await self._session_manager.close()
+#
+# async def close(self) -> None:
+# """关闭客户端并清理资源"""
+# await self._session_manager.close()
diff --git a/src/client/arxiv_client.py b/src/client/arxiv_client.py
new file mode 100644
index 0000000..a6fb652
--- /dev/null
+++ b/src/client/arxiv_client.py
@@ -0,0 +1,217 @@
+from collections.abc import AsyncGenerator
+from pathlib import Path
+from types import TracebackType
+from typing import Optional
+
+from aiohttp import ClientResponse
+
+from ..config import ArxivConfig, default_config
+from ..exception import (
+ HTTPException,
+ ParseErrorContext,
+ ParserException,
+ QueryBuildError,
+ QueryContext,
+ SearchCompleteException,
+)
+from ..models import Paper, SearchParams
+from ..utils import logger
+from ..utils.parser import ArxivParser
+from ..utils.session import SessionManager
+from .base import BaseSearchManager
+from .downloader import ArxivDownloader
+from .search import ArxivSearchManager
+
+
+class ArxivClient:
+ def __init__(
+ self,
+ config: Optional[ArxivConfig] = None,
+ session_manager: Optional[SessionManager] = None,
+ *,
+ search_manager_class: type[BaseSearchManager] = ArxivSearchManager,
+ ):
+ self._config = config or default_config
+ self._session_manager = session_manager or SessionManager(config=self._config)
+ self._search_manager = search_manager_class(self)
+ self._downloader = ArxivDownloader(self._session_manager)
+
+ async def search(
+ self,
+ query: str,
+ max_results: Optional[int] = None,
+ ) -> AsyncGenerator[Paper, None]:
+ """执行搜索"""
+ params = SearchParams(query=query, max_results=max_results)
+
+ try:
+ async with self._session_manager:
+ async for paper in self._search_manager.execute_search(params):
+ yield paper
+ except SearchCompleteException:
+ return
+
+ async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse:
+ """
+ 获取单页结果
+
+ Args:
+ params: 搜索参数
+ start: 起始位置
+
+ Returns:
+ 响应对象
+
+ Raises:
+ QueryBuildError: 如果构建查询参数失败
+ """
+ try:
+ query_params = self._build_query_params(params, start)
+ response = await self._session_manager.request(
+ "GET", str(self._config.base_url), params=query_params
+ )
+
+ if response.status != 200:
+ logger.error("搜索请求失败", extra={"status_code": response.status})
+ raise HTTPException(response.status)
+
+ return response
+
+ except QueryBuildError:
+ raise
+ except Exception as e:
+ logger.error("未预期的错误", exc_info=True)
+ raise QueryBuildError(
+ message="构建查询参数失败",
+ context=QueryContext(
+ params={"query": params.query, "start": start},
+ field_name="query_params",
+ ),
+ original_error=e,
+ )
+
+ def _build_query_params(self, params: SearchParams, start: int) -> dict:
+ """
+ 构建查询参数
+
+ Args:
+ params: 搜索参数
+ start: 起始位置
+
+ Returns:
+ dict: 查询参数
+
+ Raises:
+ QueryBuildError: 如果构建查询参数失败
+ """
+ self._validate_params(params, start)
+
+ try:
+ page_size = min(
+ self._config.page_size,
+ params.max_results - start if params.max_results else float("inf"),
+ )
+
+ query_params = {
+ "search_query": params.query,
+ "start": start,
+ "max_results": page_size,
+ }
+
+ if params.sort_by:
+ query_params["sortBy"] = params.sort_by.value
+
+ if params.sort_order:
+ query_params["sortOrder"] = params.sort_order.value
+
+ return query_params
+
+ except Exception as e:
+ raise QueryBuildError(
+ message="计算页面大小失败",
+ context=QueryContext(
+ params={
+ "page_size": self._config.page_size,
+ "max_results": params.max_results,
+ "start": start,
+ },
+ field_name="page_size",
+ ),
+ original_error=e,
+ )
+
+ def _validate_params(self, params: SearchParams, start: int) -> None:
+ """验证查询参数"""
+ if not params.query:
+ raise QueryBuildError(
+ message="搜索查询不能为空",
+ context=QueryContext(
+ params={"query": None}, field_name="query", constraint="required"
+ ),
+ )
+
+ if start < 0:
+ raise QueryBuildError(
+ message="起始位置不能为负数",
+ context=QueryContext(
+ params={"start": start},
+ field_name="start",
+ constraint="non_negative",
+ ),
+ )
+
+ async def parse_response(self, response: ClientResponse) -> tuple[list[Paper], int]:
+ """
+ 解析API响应
+
+ Args:
+ response: 响应对象
+
+ Returns:
+ 解析后的论文列表和总结果数
+
+ Raises:
+ ParserException: 如果解析失败
+ """
+ content = await response.text()
+ try:
+ return await ArxivParser.parse_feed(content=content, url=str(response.url))
+ except ParserException:
+ raise
+ except Exception as e:
+ raise ParserException(
+ url=str(response.url),
+ message="解析响应失败",
+ context=ParseErrorContext(raw_content=content),
+ original_error=e,
+ )
+
+ async def download_paper(
+ self, paper: Paper, filename: Optional[str] = None
+ ) -> Path:
+ """
+ 下载论文
+
+ Args:
+ paper: 论文对象
+ filename: 文件名
+
+ Returns:
+ 下载文件的存放路径
+ """
+ return await self._downloader.download_paper(str(paper.pdf_url), filename)
+
+ async def close(self) -> None:
+ """关闭客户端"""
+ await self._session_manager.close()
+
+ async def __aenter__(self) -> "ArxivClient":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
diff --git a/src/client/base.py b/src/client/base.py
new file mode 100644
index 0000000..3c7c74a
--- /dev/null
+++ b/src/client/base.py
@@ -0,0 +1,34 @@
+from abc import ABC, abstractmethod
+from collections.abc import AsyncGenerator
+from typing import Protocol
+
+from aiohttp import ClientResponse
+
+from ..config import ArxivConfig
+from ..models import Paper, SearchParams
+
+
+class ClientProtocol(Protocol):
+ """客户端协议"""
+ async def _fetch_page(self, params: SearchParams, start: int) -> ClientResponse:
+ ...
+
+ async def parse_response(
+ self,
+ response: ClientResponse
+ ) -> tuple[list[Paper], int]:
+ ...
+
+ @property
+ def _config(self) -> "ArxivConfig":
+ ...
+
+class BaseSearchManager(ABC):
+ """搜索管理器基类"""
+ @abstractmethod
+ def __init__(self, client: ClientProtocol):
+ pass
+
+ @abstractmethod
+ def execute_search(self, params: SearchParams) -> AsyncGenerator[Paper, None]:
+ pass
diff --git a/src/client/downloader.py b/src/client/downloader.py
new file mode 100644
index 0000000..cda4469
--- /dev/null
+++ b/src/client/downloader.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from types import TracebackType
+from typing import Optional
+
+import aiofiles
+
+from ..utils import get_project_root
+from ..utils.log import logger
+from ..utils.session import SessionManager
+
+
+class ArxivDownloader:
+ """arXiv论文下载器"""
+ def __init__(
+ self,
+ session_manager: Optional[SessionManager] = None,
+ download_dir: Optional[str] = None
+ ):
+ """
+ 初始化下载器
+
+ Args:
+ session_manager: 会话管理器,可选
+ download_dir: 下载目录,可选
+ """
+ project_root = get_project_root()
+ self._session_manager = session_manager
+ self._own_session = False
+ self.download_dir = Path(download_dir) if download_dir else project_root / "downloads"
+ self.download_dir.mkdir(parents=True, exist_ok=True)
+
+ @property
+ def session_manager(self) -> SessionManager:
+ """懒加载会话管理器"""
+ if self._session_manager is None:
+ self._session_manager = SessionManager()
+ self._own_session = True
+ return self._session_manager
+
+ async def download_paper(
+ self,
+ pdf_url: str,
+ filename: Optional[str] = None
+ ) -> Path:
+ """下载论文PDF文件"""
+ if not filename:
+ filename = pdf_url.split("/")[-1]
+ if not filename.endswith(".pdf"):
+ filename += ".pdf"
+
+ file_path = self.download_dir / filename
+ temp_path = file_path.with_suffix(file_path.suffix + ".tmp")
+
+ logger.info(f"开始下载论文: {pdf_url}")
+ try:
+ async with self.session_manager.rate_limited_context():
+ # 先获取响应
+ response = await self.session_manager.request("GET", pdf_url)
+ # 然后使用 async with 管理响应生命周期
+ async with response:
+ response.raise_for_status()
+
+ async with aiofiles.open(temp_path, "wb") as f:
+ total_size = int(response.headers.get("content-length", 0))
+ downloaded_size = 0
+
+ async for chunk, _ in response.content.iter_chunks():
+ if chunk:
+ await f.write(chunk)
+ downloaded_size += len(chunk)
+
+ if 0 < total_size != downloaded_size:
+ raise RuntimeError(
+ f"下载不完整: 预期 {total_size} 字节, 实际下载 {downloaded_size} 字节")
+
+ temp_path.rename(file_path)
+
+ logger.info(f"下载完成: {file_path}")
+ return file_path
+
+ except Exception as e:
+ logger.error(f"下载失败: {e!s}")
+ # 清理临时文件和目标文件
+ if temp_path.exists():
+ temp_path.unlink()
+ if file_path.exists():
+ file_path.unlink()
+ raise
+
+ async def __aenter__(self) -> "ArxivDownloader":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]
+ ) -> None:
+ if self._own_session and self._session_manager:
+ await self._session_manager.close()
diff --git a/src/client/processor.py b/src/client/processor.py
new file mode 100644
index 0000000..e13a03e
--- /dev/null
+++ b/src/client/processor.py
@@ -0,0 +1,136 @@
+import sys
+from collections.abc import AsyncGenerator
+from dataclasses import dataclass
+from typing import Optional
+
+from anyio import create_memory_object_stream, create_task_group, sleep
+
+from ..config import ArxivConfig
+from ..exception import SearchCompleteException
+from ..models import Paper, SearchParams
+from ..utils.log import logger
+
+
+@dataclass
+class BatchContext:
+ """批处理上下文"""
+ start: int
+ size: int
+ total_results: float
+ results_count: int
+ max_results: Optional[int]
+
+class ResultProcessor:
+ """处理搜索结果的处理器类"""
+ def __init__(self, concurrent_limit: int, config: ArxivConfig):
+ self.concurrent_limit = concurrent_limit
+ self.config = config
+
+ async def create_streams(self):
+ """创建内存流"""
+ return create_memory_object_stream[tuple[list[Paper], int]](
+ max_buffer_size=self.concurrent_limit
+ )
+
+ def calculate_batch_size(self, ctx: BatchContext) -> int:
+ """计算批次大小"""
+ return min(
+ ctx.size,
+ self.config.page_size,
+ ctx.max_results - ctx.results_count if ctx.max_results else ctx.size
+ )
+
+ def calculate_batch_end(self, ctx: BatchContext, batch_size: int) -> int:
+ """计算批次结束位置"""
+ return min(
+ ctx.start + batch_size,
+ int(ctx.total_results) if ctx.total_results != float("inf") else sys.maxsize
+ )
+
+async def _iter_papers(self, params: SearchParams) -> AsyncGenerator[Paper, None]:
+ """迭代获取论文"""
+ # 初始化上下文
+ concurrent_limit = min(
+ self._config.max_concurrent_requests,
+ self._config.rate_limit_calls
+ )
+ processor = ResultProcessor(concurrent_limit, self._config)
+ ctx = BatchContext(
+ start=0,
+ size=concurrent_limit * self._config.page_size,
+ total_results=float("inf"),
+ results_count=0,
+ max_results=params.max_results
+ )
+
+ logger.info("开始搜索", extra={
+ "query": params.query,
+ "max_results": params.max_results,
+ "concurrent_limit": concurrent_limit
+ })
+
+ send_stream, receive_stream = await processor.create_streams()
+
+ async def fetch_page(page_start: int):
+ """获取单页数据"""
+ try:
+ async with self._session_manager.rate_limited_context():
+ response = await self._fetch_page(params, page_start)
+ result = await self.parse_response(response)
+ await send_stream.send(result)
+ except Exception as e:
+ logger.error(f"获取页面失败: {e!s}", extra={"page": page_start})
+ raise
+
+ async def process_batch(batch_ctx: BatchContext):
+ """处理单个批次"""
+ try:
+ async with create_task_group() as batch_tg:
+ batch_size = processor.calculate_batch_size(batch_ctx)
+ end = processor.calculate_batch_end(batch_ctx, batch_size)
+
+ for offset in range(batch_ctx.start, end):
+ if batch_ctx.max_results and offset >= batch_ctx.max_results:
+ break
+ batch_tg.start_soon(fetch_page, offset)
+ except Exception as e:
+ logger.error("批量获取失败", extra={
+ "start": batch_ctx.start,
+ "size": batch_ctx.size,
+ "error": str(e)
+ })
+ raise
+
+ try:
+ async with create_task_group() as tg:
+ tg.start_soon(process_batch, ctx)
+
+ async for papers, total in receive_stream:
+ # 更新总结果数
+ ctx.total_results = total
+
+ # 检查首次批次是否有结果
+ if ctx.results_count == 0 and total == 0:
+ logger.info("未找到结果", extra={"query": params.query})
+ raise SearchCompleteException(0)
+
+ # 处理论文
+ for paper in papers:
+ yield paper
+ ctx.results_count += 1
+ if ctx.max_results and ctx.results_count >= ctx.max_results:
+ raise SearchCompleteException(ctx.results_count)
+
+ # 准备下一批次
+ ctx.start += len(papers)
+ if ctx.start < min(
+ int(ctx.total_results) if ctx.total_results != float("inf")
+ else sys.maxsize,
+ ctx.max_results or sys.maxsize
+ ):
+ await sleep(self._config.rate_limit_period)
+ tg.start_soon(process_batch, ctx)
+
+ finally:
+ await receive_stream.aclose()
+ logger.info("搜索结束", extra={"total_results": ctx.results_count})
diff --git a/src/client/search.py b/src/client/search.py
new file mode 100644
index 0000000..3e1de5b
--- /dev/null
+++ b/src/client/search.py
@@ -0,0 +1,169 @@
+import sys
+from collections.abc import AsyncGenerator
+from dataclasses import dataclass
+from typing import Any
+
+from anyio import create_memory_object_stream, create_task_group, sleep
+
+from ..exception import SearchCompleteException
+from ..models import Paper, SearchParams
+from ..utils.log import logger
+from .base import BaseSearchManager, ClientProtocol
+
+
+@dataclass
+class SearchContext:
+ """搜索上下文"""
+
+ params: SearchParams
+ start: int = 0
+ total_results: float = float("inf")
+ results_count: int = 0
+ first_batch: bool = True
+
+ def update_with_papers(self, papers: list[Paper]) -> None:
+ """更新处理进度"""
+ self.start += len(papers)
+ self.results_count += len(papers)
+
+ def should_continue(self) -> bool:
+ """检查是否应继续处理"""
+ max_total = (
+ int(self.total_results)
+ if self.total_results != float("inf")
+ else sys.maxsize
+ )
+ max_allowed = self.params.max_results or sys.maxsize
+ return self.start < min(max_total, max_allowed)
+
+ def reached_limit(self) -> bool:
+ """检查是否达到最大结果限制"""
+ return (
+ self.params.max_results is not None
+ and self.results_count >= self.params.max_results
+ )
+
+
+class ArxivSearchManager(BaseSearchManager):
+ def __init__(self, client: ClientProtocol):
+ self.client = client
+ self.config = client._config
+ self.concurrent_limit = min(
+ self.config.max_concurrent_requests, self.config.rate_limit_calls
+ )
+
+ def execute_search(self, params: SearchParams) -> AsyncGenerator[Paper, None]:
+ """执行搜索"""
+
+ async def search_implementation():
+ ctx = SearchContext(params=params)
+ send_stream, receive_stream = create_memory_object_stream[
+ tuple[list[Paper], int]
+ ](max_buffer_size=self.concurrent_limit)
+
+ try:
+ async with create_task_group() as tg:
+ await self._start_batch(tg, ctx, send_stream)
+ try:
+ async for paper in self._process_results(
+ ctx, tg, receive_stream
+ ):
+ yield paper
+ except SearchCompleteException:
+ # 正常的搜索完成,记录日志后继续传播
+ logger.info(f"搜索完成,共获取{ctx.results_count}条结果")
+ raise
+ finally:
+ # 确保资源清理
+ await send_stream.aclose()
+ await receive_stream.aclose()
+
+ # 返回异步生成器
+ return search_implementation()
+
+ async def _start_batch(self, tg: Any, ctx: SearchContext, send_stream: Any) -> None:
+ """启动批量获取"""
+ tg.start_soon(
+ self._fetch_batch,
+ ctx,
+ self.concurrent_limit * self.config.page_size,
+ send_stream,
+ )
+
+ async def _fetch_batch(
+ self, ctx: SearchContext, size: int, send_stream: Any
+ ) -> None:
+ """获取一批论文"""
+ try:
+ async with create_task_group() as batch_tg:
+ batch_size = self._calculate_batch_size(ctx, size)
+ end = self._calculate_batch_end(ctx, batch_size)
+
+ for offset in range(ctx.start, end):
+ if ctx.params.max_results and offset >= ctx.params.max_results:
+ break
+ batch_tg.start_soon(self._fetch_page, offset, ctx, send_stream)
+ except Exception as e:
+ logger.error(
+ "批量获取失败",
+ extra={"start": ctx.start, "size": size, "error": str(e)},
+ )
+ raise
+
+ def _calculate_batch_size(self, ctx: SearchContext, size: int) -> int:
+ """计算批次大小"""
+ return min(
+ size,
+ self.config.page_size,
+ ctx.params.max_results - ctx.results_count
+ if ctx.params.max_results
+ else size,
+ )
+
+ def _calculate_batch_end(self, ctx: SearchContext, batch_size: int) -> int:
+ """计算批次结束位置"""
+ return min(
+ ctx.start + batch_size,
+ int(ctx.total_results)
+ if ctx.total_results != float("inf")
+ else sys.maxsize,
+ )
+
+ async def _fetch_page(
+ self, offset: int, ctx: SearchContext, send_stream: Any
+ ) -> None:
+ """获取并发送单页数据"""
+ try:
+ response = await self.client._fetch_page(ctx.params, offset)
+ result = await self.client.parse_response(response)
+ await send_stream.send(result)
+ except Exception as e:
+ logger.error(f"获取页面失败: {e!s}", extra={"page": offset})
+ raise
+
+ async def _process_results(
+ self, ctx: SearchContext, tg: Any, receive_stream: Any
+ ) -> AsyncGenerator[Paper, None]:
+ """处理搜索结果"""
+ async for papers, total in receive_stream:
+ # 更新总结果数
+ ctx.total_results = total
+
+ # 检查首次批次是否有结果
+ if ctx.first_batch and total == 0:
+ logger.info("未找到结果", extra={"query": ctx.params.query})
+ raise SearchCompleteException(0)
+ ctx.first_batch = False
+
+ # 处理论文
+ for paper in papers:
+ yield paper
+ ctx.results_count += 1
+ if ctx.reached_limit():
+ raise SearchCompleteException(ctx.results_count)
+
+ # 更新进度并启动下一批次
+ ctx.start += len(papers)
+ if ctx.should_continue():
+ await sleep(self.config.rate_limit_period)
+ await self._start_batch(tg, ctx, receive_stream)
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..630047b
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,25 @@
+from typing import Optional
+
+from pydantic import BaseModel, ConfigDict, Field, HttpUrl
+
+
+class ArxivConfig(BaseModel):
+ """arXiv API 配置类"""
+
+ base_url: HttpUrl = Field(
+ default="http://export.arxiv.org/api/query", description="arXiv API 基础URL"
+ )
+ timeout: float = Field(default=30.0, description="请求超时时间(秒)", gt=0)
+ max_retries: int = Field(default=3, description="最大重试次数", ge=0)
+ rate_limit_calls: int = Field(default=5, description="速率限制窗口内的最大请求数")
+ rate_limit_period: float = Field(default=1.0, description="速率限制窗口期(秒)")
+ max_concurrent_requests: int = Field(default=5, description="最大并发请求数")
+ proxy: Optional[str] = Field(default=None, description="HTTP/HTTPS代理URL")
+ log_level: str = Field(default="INFO", description="日志等级")
+ page_size: int = Field(default=10, description="每页结果数")
+ min_wait: float = Field(default=1.0, description="最小重试等待时间(秒)", gt=0)
+
+ model_config = ConfigDict(extra="allow")
+
+
+default_config = ArxivConfig()
diff --git a/src/exception.py b/src/exception.py
new file mode 100644
index 0000000..086604b
--- /dev/null
+++ b/src/exception.py
@@ -0,0 +1,244 @@
+from dataclasses import dataclass
+from http import HTTPStatus
+from typing import Any, Optional
+
+from pydantic import BaseModel, HttpUrl
+
+
+class ArxivException(Exception):
+ """基础异常类"""
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+
+class HTTPException(ArxivException):
+ """HTTP 请求相关异常"""
+
+ def __init__(self, status_code: int, message: Optional[str] = None):
+ self.status_code = status_code
+ self.message = message or HTTPStatus(status_code).description
+ super().__init__(self.message)
+
+
+class RateLimitException(HTTPException):
+ """达到 API 速率限制时的异常"""
+
+ def __init__(self, retry_after: Optional[int] = None):
+ self.retry_after = retry_after
+ super().__init__(429, "Too Many Requests")
+
+
+class ValidationException(ArxivException):
+ """数据验证异常"""
+
+ def __init__(
+ self,
+ message: str,
+ field_name: str,
+ input_value: Any,
+ expected_type: type,
+ model: Optional[type[BaseModel]] = None,
+ validation_errors: Optional[dict] = None,
+ ):
+ self.field_name = field_name
+ self.input_value = input_value
+ self.expected_type = expected_type
+ self.model = model
+ self.validation_errors = validation_errors
+ super().__init__(message)
+
+ def __str__(self) -> str:
+ error_msg = [
+ f"Validation error for field '{self.field_name}':",
+ f"Input value: {self.input_value!r}",
+ f"Expected type: {self.expected_type.__name__}",
+ ]
+
+ if self.model:
+ error_msg.append(f"Model: {self.model.__name__}")
+
+ if self.validation_errors:
+ error_msg.append("Detailed errors:")
+ for key, err in self.validation_errors.items():
+ error_msg.append(f" - {key}: {err}")
+
+ return "\n".join(error_msg)
+
+
+class TimeoutException(ArxivException):
+ """请求超时异常"""
+
+ def __init__(
+ self,
+ timeout: float,
+ message: Optional[str] = None,
+ proxy: Optional[HttpUrl] = None,
+ link: Optional[HttpUrl] = None,
+ ):
+ self.timeout = timeout
+ self.proxy = proxy
+ self.link = link
+ self.message = message or f"Request timed out after {timeout} seconds"
+ super().__init__(message)
+
+ def __str__(self) -> str:
+ error_msg = [
+ f"Request timed out after {self.timeout} seconds",
+ self.message,
+ ]
+
+ if self.proxy:
+ error_msg.append(f"Proxy: {self.proxy}")
+
+ if self.link:
+ error_msg.append(f"Link: {self.link}")
+
+ return "\n".join(error_msg)
+
+
+class RetryError(ArxivException):
+ """重试次数用尽异常"""
+
+ def __init__(self, attempts: int, last_error: Exception):
+ self.attempts = attempts
+ self.last_error = last_error
+ super().__init__(f"Failed after {attempts} attempts. Last error: {last_error}")
+
+
+@dataclass
+class ConfigError:
+ """配置错误详情"""
+
+ property_name: str
+ input_value: Any
+ expected_type: type
+ message: str
+
+
+class ConfigurationError(ArxivException):
+ """配置错误异常"""
+
+ def __init__(
+ self,
+ message: str,
+ property_name: str,
+ input_value: Any,
+ expected_type: type,
+ config_class: Optional[type] = None,
+ ):
+ self.property_name = property_name
+ self.input_value = input_value
+ self.expected_type = expected_type
+ self.config_class = config_class
+ self.message = message
+ super().__init__(message)
+
+ def __str__(self) -> str:
+ error_parts = [
+ f"Configuration error for '{self.property_name}':",
+ f"Input value: {self.input_value!r}",
+ f"Expected type: {self.expected_type.__name__}",
+ f"Message: {self.message}",
+ ]
+
+ if self.config_class:
+ error_parts.append(f"Config class: {self.config_class.__name__}")
+
+ return "\n".join(error_parts)
+
+
+@dataclass
+class QueryContext:
+ """查询构建上下文"""
+
+ params: dict[str, Any] # 查询参数
+ field_name: Optional[str] = None # 出错的字段
+ value: Optional[Any] = None # 问题值
+ constraint: Optional[str] = None # 违反的约束
+
+
+class QueryBuildError(ArxivException):
+ """查询构建错误"""
+
+ def __init__(
+ self,
+ message: str,
+ context: Optional[QueryContext] = None,
+ original_error: Optional[Exception] = None,
+ ):
+ self.message = message
+ self.context = context
+ self.original_error = original_error
+ super().__init__(message)
+
+ def __str__(self) -> str:
+ parts = [f"查询构建错误: {self.message}"]
+
+ if self.context:
+ parts.append("查询参数:")
+ for k, v in self.context.params.items():
+ parts.append(f" {k}: {v}")
+
+ if self.context.field_name:
+ parts.append(f"问题字段: {self.context.field_name}")
+ if self.context.value is not None:
+ parts.append(f"问题值: {self.context.value}")
+ if self.context.constraint:
+ parts.append(f"违反约束: {self.context.constraint}")
+
+ if self.original_error:
+ parts.append(f"原始错误: {self.original_error!s}")
+
+ return "\n".join(parts)
+
+
+@dataclass
+class ParseErrorContext:
+ """解析错误上下文"""
+
+ raw_content: Optional[str] = None
+ position: Optional[int] = None
+ element_name: Optional[str] = None
+ namespace: Optional[str] = None
+
+
+class ParserException(Exception):
+ """XML解析异常"""
+
+ def __init__(
+ self,
+ url: str,
+ message: str,
+ context: Optional[ParseErrorContext] = None,
+ original_error: Optional[Exception] = None,
+ ):
+ self.url = url
+ self.message = message
+ self.context = context
+ self.original_error = original_error
+ super().__init__(self.message)
+
+ def __str__(self) -> str:
+ parts = [f"解析错误: {self.message}", f"URL: {self.url}"]
+
+ if self.context:
+ if self.context.element_name:
+ parts.append(f"元素: {self.context.element_name}")
+ if self.context.namespace:
+ parts.append(f"命名空间: {self.context.namespace}")
+ if self.context.position is not None:
+ parts.append(f"位置: {self.context.position}")
+ if self.context.raw_content:
+ parts.append(f"原始内容: \n{self.context.raw_content[:200]}...")
+
+ if self.original_error:
+ parts.append(f"原始错误: {self.original_error!s}")
+
+ return "\n".join(parts)
+
+class SearchCompleteException(ArxivException):
+ """搜索完成异常"""
+ def __init__(self, total_results: int):
+ self.total_results = total_results
+ super().__init__(f"搜索完成,共获取{total_results}条结果")
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000..353bea1
--- /dev/null
+++ b/src/models/__init__.py
@@ -0,0 +1,61 @@
+from datetime import datetime
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field, HttpUrl
+
+
+class SortCriterion(str, Enum):
+ """排序标准"""
+
+ RELEVANCE = "relevance"
+ LAST_UPDATED = "lastUpdatedDate"
+ SUBMITTED = "submittedDate"
+
+
+class SortOrder(str, Enum):
+ """排序方向"""
+
+ ASCENDING = "ascending"
+ DESCENDING = "descending"
+
+
+class Author(BaseModel):
+ """作者模型"""
+
+ name: str
+ affiliation: Optional[str] = None
+
+
+class Category(BaseModel):
+ """分类模型"""
+
+ primary: str
+ secondary: list[str] = Field(default_factory=list)
+
+
+class Paper(BaseModel):
+ """论文模型"""
+
+ id: str = Field(description="arXiv ID")
+ title: str
+ summary: str
+ authors: list[Author]
+ categories: Category
+ doi: Optional[str] = None
+ journal_ref: Optional[str] = None
+ pdf_url: Optional[HttpUrl] = None
+ published: datetime
+ updated: datetime
+ comment: Optional[str] = None
+
+
+class SearchParams(BaseModel):
+ """搜索参数"""
+
+ query: str
+ id_list: list[str] = Field(default_factory=list)
+ max_results: Optional[int] = Field(default=None, gt=0)
+ start: int = Field(default=0, ge=0)
+ sort_by: Optional[SortCriterion] = None
+ sort_order: Optional[SortOrder] = None
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..90f2a64
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1,44 @@
+from pathlib import Path
+from time import monotonic
+from types import SimpleNamespace
+
+import aiohttp
+
+from .log import logger
+
+
+def create_trace_config() -> aiohttp.TraceConfig:
+ """
+ 创建请求追踪配置。
+
+ Returns:
+ aiohttp.TraceConfig: 请求追踪配置
+ """
+
+ async def _on_request_start(
+ session: aiohttp.ClientSession,
+ trace_config_ctx: SimpleNamespace,
+ params: aiohttp.TraceRequestStartParams,
+ ) -> None:
+ logger.debug(f"Starting request: {params.method} {params.url}")
+ trace_config_ctx.start_time = monotonic()
+
+ async def _on_request_end(
+ session: aiohttp.ClientSession,
+ trace_config_ctx: SimpleNamespace,
+ params: aiohttp.TraceRequestEndParams,
+ ) -> None:
+ elapsed_time = monotonic() - trace_config_ctx.start_time
+ logger.debug(
+ f"Ending request: {params.response.status} {params.url} - Time elapsed: "
+ f"{elapsed_time:.2f} seconds"
+ )
+
+ trace_config = aiohttp.TraceConfig()
+ trace_config.on_request_start.append(_on_request_start)
+ trace_config.on_request_end.append(_on_request_end)
+ return trace_config
+
+def get_project_root() -> Path:
+ """获取项目根目录"""
+ return Path(__file__).parent.parent.parent
diff --git a/src/utils/log.py b/src/utils/log.py
new file mode 100644
index 0000000..c983e53
--- /dev/null
+++ b/src/utils/log.py
@@ -0,0 +1,75 @@
+import inspect
+import logging
+import sys
+from typing import TYPE_CHECKING
+
+import loguru
+
+if TYPE_CHECKING:
+ # avoid sphinx autodoc resolve annotation failed
+ # because loguru module do not have `Logger` class actually
+ from loguru import Logger, Record
+
+logger: "Logger" = loguru.logger
+"""日志记录器对象。
+
+default:
+
+- 格式: `[%(asctime)s %(name)s] %(levelname)s: %(message)s`
+- 等级: `INFO` ,根据 `config.log_level` 配置改变
+- 输出: 输出至 stdout
+
+usage:
+ ```python
+ from log import logger
+ ```
+"""
+
+
+# https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
+class LoguruHandler(logging.Handler): # pragma: no cover
+ """logging 与 loguru 之间的桥梁,将 logging 的日志转发到 loguru。"""
+
+ def emit(self, record: logging.LogRecord):
+ try:
+ level = logger.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+
+ frame, depth = inspect.currentframe(), 0
+ while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
+ frame = frame.f_back
+ depth += 1
+
+ logger.opt(depth=depth, exception=record.exc_info).log(
+ level, record.getMessage()
+ )
+
+
+def default_filter(record: "Record"):
+ """默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
+ log_level = record["extra"].get("log_level", "INFO")
+ levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
+ return record["level"].no >= levelno
+
+
+default_format: str = (
+ "{time:MM-DD HH:mm:ss} "
+ "[{level}] "
+ "{name} | "
+ # "{function}:{line}| "
+ "{message}"
+)
+"""默认日志格式"""
+
+logger.remove()
+logger_id = logger.add(
+ sys.stdout,
+ level=0,
+ diagnose=False,
+ filter=default_filter,
+ format=default_format,
+)
+"""默认日志处理器 id"""
+
+__autodoc__ = {"logger_id": False}
diff --git a/src/utils/parser.py b/src/utils/parser.py
new file mode 100644
index 0000000..a2acbe2
--- /dev/null
+++ b/src/utils/parser.py
@@ -0,0 +1,339 @@
+import xml.etree.ElementTree as ET
+from datetime import datetime
+from typing import ClassVar, Optional, cast
+
+from pydantic import HttpUrl
+
+from ..exception import ParseErrorContext, ParserException
+from ..models import Author, Category, Paper
+from .log import logger
+
+
+class ArxivParser:
+ """
+ arXiv API响应解析器
+
+ Attributes:
+ NS (ClassVar[dict[str, str]]): XML命名空间
+
+ Args:
+ entry (ET.Element): 根元素
+
+ Raises:
+ ParserException: 如果解析失败
+ """
+ NS: ClassVar[dict[str, str]] = {
+ "atom": "http://www.w3.org/2005/Atom",
+ "opensearch": "http://a9.com/-/spec/opensearch/1.1/",
+ "arxiv": "http://arxiv.org/schemas/atom",
+ }
+ def __init__(self, entry: ET.Element):
+ self.entry = entry
+
+ def _create_parser_exception(
+ self, message: str, url: str = "", error: Optional[Exception] = None
+ ) -> ParserException:
+ """创建解析异常"""
+ return ParserException(
+ url=url,
+ message=message,
+ context=ParseErrorContext(
+ raw_content=ET.tostring(self.entry, encoding="unicode"),
+ element_name=self.entry.tag,
+ ),
+ original_error=error,
+ )
+
+ def parse_authors(self) -> list[Author]:
+ """
+ 解析作者信息
+
+ Returns:
+ list[Author]: 作者列表
+
+ Raises:
+ ParserException: 如果作者信息无效
+ """
+ logger.debug("开始解析作者信息")
+ authors = []
+ for author_elem in self.entry.findall("atom:author", self.NS):
+ name = author_elem.find("atom:name", self.NS)
+ if name is not None and name.text:
+ authors.append(Author(name=name.text))
+ else:
+ logger.warning("发现作者信息不完整")
+ return authors
+
+ def parse_entry_id(self) -> str:
+ """解析论文ID
+
+ Returns:
+ str: 论文ID
+
+ Raises:
+ ParserException: 如果ID元素缺失或无效
+ """
+ id_elem = self.entry.find("atom:id", self.NS)
+ if id_elem is None or id_elem.text is None:
+ raise ParserException(
+ url="",
+ message="缺少论文ID",
+ context=ParseErrorContext(
+ raw_content=ET.tostring(self.entry, encoding="unicode"),
+ element_name="id",
+ namespace=self.NS["atom"],
+ ),
+ )
+
+ return id_elem.text
+
+ def parse_categories(self) -> Category:
+ """
+ 解析分类信息
+
+ Returns:
+ Category: 分类信息
+
+ Raises:
+ ParserException: 如果分类信息无效
+ """
+ logger.debug("开始解析分类信息")
+ primary = self.entry.find("arxiv:primary_category", self.NS)
+ categories = self.entry.findall("atom:category", self.NS)
+
+ if primary is None or "term" not in primary.attrib:
+ logger.warning("未找到主分类信息,使用默认分类")
+ primary_category = "unknown"
+ else:
+ primary_category = primary.attrib["term"]
+
+ return Category(
+ primary=primary_category,
+ secondary=[c.attrib["term"] for c in categories if "term" in c.attrib],
+ )
+
+ def parse_required_fields(self) -> dict:
+ """
+ 解析必要字段
+
+ Returns:
+ dict: 必要字段字典
+
+ Raises:
+ ParserException: 如果字段缺失
+ """
+ fields = {
+ "title": self.entry.find("atom:title", self.NS),
+ "summary": self.entry.find("atom:summary", self.NS),
+ "published": self.entry.find("atom:published", self.NS),
+ "updated": self.entry.find("atom:updated", self.NS),
+ }
+
+ missing = [k for k, v in fields.items() if v is None or v.text is None]
+ if missing:
+ raise self._create_parser_exception(
+ f"缺少必要字段: {', '.join(missing)}"
+ )
+
+ return {
+ k: v.text for k, v in fields.items() if v is not None and v.text is not None
+ }
+
+ def _parse_pdf_url(self) -> Optional[str]:
+ """
+ 解析PDF链接
+
+ Returns:
+ Optional[str]: PDF链接或None
+
+ Raises:
+ ParserException: 如果PDF链接无效
+ """
+ try:
+ links = self.entry.findall("atom:link", self.NS)
+ if not links:
+ logger.warning("未找到任何链接")
+ return None
+
+ pdf_url = next(
+ (
+ link.attrib["href"]
+ for link in links
+ if link.attrib.get("type") == "application/pdf"
+ ),
+ None,
+ )
+
+ if pdf_url is None:
+ logger.warning("未找到PDF链接")
+
+ return pdf_url
+
+ except (KeyError, AttributeError) as e:
+ logger.error("解析PDF链接失败", exc_info=True)
+ raise ParserException(
+ url="",
+ message="解析PDF链接失败",
+ context=ParseErrorContext(
+ raw_content=ET.tostring(self.entry, encoding="unicode"),
+ element_name="link",
+ namespace=self.NS["atom"],
+ ),
+ original_error=e,
+ )
+
+ def parse_optional_fields(self) -> dict:
+ """
+ 解析可选字段
+
+ Returns:
+ dict: 可选字段字典
+
+ Raises:
+ ParserException: 如果字段无效
+ """
+ fields = {
+ "comment": self.entry.find("arxiv:comment", self.NS),
+ "journal_ref": self.entry.find("arxiv:journal_ref", self.NS),
+ "doi": self.entry.find("arxiv:doi", self.NS),
+ }
+
+ return {k: v.text if v is not None else None for k, v in fields.items()}
+
+ @staticmethod
+ def parse_datetime(date_str: str) -> datetime:
+ """
+ 解析ISO格式的日期时间字符串
+
+ Args:
+ date_str: ISO格式的日期时间字符串
+
+ Returns:
+ datetime: 解析后的datetime对象
+
+ Raises:
+ ValueError: 日期格式无效
+ """
+ try:
+ normalized_date = date_str.replace("Z", "+00:00")
+ return datetime.fromisoformat(normalized_date)
+ except ValueError as e:
+ logger.error(f"日期解析失败: {date_str}", exc_info=True)
+ raise ValueError(f"无效的日期格式: {date_str}") from e
+
+ def build_paper(
+ self,
+ index: int,
+ ) -> Paper:
+ """统一处理论文解析"""
+ try:
+ required_fields = self.parse_required_fields()
+ return Paper(
+ id=self.parse_entry_id().split("/")[-1],
+ title=required_fields["title"],
+ summary=required_fields["summary"],
+ authors=self.parse_authors(),
+ categories=self.parse_categories(),
+ pdf_url=cast(HttpUrl, self._parse_pdf_url()),
+ published=self.parse_datetime(
+ required_fields["published"].replace("Z", "+00:00")
+ ),
+ updated=self.parse_datetime(
+ required_fields["updated"].replace("Z", "+00:00")
+ ),
+ **self.parse_optional_fields(),
+ )
+ except ParserException:
+ raise
+ except Exception as e:
+ raise ParserException(
+ url="",
+ message=f"解析第 {index + 1} 篇论文失败",
+ context=ParseErrorContext(
+ raw_content=ET.tostring(self.entry, encoding="unicode"),
+ position=index,
+ element_name=self.entry.tag,
+ ),
+ original_error=e,
+ )
+
+ @classmethod
+ def _parse_root(
+ cls,
+ root: ET.Element,
+ url: str
+ ) -> tuple[list[Paper], int]:
+ """解析根元素"""
+ # 解析总结果数
+ total_element = root.find("opensearch:totalResults", cls.NS)
+ if total_element is None or total_element.text is None:
+ raise ParserException(
+ url=url,
+ message="缺少总结果数元素",
+ context=ParseErrorContext(
+ raw_content=ET.tostring(root, encoding="unicode"),
+ element_name="totalResults",
+ namespace=cls.NS["opensearch"],
+ ),
+ )
+
+ total_results = int(total_element.text)
+
+ # 解析论文列表
+ papers = []
+ for i, entry in enumerate(root.findall("atom:entry", cls.NS)):
+ try:
+ parser = cls(entry)
+ papers.append(parser.build_paper(i))
+ except Exception as e:
+ raise ParserException(
+ url=url,
+ message=f"解析第 {i + 1} 篇论文失败",
+ context=ParseErrorContext(
+ raw_content=ET.tostring(entry, encoding="unicode"),
+ position=i,
+ element_name=entry.tag,
+ namespace=cls.NS["atom"],
+ ),
+ original_error=e,
+ )
+
+ return papers, total_results
+
+ @classmethod
+ async def parse_feed(
+ cls,
+ content: str,
+ url: str = ""
+ ) -> tuple[list[Paper], int]:
+ """
+ 解析arXiv API的Atom feed内容
+
+ Args:
+ content: XML内容
+ url: 请求URL,用于错误上下文
+
+ Returns:
+ tuple[list[Paper], int]: 论文列表和总结果数
+ """
+ logger.debug("开始解析feed内容")
+ try:
+ root = ET.fromstring(content)
+ return cls._parse_root(root, url)
+ except ET.ParseError as e:
+ logger.error("XML格式错误", exc_info=True)
+ raise ParserException(
+ url=url,
+ message="XML格式错误",
+ context=ParseErrorContext(raw_content=content),
+ original_error=e,
+ )
+ except ParserException:
+ raise
+ except Exception as e:
+ raise ParserException(
+ url=url,
+ message="未知解析错误",
+ context=ParseErrorContext(raw_content=content),
+ original_error=e,
+ )
diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py
new file mode 100644
index 0000000..4c816ea
--- /dev/null
+++ b/src/utils/rate_limiter.py
@@ -0,0 +1,137 @@
+import asyncio
+from dataclasses import dataclass
+from types import TracebackType
+from typing import ClassVar, Optional
+
+from ..config import default_config
+from .log import logger
+
+
+@dataclass
+class RateLimitState:
+ """速率限制状态"""
+
+ remaining: int
+ reset_at: float
+ window_start: float
+
+
+class RateLimiter:
+ """
+ 速率限制器
+
+ 用于限制请求速率,防止过多请求导致服务器拒绝服务。
+
+ Attributes:
+ DEFAULT_CALLS: ClassVar[int]: 默认窗口期内的最大请求数
+ DEFAULT_PERIOD: ClassVar[float]: 默认窗口期
+ calls: int: 窗口期内的最大请求数
+ period: float: 窗口期
+ timestamps: list[float]: 请求时间戳列表
+ _lock: asyncio.Lock: 锁
+ _last_check: Optional[float]: 上次检查时间
+ _logger: logging.Logger: 日志记录器
+ """
+
+ # 从配置获取默认值
+ DEFAULT_CALLS: ClassVar[int] = default_config.rate_limit_calls
+ DEFAULT_PERIOD: ClassVar[float] = default_config.rate_limit_period
+
+ def __init__(self, calls: int = DEFAULT_CALLS, period: float = DEFAULT_PERIOD):
+ """
+ 初始化速率限制器
+
+ Args:
+ calls: 窗口期内的最大请求数,默认从配置获取
+ period: 窗口期,默认从配置获取
+ """
+ if calls <= 0:
+ raise ValueError("calls must be positive")
+ if period <= 0:
+ raise ValueError("period must be positive")
+ self.calls = calls
+ self.period = period
+ self.timestamps: list[float] = []
+ self._lock = asyncio.Lock()
+ self._last_check: Optional[float] = None
+ self._logger = logger
+
+ @property
+ def is_limited(self) -> bool:
+ """当前是否处于限制状态"""
+ # 使用当前时间作为参考点
+ loop = asyncio.get_running_loop()
+ now = loop.time()
+
+ # 获取有效时间戳并更新
+ valid_stamps = [t for t in self.timestamps if (now - t) < self.period]
+ self.timestamps = valid_stamps
+
+ return len(valid_stamps) >= self.calls
+
+ def _get_valid_timestamps(self, now: float) -> list[float]:
+ """获取有效的时间戳列表"""
+ return [t for t in self.timestamps if now - t <= self.period]
+
+ @property
+ async def state(self) -> RateLimitState:
+ """获取当前速率限制状态"""
+ async with self._lock:
+ now = asyncio.get_event_loop().time()
+ valid_timestamps = self._get_valid_timestamps(now)
+
+ return RateLimitState(
+ remaining=max(0, self.calls - len(valid_timestamps)),
+ reset_at=min(self.timestamps, default=now) + self.period
+ if self.timestamps
+ else now,
+ window_start=now,
+ )
+
+ async def acquire(self) -> None:
+ """获取访问许可"""
+ async with self._lock:
+ now = asyncio.get_event_loop().time()
+
+ self.timestamps = self._get_valid_timestamps(now)
+
+ # 检查是否需要等待
+ if len(self.timestamps) >= self.calls:
+ sleep_time = self.timestamps[0] + self.period - now
+ if sleep_time > 0:
+ self._logger.debug(
+ "触发速率限制",
+ extra={
+ "wait_time": f"{sleep_time:.2f}s",
+ "current_calls": len(self.timestamps),
+ "max_calls": self.calls,
+ },
+ )
+ await asyncio.sleep(sleep_time)
+
+ self.timestamps.append(now)
+ self._last_check = now
+
+ self._logger.debug(
+ "获取访问许可",
+ extra={
+ "remaining_calls": self.calls - len(self.timestamps),
+ "window_reset_in": f"{(self.timestamps[0] + self.period - now):.2f}s"
+ if self.timestamps
+ else "0s",
+ },
+ )
+
+ async def __aenter__(self) -> "RateLimiter":
+ """进入速率限制上下文"""
+ await self.acquire()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ """退出速率限制上下文"""
+ pass
diff --git a/src/utils/session.py b/src/utils/session.py
new file mode 100644
index 0000000..d5593e7
--- /dev/null
+++ b/src/utils/session.py
@@ -0,0 +1,121 @@
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
+from types import TracebackType
+from typing import Optional
+
+from aiohttp import (
+ ClientResponse,
+ ClientSession,
+ ClientTimeout,
+ TraceConfig,
+)
+
+from ..config import ArxivConfig, default_config
+from ..utils import create_trace_config
+from .log import logger
+from .rate_limiter import RateLimiter
+
+
+class SessionManager:
+ def __init__(
+ self,
+ config: Optional[ArxivConfig] = None,
+ session: Optional[ClientSession] = None,
+ rate_limiter: Optional[RateLimiter] = None,
+ trace_config: Optional[TraceConfig] = None,
+ ):
+ """
+ 初始化会话管理器
+
+ Args:
+ config: arXiv API配置对象
+ session: aiohttp会话
+ rate_limiter: 速率限制器
+ trace_config: 请求追踪配置
+ """
+ self._config = config or default_config
+ self._timeout = ClientTimeout(
+ total=self._config.timeout,
+ )
+ self._session = session
+ self._rate_limiter = rate_limiter or RateLimiter(
+ calls=self._config.rate_limit_calls,
+ period=self._config.rate_limit_period,
+ )
+ self._trace_config = trace_config or create_trace_config()
+
+ @asynccontextmanager
+ async def rate_limited_context(self) -> AsyncIterator[None]:
+ """获取速率限制上下文"""
+ limiter = RateLimiter()
+ await limiter.acquire()
+ yield
+
+ async def get_session(self) -> ClientSession:
+ """
+ 获取或创建会话
+
+ Returns:
+ ClientSession: aiohttp会话
+ """
+ if self._session is None or self._session.closed:
+ self._session = ClientSession(
+ timeout=self._timeout,
+ trace_configs=[self._trace_config],
+ )
+ return self._session
+
+ async def request(self, method: str, url: str, **kwargs) -> ClientResponse:
+ """
+ 发送受速率限制的请求
+
+ Args:
+ method: HTTP 方法
+ url: 请求URL
+ **kwargs: 传递给 aiohttp.ClientSession.request 的额外参数
+
+ Returns:
+ ClientResponse: aiohttp 响应对象
+ """
+ session = await self.get_session()
+
+ if self._rate_limiter:
+ await self._rate_limiter.acquire()
+
+ if self._config.proxy:
+ logger.debug(f"使用代理: {self._config.proxy}")
+
+ return await session.request(
+ method, url, proxy=self._config.proxy or None, **kwargs
+ )
+
+ async def close(self) -> None:
+ """关闭会话"""
+ if self._session and not self._session.closed:
+ await self._session.close()
+ self._session = None
+
+ async def __aenter__(self) -> "SessionManager":
+ """
+ 进入会话管理器
+
+ Returns:
+ SessionManager: 会话管理器实例
+ """
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ """
+ 退出会话管理器
+
+ Args:
+ exc_type: 异常类型
+ exc_val: 异常值
+ exc_tb: 异常回溯
+ """
+ await self.close()
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_client/__init__.py b/tests/test_client/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/utils/test_rate_limiter.py b/tests/utils/test_rate_limiter.py
new file mode 100644
index 0000000..fae1995
--- /dev/null
+++ b/tests/utils/test_rate_limiter.py
@@ -0,0 +1,92 @@
+import asyncio
+
+import pytest
+
+from src.utils.rate_limiter import RateLimiter
+
+
+@pytest.fixture
+def rate_limiter():
+ return RateLimiter(calls=3, period=1.0)
+
+
+class TestRateLimiter:
+ def test_init_with_invalid_params(self):
+ """测试无效的初始化参数"""
+ with pytest.raises(ValueError, match="calls must be positive"):
+ RateLimiter(calls=0)
+
+ with pytest.raises(ValueError, match="period must be positive"):
+ RateLimiter(period=0)
+
+ @pytest.mark.asyncio
+ async def test_acquire_within_limit(self, rate_limiter):
+ """测试在限制范围内获取许可"""
+ for _ in range(rate_limiter.calls):
+ await rate_limiter.acquire()
+ assert len(rate_limiter.timestamps) == rate_limiter.calls
+
+ @pytest.mark.asyncio
+ async def test_acquire_exceeds_limit(self, rate_limiter):
+ """测试超出限制时的等待"""
+ start_time = asyncio.get_event_loop().time()
+
+ # 先用完配额
+ for _ in range(rate_limiter.calls):
+ await rate_limiter.acquire()
+
+ # 下一次请求应该等待
+ await rate_limiter.acquire()
+
+ elapsed = asyncio.get_event_loop().time() - start_time
+ assert elapsed >= rate_limiter.period
+
+ @pytest.mark.asyncio
+ async def test_state(self, rate_limiter):
+ """测试速率限制状态"""
+ # 初始状态
+ state = await rate_limiter.state
+ assert state.remaining == rate_limiter.calls
+
+ # 使用一个许可后的状态
+ await rate_limiter.acquire()
+ state = await rate_limiter.state
+ assert state.remaining == rate_limiter.calls - 1
+
+ @pytest.mark.asyncio
+ async def test_context_manager(self, rate_limiter):
+ """测试上下文管理器"""
+ async with rate_limiter as limiter:
+ assert isinstance(limiter, RateLimiter)
+ assert len(limiter.timestamps) == 1
+
+ @pytest.mark.asyncio
+ async def test_window_sliding(self, rate_limiter):
+ """测试时间窗口滑动"""
+ # 填满时间窗口
+ for _ in range(rate_limiter.calls):
+ await rate_limiter.acquire()
+
+ # 等待窗口滑动
+ await asyncio.sleep(rate_limiter.period)
+
+ # 应该可以立即获取新的许可
+ start = asyncio.get_event_loop().time()
+ await rate_limiter.acquire()
+ elapsed = asyncio.get_event_loop().time() - start
+ assert elapsed < 0.1 # 不应该有明显延迟
+
+ @pytest.mark.asyncio
+ async def test_is_limited(self, rate_limiter):
+ """测试是否处于限制状态"""
+ # 初始状态未受限
+ assert not rate_limiter.is_limited
+
+ # 填满时间窗口
+ for _ in range(rate_limiter.calls):
+ await rate_limiter.acquire()
+ assert rate_limiter.is_limited
+
+ # 等待窗口滑动
+ await asyncio.sleep(rate_limiter.period)
+ assert not rate_limiter.is_limited
diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py
new file mode 100644
index 0000000..4493491
--- /dev/null
+++ b/tests/utils/test_session.py
@@ -0,0 +1,154 @@
+# tests/utils/test_session.py
+
+import pytest
+from aiohttp import ClientResponse, ClientSession, ClientTimeout
+from pytest_mock import MockerFixture
+
+from src.config import ArxivConfig
+from src.utils.rate_limiter import RateLimiter
+from src.utils.session import SessionManager
+
+
+@pytest.fixture
+def config():
+ return ArxivConfig(
+ timeout=30,
+ rate_limit_calls=3,
+ rate_limit_period=1,
+ )
+
+
+@pytest.fixture
+def rate_limiter():
+ return RateLimiter(calls=3, period=1)
+
+
+@pytest.fixture
+async def session_manager(config, rate_limiter):
+ manager = SessionManager(config=config, rate_limiter=rate_limiter)
+ yield manager
+ await manager.close()
+
+
+@pytest.mark.asyncio
+async def test_init_with_config(config):
+ """测试使用配置初始化"""
+ manager = SessionManager(config=config)
+ assert manager._config == config
+ assert isinstance(manager._timeout, ClientTimeout)
+ assert manager._timeout.total == config.timeout
+
+
+@pytest.mark.asyncio
+async def test_get_session(session_manager):
+ """测试获取会话"""
+ session = await session_manager.get_session()
+ assert isinstance(session, ClientSession)
+ assert not session.closed
+
+
+@pytest.mark.asyncio
+async def test_reuse_existing_session(session_manager):
+ """测试复用现有会话"""
+ session1 = await session_manager.get_session()
+ session2 = await session_manager.get_session()
+ assert session1 is session2
+
+
+@pytest.mark.asyncio
+async def test_rate_limited_context(session_manager, mocker: MockerFixture):
+ """测试速率限制上下文"""
+ mock_limiter = mocker.patch.object(RateLimiter, "acquire")
+
+ async with session_manager.rate_limited_context():
+ pass
+
+ mock_limiter.assert_awaited_once()
+
+
+@pytest.mark.asyncio
+async def test_close(session_manager):
+ """测试关闭会话"""
+ session = await session_manager.get_session()
+ assert not session.closed
+
+ await session_manager.close()
+ assert session.closed
+ assert session_manager._session is None
+
+
+@pytest.mark.asyncio
+async def test_context_manager(session_manager):
+ """测试上下文管理器"""
+ async with session_manager as manager:
+ session = await manager.get_session()
+ assert not session.closed
+
+ assert session.closed
+
+
+@pytest.mark.asyncio
+async def test_request(session_manager, mocker: MockerFixture):
+ """测试请求方法(包含速率限制)"""
+ # 1. 模拟响应
+ mock_response = mocker.AsyncMock(spec=ClientResponse)
+
+ # 2. 模拟会话
+ mock_session = mocker.AsyncMock(spec=ClientSession)
+ mock_session.request = mocker.AsyncMock(return_value=mock_response)
+
+ # 3. 模拟 get_session
+ async def mock_get_session():
+ return mock_session
+
+ session_manager.get_session = mock_get_session
+
+ # 4. 模拟速率限制器
+ mock_acquire = mocker.AsyncMock()
+ session_manager._rate_limiter.acquire = mock_acquire
+
+ # 5. 执行请求
+ response = await session_manager.request("GET", "http://example.com")
+
+ # 6. 验证
+ assert response == mock_response
+ mock_acquire.assert_awaited_once()
+ mock_session.request.assert_called_once_with(
+ "GET", "http://example.com", proxy=None
+ )
+
+
+@pytest.mark.asyncio
+async def test_request_creates_new_session_if_closed(session_manager):
+ """测试请求时如果会话已关闭则创建新会话"""
+ session1 = await session_manager.get_session()
+ await session1.close()
+
+ session2 = await session_manager.get_session()
+ assert session2 is not session1
+ assert not session2.closed
+
+
+@pytest.mark.asyncio
+async def test_request_with_proxy(session_manager, mocker: MockerFixture):
+ """测试请求方法(包含代理)"""
+ mock_response = mocker.AsyncMock(spec=ClientResponse)
+ mock_session = mocker.AsyncMock(spec=ClientSession)
+ mock_session.request = mocker.AsyncMock(return_value=mock_response)
+
+ async def mock_get_session():
+ return mock_session
+
+ session_manager.get_session = mock_get_session
+
+ mock_acquire = mocker.AsyncMock()
+ session_manager._rate_limiter.acquire = mock_acquire
+
+ session_manager._config.proxy = "http://proxy.com"
+ response = await session_manager.request("GET", "http://example.com")
+
+ assert response == mock_response
+ mock_acquire.assert_awaited_once()
+ mock_session.request.assert_called_once_with(
+ "GET", "http://example.com", proxy="http://proxy.com"
+ )