diff --git a/.dockerignore b/.dockerignore index b456f25..df5f9db 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,5 @@ # Version control .git -.gitignore # Python __pycache__ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a2cbe1a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +name: CI + +on: + push: + branches: [ "develop", "master" ] + pull_request: + branches: [ "develop", "master" ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + fail-fast: false + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up pip cache + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install PyTorch CPU + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Install dependencies + run: | + pip install ruff pytest-cov + pip install -r requirements.txt + pip install -r requirements-test.txt + + - name: Lint with ruff + run: | + ruff check . + + + - name: Test with pytest + run: | + pytest --asyncio-mode=auto --cov=api --cov-report=term-missing diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index eb97779..c9d860b 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -49,10 +49,10 @@ jobs: flavor: | suffix=-cpu tags: | - type=semver,pattern=v{{version}}-cpu - type=semver,pattern=v{{major}}.{{minor}}-cpu - type=semver,pattern=v{{major}}-cpu - type=raw,value=latest-cpu + type=semver,pattern=v{{version}} + type=semver,pattern=v{{major}}.{{minor}} + type=semver,pattern=v{{major}} + type=raw,value=latest # Build and push GPU version - name: Build and push GPU Docker image @@ -85,10 +85,10 @@ jobs: flavor: | suffix=-ui tags: | - type=semver,pattern=v{{version}}-ui - type=semver,pattern=v{{major}}.{{minor}}-ui - type=semver,pattern=v{{major}}-ui - type=raw,value=latest-ui + type=semver,pattern=v{{version}} + type=semver,pattern=v{{major}}.{{minor}} + type=semver,pattern=v{{major}} + type=raw,value=latest # Build and push UI version - name: Build and push UI Docker image diff --git a/.github/workflows/sync-develop.yml b/.github/workflows/sync-develop.yml new file mode 100644 index 0000000..56b881f --- /dev/null +++ b/.github/workflows/sync-develop.yml @@ -0,0 +1,55 @@ +# name: Sync develop with master + +# on: +# push: +# branches: +# - master + +# jobs: +# sync-develop: +# runs-on: ubuntu-latest +# permissions: +# contents: write +# issues: write +# steps: +# - name: Checkout repository +# uses: actions/checkout@v4 +# with: +# fetch-depth: 0 +# ref: develop + +# - name: Configure Git +# run: | +# git config user.name "GitHub Actions" +# git config user.email "actions@github.com" + +# - name: Merge master into develop +# run: | +# git fetch origin master:master +# git merge --no-ff origin/master -m "chore: Merge master into develop branch" + +# - name: Push changes +# run: | +# if ! git push origin develop; then +# echo "Failed to push to develop branch" +# exit 1 +# fi + +# - name: Handle Failure +# if: failure() +# uses: actions/github-script@v7 +# with: +# script: | +# const issueBody = `Automatic merge from master to develop failed. + +# Please resolve this manually + +# Workflow run: ${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`; + +# await github.rest.issues.create({ +# owner: context.repo.owner, +# repo: context.repo.repo, +# title: '🔄 Automatic master to develop merge failed', +# body: issueBody, +# labels: ['merge-failed', 'automation'] +# }); diff --git a/.gitignore b/.gitignore index cf4c28f..781794b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,29 +1,52 @@ +# Version control +.git -output/* -output_audio/* -ui/data/* - -*.db +# Python +__pycache__ *.pyc -*.pth +*.pyo +*.pyd *.pt - -Kokoro-82M/* -__pycache__/ -.vscode/ -env/ .Python - - +*.py[cod] +*$py.class +.pytest_cache .coverage +.coveragerc -examples/assorted_checks/benchmarks/output_audio/* -examples/assorted_checks/test_combinations/output/* -examples/assorted_checks/test_openai/output/* - -examples/assorted_checks/test_voices/output/* -examples/assorted_checks/test_formats/output/* -examples/assorted_checks/benchmarks/output_audio_stream/* -ui/RepoScreenshot.png -examples/assorted_checks/benchmarks/output_audio_stream_openai/* - +# Environment +# .env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea +.vscode +*.swp +*.swo + +# Project specific +*examples/*.wav +*examples/*.pcm +*examples/*.mp3 +*examples/*.flac +*examples/*.acc +*examples/*.ogg + +Kokoro-82M/ +ui/data +tests/ +*.md +*.txt +requirements.txt + +# Docker +Dockerfile* +docker-compose* + +*.egg-info +*.pt +*.wav +*.tar* diff --git a/CHANGELOG.md b/CHANGELOG.md index 6303af1..25643d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ Notable changes to this project will be documented in this file. +## [v0.0.5post1] - 2025-01-11 +### Fixed +- Docker image tagging and versioning improvements (-gpu, -cpu, -ui) +- Minor vram management improvements +- Gradio bugfix causing crashes and errant warnings +- Updated GPU and UI container configurations + +## [v0.0.5] - 2025-01-10 +### Fixed +- Stabilized issues with images tagging and structures from v0.0.4 +- Added automatic master to develop branch synchronization +- Improved release tagging and structures +- Initial CI/CD setup + ## 2025-01-04 ### Added - ONNX Support: diff --git a/Dockerfile b/Dockerfile index 7d70af9..3cc5689 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ espeak-ng \ git \ libsndfile1 \ + curl \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/Dockerfile.cpu b/Dockerfile.cpu index e9f2d3b..ed7b792 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -7,12 +7,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ espeak-ng \ git \ libsndfile1 \ + curl \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* # Install PyTorch CPU version and ONNX runtime -RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu && \ - pip3 install --no-cache-dir onnxruntime==1.20.1 +RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu # Install all other dependencies from requirements.txt COPY requirements.txt . diff --git a/README.md b/README.md index 743136e..fddea39 100644 --- a/README.md +++ b/README.md @@ -3,18 +3,20 @@

# Kokoro TTS API -[![Tests](https://img.shields.io/badge/tests-111%20passed-darkgreen)]() +[![Tests](https://img.shields.io/badge/tests-117%20passed-darkgreen)]() [![Coverage](https://img.shields.io/badge/coverage-75%25-darkgreen)]() -[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) +[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [![Buy Me A Coffee](https://img.shields.io/badge/BMC-✨☕-gray?style=flat-square)](https://www.buymeacoffee.com/remsky) Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model - OpenAI-compatible Speech endpoint, with inline voice combination functionality -- NVIDIA GPU accelerated inference (or CPU Onnx) option +- NVIDIA GPU accelerated or CPU Onnx inference - very fast generation time - - 35x+ real time speed via 4060Ti, ~300ms latency - - 5x+ real time spead via M3 Pro CPU, ~1000ms latency + - 100x+ real time speed via HF A100 + - 35-50x+ real time speed via 4060Ti + - 5x+ real time speed via M3 Pro CPU - streaming support w/ variable chunking to control latency & artifacts - simple audio generation web ui utility +- (new) phoneme endpoints for conversion and generation ## Quick Start @@ -27,13 +29,14 @@ The service can be accessed through either the API endpoints or the Gradio web i ```bash git clone https://github.com/remsky/Kokoro-FastAPI.git cd Kokoro-FastAPI - docker compose up --build + docker compose up --build # for GPU + #docker compose -f docker-compose.cpu.yml up --build # for CPU ``` 2. Run locally as an OpenAI-Compatible Speech Endpoint ```python from openai import OpenAI client = OpenAI( - base_url="http://localhost:8880", + base_url="http://localhost:8880/v1", api_key="not-needed" ) @@ -58,7 +61,7 @@ The service can be accessed through either the API endpoints or the Gradio web i ```python # Using OpenAI's Python library from openai import OpenAI -client = OpenAI(base_url="http://localhost:8880", api_key="not-needed") +client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed") response = client.audio.speech.create( model="kokoro", # Not used but required for compatibility, also accepts library defaults voice="af_bella+af_sky", @@ -95,8 +98,8 @@ with open("output.mp3", "wb") as f: Quick tests (run from another terminal): ```bash -python examples/test_openai_tts.py # Test OpenAI Compatibility -python examples/test_all_voices.py # Test all available voices +python examples/assorted_checks/test_openai/test_openai_tts.py # Test OpenAI Compatibility +python examples/assorted_checks/test_voices/test_all_voices.py # Test all available voices ``` @@ -229,8 +232,9 @@ for chunk in response.iter_content(chunk_size=1024): Key Streaming Metrics: - First token latency @ chunksize - - ~300ms (GPU) @ 400 - - ~3500ms (CPU) @ 200 + - ~300ms (GPU) @ 400 + - ~3500ms (CPU) @ 200 (older i7) + - ~<1s (CPU) @ 200 (M3 Pro) - Adjustable chunking settings for real-time playback *Note: Artifacts in intonation can increase with smaller chunks* @@ -277,6 +281,90 @@ docker compose -f docker-compose.cpu.yml up --build - Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output +
+Phoneme & Token Routes + +Convert text to phonemes and/or generate audio directly from phonemes: +```python +import requests + +# Convert text to phonemes +response = requests.post( + "http://localhost:8880/dev/phonemize", + json={ + "text": "Hello world!", + "language": "a" # "a" for American English + } +) +result = response.json() +phonemes = result["phonemes"] # Phoneme string e.g ðɪs ɪz ˈoʊnli ɐ tˈɛst +tokens = result["tokens"] # Token IDs including start/end tokens + +# Generate audio from phonemes +response = requests.post( + "http://localhost:8880/dev/generate_from_phonemes", + json={ + "phonemes": phonemes, + "voice": "af_bella", + "speed": 1.0 + } +) + +# Save WAV audio +with open("speech.wav", "wb") as f: + f.write(response.content) +``` + +See `examples/phoneme_examples/generate_phonemes.py` for a sample script. +
+ +## Known Issues + +
+Linux GPU Permissions + +Some Linux users may encounter GPU permission issues when running as non-root. +Can't guarantee anything, but here are some common solutions, consider your security requirements carefully + +### Option 1: Container Groups (Likely the best option) +```yaml +services: + kokoro-tts: + # ... existing config ... + group_add: + - "video" + - "render" +``` + +### Option 2: Host System Groups +```yaml +services: + kokoro-tts: + # ... existing config ... + user: "${UID}:${GID}" + group_add: + - "video" +``` +Note: May require adding host user to groups: `sudo usermod -aG docker,video $USER` and system restart. + +### Option 3: Device Permissions (Use with caution) +```yaml +services: + kokoro-tts: + # ... existing config ... + devices: + - /dev/nvidia0:/dev/nvidia0 + - /dev/nvidiactl:/dev/nvidiactl + - /dev/nvidia-uvm:/dev/nvidia-uvm +``` +⚠️ Warning: Reduces system security. Use only in development environments. + +Prerequisites: NVIDIA GPU, drivers, and container toolkit must be properly configured. + +Visit [NVIDIA Container Toolkit installation](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) for more detailed information + +
+ ## Model and License
diff --git a/api/src/core/config.py b/api/src/core/config.py index ad0ef1c..2174bce 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): sample_rate: int = 24000 max_chunk_size: int = 300 # Maximum size of text chunks for processing gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds - + # ONNX Optimization Settings onnx_num_threads: int = 4 # Number of threads for intra-op parallelism onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism diff --git a/api/src/main.py b/api/src/main.py index fc51043..93f5f34 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -2,6 +2,7 @@ FastAPI OpenAI Compatible API """ +import sys from contextlib import asynccontextmanager import uvicorn @@ -11,11 +12,33 @@ from .core.config import settings from .services.tts_model import TTSModel +from .routers.development import router as dev_router from .services.tts_service import TTSService from .routers.openai_compatible import router as openai_router -from .routers.text_processing import router as text_router +def setup_logger(): + """Configure loguru logger with custom formatting""" + config = { + "handlers": [ + { + "sink": sys.stdout, + "format": "{time:hh:mm:ss A} | " + "{level: <8} | " + "{message}", + "colorize": True, + "level": "INFO", + }, + ], + } + logger.remove() + logger.configure(**config) + logger.level("ERROR", color="") + + +# Configure logger +setup_logger() + @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for model initialization""" @@ -25,7 +48,7 @@ async def lifespan(app: FastAPI): voicepack_count = await TTSModel.setup() # boundary = "█████╗"*9 boundary = "░" * 24 - startup_msg =f""" + startup_msg = f""" {boundary} @@ -67,7 +90,8 @@ async def lifespan(app: FastAPI): # Include routers app.include_router(openai_router, prefix="/v1") -app.include_router(text_router) +app.include_router(dev_router) # New development endpoints +# app.include_router(text_router) # Deprecated but still live for backwards compatibility # Health check endpoint diff --git a/api/src/routers/development.py b/api/src/routers/development.py new file mode 100644 index 0000000..c7c938b --- /dev/null +++ b/api/src/routers/development.py @@ -0,0 +1,130 @@ +from typing import List + +import numpy as np +from loguru import logger +from fastapi import Depends, Response, APIRouter, HTTPException + +from ..services.audio import AudioService +from ..services.tts_model import TTSModel +from ..services.tts_service import TTSService +from ..structures.text_schemas import ( + PhonemeRequest, + PhonemeResponse, + GenerateFromPhonemesRequest, +) +from ..services.text_processing import tokenize, phonemize + +router = APIRouter(tags=["text processing"]) + + +def get_tts_service() -> TTSService: + """Dependency to get TTSService instance""" + return TTSService() + + +@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"]) +@router.post("/dev/phonemize", response_model=PhonemeResponse) +async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse: + """Convert text to phonemes and tokens + + Args: + request: Request containing text and language + tts_service: Injected TTSService instance + + Returns: + Phonemes and token IDs + """ + try: + if not request.text: + raise ValueError("Text cannot be empty") + + # Get phonemes + phonemes = phonemize(request.text, request.language) + if not phonemes: + raise ValueError("Failed to generate phonemes") + + # Get tokens + tokens = tokenize(phonemes) + tokens = [0] + tokens + [0] # Add start/end tokens + + return PhonemeResponse(phonemes=phonemes, tokens=tokens) + except ValueError as e: + logger.error(f"Error in phoneme generation: {str(e)}") + raise HTTPException( + status_code=500, detail={"error": "Server error", "message": str(e)} + ) + except Exception as e: + logger.error(f"Error in phoneme generation: {str(e)}") + raise HTTPException( + status_code=500, detail={"error": "Server error", "message": str(e)} + ) + + +@router.post("/text/generate_from_phonemes", tags=["deprecated"]) +@router.post("/dev/generate_from_phonemes") +async def generate_from_phonemes( + request: GenerateFromPhonemesRequest, + tts_service: TTSService = Depends(get_tts_service), +) -> Response: + """Generate audio directly from phonemes + + Args: + request: Request containing phonemes and generation parameters + tts_service: Injected TTSService instance + + Returns: + WAV audio bytes + """ + # Validate phonemes first + if not request.phonemes: + raise HTTPException( + status_code=400, + detail={"error": "Invalid request", "message": "Phonemes cannot be empty"}, + ) + + # Validate voice exists + voice_path = tts_service._get_voice_path(request.voice) + if not voice_path: + raise HTTPException( + status_code=400, + detail={ + "error": "Invalid request", + "message": f"Voice not found: {request.voice}", + }, + ) + + try: + # Load voice + voicepack = tts_service._load_voice(voice_path) + + # Convert phonemes to tokens + tokens = tokenize(request.phonemes) + tokens = [0] + tokens + [0] # Add start/end tokens + + # Generate audio directly from tokens + audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed) + + # Convert to WAV bytes + wav_bytes = AudioService.convert_audio( + audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False + ) + + return Response( + content=wav_bytes, + media_type="audio/wav", + headers={ + "Content-Disposition": "attachment; filename=speech.wav", + "Cache-Control": "no-cache", + }, + ) + + except ValueError as e: + logger.error(f"Invalid request: {str(e)}") + raise HTTPException( + status_code=400, detail={"error": "Invalid request", "message": str(e)} + ) + except Exception as e: + logger.error(f"Error generating audio: {str(e)}") + raise HTTPException( + status_code=500, detail={"error": "Server error", "message": str(e)} + ) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index b790b4b..d86e00a 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,13 +1,12 @@ -from typing import List, Union +from typing import List, Union, AsyncGenerator from loguru import logger -from fastapi import Depends, Response, APIRouter, HTTPException -from fastapi import Header +from fastapi import Header, Depends, Response, APIRouter, HTTPException from fastapi.responses import StreamingResponse -from ..services.tts_service import TTSService + from ..services.audio import AudioService from ..structures.schemas import OpenAISpeechRequest -from typing import AsyncGenerator +from ..services.tts_service import TTSService router = APIRouter( tags=["OpenAI Compatible TTS"], @@ -20,7 +19,9 @@ def get_tts_service() -> TTSService: return TTSService() # Initialize TTSService with default settings -async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str: +async def process_voices( + voice_input: Union[str, List[str]], tts_service: TTSService +) -> str: """Process voice input into a combined voice, handling both string and list formats""" # Convert input to list of voices if isinstance(voice_input, str): @@ -35,7 +36,9 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer available_voices = await tts_service.list_voices() for voice in voices: if voice not in available_voices: - raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}") + raise ValueError( + f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}" + ) # If single voice, return it directly if len(voices) == 1: @@ -45,21 +48,23 @@ async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSSer return await tts_service.combine_voices(voices=voices) -async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]: +async def stream_audio_chunks( + tts_service: TTSService, request: OpenAISpeechRequest +) -> AsyncGenerator[bytes, None]: """Stream audio chunks as they're generated""" voice_to_use = await process_voices(request.voice, tts_service) async for chunk in tts_service.generate_audio_stream( text=request.input, voice=voice_to_use, speed=request.speed, - output_format=request.response_format + output_format=request.response_format, ): yield chunk @router.post("/audio/speech") async def create_speech( - request: OpenAISpeechRequest, + request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service), x_raw_response: str = Header(None, alias="x-raw-response"), ): @@ -88,6 +93,7 @@ async def create_speech( "Content-Disposition": f"attachment; filename=speech.{request.response_format}", "X-Accel-Buffering": "no", # Disable proxy buffering "Cache-Control": "no-cache", # Prevent caching + "Transfer-Encoding": "chunked", # Enable chunked transfer encoding }, ) else: @@ -101,11 +107,8 @@ async def create_speech( # Convert to requested format content = AudioService.convert_audio( - audio, - 24000, - request.response_format, - is_first_chunk=True, - stream=False) + audio, 24000, request.response_format, is_first_chunk=True, stream=False + ) return Response( content=content, diff --git a/api/src/routers/text_processing.py b/api/src/routers/text_processing.py deleted file mode 100644 index 9e1ce3a..0000000 --- a/api/src/routers/text_processing.py +++ /dev/null @@ -1,30 +0,0 @@ -from fastapi import APIRouter -from ..structures.text_schemas import PhonemeRequest, PhonemeResponse -from ..services.text_processing import phonemize, tokenize - -router = APIRouter( - prefix="/text", - tags=["text processing"] -) - -@router.post("/phonemize", response_model=PhonemeResponse) -async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse: - """Convert text to phonemes and tokens: Rough attempt - - Args: - request: Request containing text and language - - Returns: - Phonemes and token IDs - """ - # Get phonemes - phonemes = phonemize(request.text, request.language) - - # Get tokens - tokens = tokenize(phonemes) - tokens = [0] + tokens + [0] # Add start/end tokens - - return PhonemeResponse( - phonemes=phonemes, - tokens=tokens - ) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index dcb2a72..4c5a415 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -6,35 +6,41 @@ import soundfile as sf import scipy.io.wavfile as wavfile from loguru import logger + from ..core.config import settings + class AudioNormalizer: """Handles audio normalization state for a single stream""" + def __init__(self): self.int16_max = np.iinfo(np.int16).max self.chunk_trim_ms = settings.gap_trim_ms self.sample_rate = 24000 # Sample rate of the audio self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) - - def normalize(self, audio_data: np.ndarray, is_last_chunk: bool = False) -> np.ndarray: + + def normalize( + self, audio_data: np.ndarray, is_last_chunk: bool = False + ) -> np.ndarray: """Normalize audio data to int16 range and trim chunk boundaries""" # Convert to float32 if not already audio_float = audio_data.astype(np.float32) - + # Normalize to [-1, 1] range first if np.max(np.abs(audio_float)) > 0: audio_float = audio_float / np.max(np.abs(audio_float)) - + # Trim end of non-final chunks to reduce gaps if not is_last_chunk and len(audio_float) > self.samples_to_trim: - audio_float = audio_float[:-self.samples_to_trim] - + audio_float = audio_float[: -self.samples_to_trim] + # Scale to int16 range return (audio_float * self.int16_max).astype(np.int16) + class AudioService: """Service for audio format conversions""" - + # Default audio format settings balanced for speed and compression DEFAULT_SETTINGS = { "mp3": { @@ -46,19 +52,19 @@ class AudioService: }, "flac": { "compression_level": 0.0, # Light compression, still fast - } + }, } - + @staticmethod def convert_audio( - audio_data: np.ndarray, - sample_rate: int, - output_format: str, + audio_data: np.ndarray, + sample_rate: int, + output_format: str, is_first_chunk: bool = True, is_last_chunk: bool = False, normalizer: AudioNormalizer = None, format_settings: dict = None, - stream: bool = True + stream: bool = True, ) -> bytes: """Convert audio data to specified format @@ -88,57 +94,65 @@ def convert_audio( try: # Always normalize audio to ensure proper amplitude scaling - if stream: - if normalizer is None: - normalizer = AudioNormalizer() - normalized_audio = normalizer.normalize(audio_data, is_last_chunk=is_last_chunk) - else: - normalized_audio = audio_data + if normalizer is None: + normalizer = AudioNormalizer() + normalized_audio = normalizer.normalize( + audio_data, is_last_chunk=is_last_chunk + ) if output_format == "pcm": # Raw 16-bit PCM samples, no header buffer.write(normalized_audio.tobytes()) elif output_format == "wav": - if stream: - # Use soundfile for streaming to ensure proper headers - sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') - else: - # Trying scipy.io.wavfile for non-streaming WAV generation - # seems faster than soundfile - # avoids overhead from header generation and PCM encoding - wavfile.write(buffer, sample_rate, normalized_audio) + # WAV format with headers + sf.write( + buffer, + normalized_audio, + sample_rate, + format="WAV", + subtype="PCM_16", + ) elif output_format == "mp3": - # Use format settings or defaults + # MP3 format with proper framing settings = format_settings.get("mp3", {}) if format_settings else {} settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings} sf.write( - buffer, normalized_audio, - sample_rate, format="MP3", - **settings - ) - + buffer, normalized_audio, sample_rate, format="MP3", **settings + ) elif output_format == "opus": + # Opus format in OGG container settings = format_settings.get("opus", {}) if format_settings else {} settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings} - sf.write(buffer, normalized_audio, sample_rate, format="OGG", - subtype="OPUS", **settings) - + sf.write( + buffer, + normalized_audio, + sample_rate, + format="OGG", + subtype="OPUS", + **settings, + ) elif output_format == "flac": + # FLAC format with proper framing if is_first_chunk: logger.info("Starting FLAC stream...") settings = format_settings.get("flac", {}) if format_settings else {} settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings} - sf.write(buffer, normalized_audio, sample_rate, format="FLAC", - subtype='PCM_16', **settings) + sf.write( + buffer, + normalized_audio, + sample_rate, + format="FLAC", + subtype="PCM_16", + **settings, + ) + elif output_format == "aac": + raise ValueError( + "Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm." + ) else: - if output_format == "aac": - raise ValueError( - "Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm." - ) - else: - raise ValueError( - f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." - ) + raise ValueError( + f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac." + ) buffer.seek(0) return buffer.getvalue() diff --git a/api/src/services/text_processing/__init__.py b/api/src/services/text_processing/__init__.py index f945e18..624ce7c 100644 --- a/api/src/services/text_processing/__init__.py +++ b/api/src/services/text_processing/__init__.py @@ -1,13 +1,13 @@ from .normalizer import normalize_text -from .phonemizer import phonemize, PhonemizerBackend, EspeakBackend -from .vocabulary import tokenize, decode_tokens, VOCAB +from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize +from .vocabulary import VOCAB, tokenize, decode_tokens __all__ = [ - 'normalize_text', - 'phonemize', - 'tokenize', - 'decode_tokens', - 'VOCAB', - 'PhonemizerBackend', - 'EspeakBackend' + "normalize_text", + "phonemize", + "tokenize", + "decode_tokens", + "VOCAB", + "PhonemizerBackend", + "EspeakBackend", ] diff --git a/api/src/services/text_processing/chunker.py b/api/src/services/text_processing/chunker.py index c0c59eb..2bbda79 100644 --- a/api/src/services/text_processing/chunker.py +++ b/api/src/services/text_processing/chunker.py @@ -1,44 +1,45 @@ """Text chunking service""" import re + from ...core.config import settings def split_text(text: str, max_chunk=None): """Split text into chunks on natural pause points - + Args: text: Text to split into chunks max_chunk: Maximum chunk size (defaults to settings.max_chunk_size) """ if max_chunk is None: max_chunk = settings.max_chunk_size - + if not isinstance(text, str): text = str(text) if text is not None else "" - + text = text.strip() if not text: return - + # First split into sentences sentences = re.split(r"(?<=[.!?])\s+", text) - + for sentence in sentences: sentence = sentence.strip() if not sentence: continue - + # For medium-length sentences, split on punctuation if len(sentence) > max_chunk: # Lower threshold for more consistent sizes # First try splitting on semicolons and colons parts = re.split(r"(?<=[;:])\s+", sentence) - + for part in parts: part = part.strip() if not part: continue - + # If part is still long, split on commas if len(part) > max_chunk: subparts = re.split(r"(?<=,)\s+", part) diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index e213e55..d24ce6b 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -10,10 +10,45 @@ # Constants VALID_TLDS = [ - "com", "org", "net", "edu", "gov", "mil", "int", "biz", "info", "name", - "pro", "coop", "museum", "travel", "jobs", "mobi", "tel", "asia", "cat", - "xxx", "aero", "arpa", "bg", "br", "ca", "cn", "de", "es", "eu", "fr", - "in", "it", "jp", "mx", "nl", "ru", "uk", "us", "io" + "com", + "org", + "net", + "edu", + "gov", + "mil", + "int", + "biz", + "info", + "name", + "pro", + "coop", + "museum", + "travel", + "jobs", + "mobi", + "tel", + "asia", + "cat", + "xxx", + "aero", + "arpa", + "bg", + "br", + "ca", + "cn", + "de", + "es", + "eu", + "fr", + "in", + "it", + "jp", + "mx", + "nl", + "ru", + "uk", + "us", + "io", ] VALID_UNITS = { @@ -37,16 +72,20 @@ } # Pre-compiled regex patterns for performance -EMAIL_PATTERN = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE) +EMAIL_PATTERN = re.compile( + r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE +) URL_PATTERN = re.compile( - r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" + - "|".join(VALID_TLDS) + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?", - re.IGNORECASE + r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" + + "|".join(VALID_TLDS) + + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?", + re.IGNORECASE, ) UNIT_PATTERN = re.compile(r"((??@\[\\\]^_`{\|}~ \n]{1})""",re.IGNORECASE) INFLECT_ENGINE=inflect.engine() + def split_num(num: re.Match[str]) -> str: """Handle number splitting for various formats""" num = num.group() @@ -71,6 +110,7 @@ def split_num(num: re.Match[str]) -> str: return f"{left} oh {right}{s}" return f"{left} {right}{s}" + def handle_money(m: re.Match[str]) -> str: """Convert money expressions to spoken form""" m = m.group() @@ -90,49 +130,57 @@ def handle_money(m: re.Match[str]) -> str: ) return f"{b} {bill}{s} and {c} {coins}" + def handle_decimal(num: re.Match[str]) -> str: """Convert decimal numbers to spoken form""" a, b = num.group().split(".") return " point ".join([a, " ".join(b)]) + def handle_email(m: re.Match[str]) -> str: """Convert email addresses into speakable format""" email = m.group(0) - parts = email.split('@') + parts = email.split("@") if len(parts) == 2: user, domain = parts - domain = domain.replace('.', ' dot ') + domain = domain.replace(".", " dot ") return f"{user} at {domain}" return email + def handle_url(u: re.Match[str]) -> str: """Make URLs speakable by converting special characters to spoken words""" if not u: return "" - + url = u.group(0).strip() - + # Handle protocol first - url = re.sub(r'^https?://', lambda a: 'https ' if 'https' in a.group() else 'http ', url, flags=re.IGNORECASE) - url = re.sub(r'^www\.', 'www ', url, flags=re.IGNORECASE) - + url = re.sub( + r"^https?://", + lambda a: "https " if "https" in a.group() else "http ", + url, + flags=re.IGNORECASE, + ) + url = re.sub(r"^www\.", "www ", url, flags=re.IGNORECASE) + # Handle port numbers before other replacements - url = re.sub(r':(\d+)(?=/|$)', lambda m: f" colon {m.group(1)}", url) - + url = re.sub(r":(\d+)(?=/|$)", lambda m: f" colon {m.group(1)}", url) + # Split into domain and path - parts = url.split('/', 1) + parts = url.split("/", 1) domain = parts[0] - path = parts[1] if len(parts) > 1 else '' - + path = parts[1] if len(parts) > 1 else "" + # Handle dots in domain - domain = domain.replace('.', ' dot ') - + domain = domain.replace(".", " dot ") + # Reconstruct URL if path: url = f"{domain} slash {path}" else: url = domain - + # Replace remaining symbols with words url = url.replace("-", " dash ") url = url.replace("_", " underscore ") @@ -142,9 +190,9 @@ def handle_url(u: re.Match[str]) -> str: url = url.replace("%", " percent ") url = url.replace(":", " colon ") # Handle any remaining colons url = url.replace("/", " slash ") # Handle any remaining slashes - + # Clean up extra spaces - return re.sub(r'\s+', ' ', url).strip() + return re.sub(r"\s+", " ", url).strip() def handle_units(u: re.Match[str]) -> str: unit=u.group(6).strip() @@ -158,12 +206,13 @@ def normalize_urls(text: str) -> str: """Pre-process URLs before other text normalization""" # Handle email addresses first text = EMAIL_PATTERN.sub(handle_email, text) - + # Handle URLs text = URL_PATTERN.sub(handle_url, text) - + return text - + + def normalize_text(text: str) -> str: """Normalize text for TTS processing""" # Pre-process numbers with units @@ -171,37 +220,36 @@ def normalize_text(text: str) -> str: # Pre-process URLs first text = normalize_urls(text) + # Replace quotes and brackets text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace(chr(8220), '"').replace(chr(8221), '"') text = text.replace("(", "«").replace(")", "»") - + # Handle CJK punctuation and some non standard chars for a, b in zip("、。!,:;?–", ",.!,:;?-"): text = text.replace(a, b + " ") - + # Clean up whitespace text = re.sub(r"[^\S \n]", " ", text) text = re.sub(r" +", " ", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text) - + # Handle titles and abbreviations text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text) text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text) text = re.sub(r"\betc\.(?! [A-Z])", "etc", text) - + # Handle common words text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) - + # Handle numbers and money text = re.sub( - r"\d*\.\d+|\b\d{4}s?\b|(? str: text, ) text = re.sub(r"\d*\.\d+", handle_decimal, text) - + # Handle various formatting text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) text = re.sub(r"(?<=\d)S", " S", text) text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text) text = re.sub(r"(?<=X')S\b", "s", text) text = re.sub( - r"(?:[A-Za-z]\.){2,} [a-z]", - lambda m: m.group().replace(".", "-"), - text + r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text ) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) + return text.strip() diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py index 0d04d86..a328bb5 100644 --- a/api/src/services/text_processing/phonemizer.py +++ b/api/src/services/text_processing/phonemizer.py @@ -1,97 +1,98 @@ import re from abc import ABC, abstractmethod + import phonemizer + from .normalizer import normalize_text + class PhonemizerBackend(ABC): """Abstract base class for phonemization backends""" - + @abstractmethod def phonemize(self, text: str) -> str: """Convert text to phonemes - + Args: text: Text to convert to phonemes - + Returns: Phonemized text """ pass + class EspeakBackend(PhonemizerBackend): """Espeak-based phonemizer implementation""" - + def __init__(self, language: str): """Initialize espeak backend - + Args: language: Language code ('en-us' or 'en-gb') """ self.backend = phonemizer.backend.EspeakBackend( - language=language, - preserve_punctuation=True, - with_stress=True + language=language, preserve_punctuation=True, with_stress=True ) self.language = language - + def phonemize(self, text: str) -> str: """Convert text to phonemes using espeak - + Args: text: Text to convert to phonemes - + Returns: Phonemized text """ # Phonemize text ps = self.backend.phonemize([text]) ps = ps[0] if ps else "" - + # Handle special cases ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ") ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l") ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps) ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»"" ]|$)', "z", ps) - + # Language-specific rules if self.language == "en-us": ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps) - + return ps.strip() + def create_phonemizer(language: str = "a") -> PhonemizerBackend: """Factory function to create phonemizer backend - + Args: language: Language code ('a' for US English, 'b' for British English) - + Returns: Phonemizer backend instance """ # Map language codes to espeak language codes - lang_map = { - "a": "en-us", - "b": "en-gb" - } - + lang_map = {"a": "en-us", "b": "en-gb"} + if language not in lang_map: raise ValueError(f"Unsupported language code: {language}") - + return EspeakBackend(lang_map[language]) + def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: """Convert text to phonemes - + Args: text: Text to convert to phonemes language: Language code ('a' for US English, 'b' for British English) normalize: Whether to normalize text before phonemization - + Returns: Phonemized text """ if normalize: text = normalize_text(text) - + phonemizer = create_phonemizer(language) return phonemizer.phonemize(text) diff --git a/api/src/services/text_processing/vocabulary.py b/api/src/services/text_processing/vocabulary.py index 66af961..7a12892 100644 --- a/api/src/services/text_processing/vocabulary.py +++ b/api/src/services/text_processing/vocabulary.py @@ -4,31 +4,34 @@ def get_vocab(): _punctuation = ';:,.!?¡¿—…"«»"" ' _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" - + # Create vocabulary dictionary symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) return {symbol: i for i, symbol in enumerate(symbols)} + # Initialize vocabulary VOCAB = get_vocab() + def tokenize(phonemes: str) -> list[int]: """Convert phonemes string to token IDs - + Args: phonemes: String of phonemes to tokenize - + Returns: List of token IDs """ return [i for i in map(VOCAB.get, phonemes) if i is not None] + def decode_tokens(tokens: list[int]) -> str: """Convert token IDs back to phonemes string - + Args: tokens: List of token IDs - + Returns: String of phonemes """ diff --git a/api/src/services/tts_base.py b/api/src/services/tts_base.py index 16e8462..6076ebf 100644 --- a/api/src/services/tts_base.py +++ b/api/src/services/tts_base.py @@ -2,12 +2,14 @@ import threading from abc import ABC, abstractmethod from typing import List, Tuple -import torch + import numpy as np +import torch from loguru import logger from ..core.config import settings + class TTSBaseModel(ABC): _instance = None _lock = threading.Lock() @@ -26,7 +28,9 @@ async def setup(cls): # Test CUDA device test_tensor = torch.zeros(1).cuda() logger.info("CUDA test successful") - model_path = os.path.join(settings.model_dir, settings.pytorch_model_path) + model_path = os.path.join( + settings.model_dir, settings.pytorch_model_path + ) cls._device = "cuda" except Exception as e: logger.error(f"CUDA test failed: {e}") @@ -36,9 +40,11 @@ async def setup(cls): model_path = os.path.join(settings.model_dir, settings.onnx_model_path) logger.info(f"Initializing model on {cls._device}") - # Initialize model - if not cls.initialize(settings.model_dir, model_path=model_path): + # Initialize model first + model = cls.initialize(settings.model_dir, model_path=model_path) + if model is None: raise RuntimeError(f"Failed to initialize {cls._device.upper()} model") + cls._instance = model # Setup voices directory os.makedirs(cls.VOICES_DIR, exist_ok=True) @@ -52,33 +58,55 @@ async def setup(cls): voice_path = os.path.join(cls.VOICES_DIR, file) if not os.path.exists(voice_path): try: - logger.info(f"Copying base voice {voice_name} to voices directory") + logger.info( + f"Copying base voice {voice_name} to voices directory" + ) base_path = os.path.join(base_voices_dir, file) - voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) + voicepack = torch.load( + base_path, + map_location=cls._device, + weights_only=True, + ) torch.save(voicepack, voice_path) except Exception as e: - logger.error(f"Error copying voice {voice_name}: {str(e)}") + logger.error( + f"Error copying voice {voice_name}: {str(e)}" + ) - # Load warmup text + # Count voices in directory + voice_count = len( + [f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")] + ) + + # Now that model and voices are ready, do warmup try: - with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "don_quixote.txt")) as f: + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "core", + "don_quixote.txt", + ) + ) as f: warmup_text = f.read() except Exception as e: logger.warning(f"Failed to load warmup text: {e}") warmup_text = "This is a warmup text that will be split into chunks for processing." - # Use warmup service + # Use warmup service after model is fully initialized from .warmup import WarmupService + warmup = WarmupService() - + # Load and warm up voices loaded_voices = warmup.load_voices() await warmup.warmup_voices(warmup_text, loaded_voices) - + logger.info("Model warm-up complete") # Count voices in directory - voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]) + voice_count = len( + [f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")] + ) return voice_count @classmethod @@ -91,11 +119,11 @@ def initialize(cls, model_dir: str, model_path: str = None): @abstractmethod def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]: """Process text into phonemes and tokens - + Args: text: Input text language: Language code - + Returns: tuple[str, list[int]]: Phonemes and token IDs """ @@ -103,15 +131,17 @@ def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]: @classmethod @abstractmethod - def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> Tuple[np.ndarray, str]: + def generate_from_text( + cls, text: str, voicepack: torch.Tensor, language: str, speed: float + ) -> Tuple[np.ndarray, str]: """Generate audio from text - + Args: text: Input text voicepack: Voice tensor language: Language code speed: Speed factor - + Returns: tuple[np.ndarray, str]: Generated audio samples and phonemes """ @@ -119,14 +149,16 @@ def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, s @classmethod @abstractmethod - def generate_from_tokens(cls, tokens: List[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: + def generate_from_tokens( + cls, tokens: List[int], voicepack: torch.Tensor, speed: float + ) -> np.ndarray: """Generate audio from tokens - + Args: tokens: Token IDs voicepack: Voice tensor speed: Speed factor - + Returns: np.ndarray: Generated audio samples """ diff --git a/api/src/services/tts_cpu.py b/api/src/services/tts_cpu.py index 0436a24..5284750 100644 --- a/api/src/services/tts_cpu.py +++ b/api/src/services/tts_cpu.py @@ -1,17 +1,31 @@ import os + import numpy as np import torch -from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode from loguru import logger +from onnxruntime import ( + ExecutionMode, + SessionOptions, + InferenceSession, + GraphOptimizationLevel, +) from .tts_base import TTSBaseModel -from .text_processing import phonemize, tokenize from ..core.config import settings +from .text_processing import tokenize, phonemize + class TTSCPUModel(TTSBaseModel): _instance = None _onnx_session = None + @classmethod + def get_instance(cls): + """Get the model instance""" + if cls._onnx_session is None: + raise RuntimeError("ONNX model not initialized. Call initialize() first.") + return cls._onnx_session + @classmethod def initialize(cls, model_dir: str, model_path: str = None): """Initialize ONNX model for CPU inference""" @@ -27,59 +41,63 @@ def initialize(cls, model_dir: str, model_path: str = None): if not onnx_path: return None - logger.info(f"Loading ONNX model from {onnx_path}") - # Configure ONNX session for optimal performance session_options = SessionOptions() - + # Set optimization level if settings.onnx_optimization_level == "all": - session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + session_options.graph_optimization_level = ( + GraphOptimizationLevel.ORT_ENABLE_ALL + ) elif settings.onnx_optimization_level == "basic": - session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC + session_options.graph_optimization_level = ( + GraphOptimizationLevel.ORT_ENABLE_BASIC + ) else: - session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL - + session_options.graph_optimization_level = ( + GraphOptimizationLevel.ORT_DISABLE_ALL + ) + # Configure threading session_options.intra_op_num_threads = settings.onnx_num_threads session_options.inter_op_num_threads = settings.onnx_inter_op_threads - + # Set execution mode session_options.execution_mode = ( - ExecutionMode.ORT_PARALLEL - if settings.onnx_execution_mode == "parallel" + ExecutionMode.ORT_PARALLEL + if settings.onnx_execution_mode == "parallel" else ExecutionMode.ORT_SEQUENTIAL ) - + # Enable/disable memory pattern optimization session_options.enable_mem_pattern = settings.onnx_memory_pattern # Configure CPU provider options provider_options = { - 'CPUExecutionProvider': { - 'arena_extend_strategy': settings.onnx_arena_extend_strategy, - 'cpu_memory_arena_cfg': 'cpu:0' + "CPUExecutionProvider": { + "arena_extend_strategy": settings.onnx_arena_extend_strategy, + "cpu_memory_arena_cfg": "cpu:0", } } - cls._onnx_session = InferenceSession( + session = InferenceSession( onnx_path, sess_options=session_options, - providers=['CPUExecutionProvider'], - provider_options=[provider_options] + providers=["CPUExecutionProvider"], + provider_options=[provider_options], ) - - return cls._onnx_session + cls._onnx_session = session + return session return cls._onnx_session @classmethod def process_text(cls, text: str, language: str) -> tuple[str, list[int]]: """Process text into phonemes and tokens - + Args: text: Input text language: Language code - + Returns: tuple[str, list[int]]: Phonemes and token IDs """ @@ -89,38 +107,42 @@ def process_text(cls, text: str, language: str) -> tuple[str, list[int]]: return phonemes, tokens @classmethod - def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]: + def generate_from_text( + cls, text: str, voicepack: torch.Tensor, language: str, speed: float + ) -> tuple[np.ndarray, str]: """Generate audio from text - + Args: text: Input text voicepack: Voice tensor language: Language code speed: Speed factor - + Returns: tuple[np.ndarray, str]: Generated audio samples and phonemes """ if cls._onnx_session is None: raise RuntimeError("ONNX model not initialized") - + # Process text phonemes, tokens = cls.process_text(text, language) - + # Generate audio audio = cls.generate_from_tokens(tokens, voicepack, speed) - + return audio, phonemes @classmethod - def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: + def generate_from_tokens( + cls, tokens: list[int], voicepack: torch.Tensor, speed: float + ) -> np.ndarray: """Generate audio from tokens - + Args: tokens: Token IDs voicepack: Voice tensor speed: Speed factor - + Returns: np.ndarray: Generated audio samples """ @@ -129,16 +151,15 @@ def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: # Pre-allocate and prepare inputs tokens_input = np.array([tokens], dtype=np.int64) - style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions - speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed - + style_input = voicepack[ + len(tokens) - 2 + ].numpy() # Already has correct dimensions + speed_input = np.full( + 1, speed, dtype=np.float32 + ) # More efficient than ones * speed + # Run inference with optimized inputs result = cls._onnx_session.run( - None, - { - 'tokens': tokens_input, - 'style': style_input, - 'speed': speed_input - } + None, {"tokens": tokens_input, "style": style_input, "speed": speed_input} ) return result[0] diff --git a/api/src/services/tts_gpu.py b/api/src/services/tts_gpu.py index 51c8424..1e5f4a1 100644 --- a/api/src/services/tts_gpu.py +++ b/api/src/services/tts_gpu.py @@ -1,13 +1,15 @@ import os +import time + import numpy as np import torch -import time from loguru import logger from models import build_model -from .text_processing import phonemize, tokenize from .tts_base import TTSBaseModel from ..core.config import settings +from .text_processing import tokenize, phonemize + # @torch.no_grad() # def forward(model, tokens, ref_s, speed): @@ -36,48 +38,65 @@ # return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy() @torch.no_grad() def forward(model, tokens, ref_s, speed): - """Forward pass through the model with light optimizations that preserve output quality""" + """Forward pass through the model with moderate memory management""" device = ref_s.device - # Keep original token handling but optimize device placement - tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) - input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) - text_mask = length_to_mask(input_lengths).to(device) - - # BERT and encoder pass - bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) - d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - - # Split reference signal once for efficiency - s_content = ref_s[:, 128:] - s_ref = ref_s[:, :128] - - # Predictor forward pass - d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) - x, _ = model.predictor.lstm(d) - - # Duration prediction - keeping original logic - duration = model.predictor.duration_proj(x) - duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() - - # Alignment matrix construction - keeping original approach for quality - pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 - c_frame += pred_dur[0, i].item() - - # Matrix multiplications - reuse unsqueezed tensor - pred_aln_trg = pred_aln_trg.unsqueeze(0) # Do unsqueeze once - en = d.transpose(-1, -2) @ pred_aln_trg - F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) - - # Text encoding and final decoding - t_en = model.text_encoder(tokens, input_lengths, text_mask) - asr = t_en @ pred_aln_trg - - return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy() + try: + # Initial tensor setup with proper device placement + tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + # Split and clone reference signals with explicit device placement + s_content = ref_s[:, 128:].clone().to(device) + s_ref = ref_s[:, :128].clone().to(device) + + # BERT and encoder pass + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + # Predictor forward pass + d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) + x, _ = model.predictor.lstm(d) + + # Duration prediction + duration = model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long() + # Only cleanup large intermediates + del duration, x + + # Alignment matrix construction + pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1 + c_frame += pred_dur[0, i].item() + pred_aln_trg = pred_aln_trg.unsqueeze(0) + + # Matrix multiplications with selective cleanup + en = d.transpose(-1, -2) @ pred_aln_trg + del d # Free large intermediate tensor + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) + del en # Free large intermediate tensor + + # Final text encoding and decoding + t_en = model.text_encoder(tokens, input_lengths, text_mask) + asr = t_en @ pred_aln_trg + del t_en # Free large intermediate tensor + + # Final decoding and transfer to CPU + output = model.decoder(asr, F0_pred, N_pred, s_ref) + result = output.squeeze().cpu().numpy() + + return result + + finally: + # Let PyTorch handle most cleanup automatically + # Only explicitly free the largest tensors + del pred_aln_trg, asr + # def length_to_mask(lengths): # """Create attention mask from lengths""" @@ -90,21 +109,32 @@ def forward(model, tokens, ref_s, speed): # mask = torch.gt(mask + 1, lengths.unsqueeze(1)) # return mask + def length_to_mask(lengths): """Create attention mask from lengths - possibly optimized version""" max_len = lengths.max() # Create mask directly on the same device as lengths - mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1) + mask = torch.arange(max_len, device=lengths.device)[None, :].expand( + lengths.shape[0], -1 + ) # Avoid type_as by using the correct dtype from the start if lengths.dtype != mask.dtype: mask = mask.to(dtype=lengths.dtype) # Fuse operations using broadcasting return mask + 1 > lengths[:, None] + class TTSGPUModel(TTSBaseModel): _instance = None _device = "cuda" + @classmethod + def get_instance(cls): + """Get the model instance""" + if cls._instance is None: + raise RuntimeError("GPU model not initialized. Call initialize() first.") + return cls._instance + @classmethod def initialize(cls, model_dir: str, model_path: str): """Initialize PyTorch model for GPU inference""" @@ -114,7 +144,7 @@ def initialize(cls, model_dir: str, model_path: str): model_path = os.path.join(model_dir, settings.pytorch_model_path) model = build_model(model_path, cls._device) cls._instance = model - return cls._instance + return model except Exception as e: logger.error(f"Failed to initialize GPU model: {e}") return None @@ -123,11 +153,11 @@ def initialize(cls, model_dir: str, model_path: str): @classmethod def process_text(cls, text: str, language: str) -> tuple[str, list[int]]: """Process text into phonemes and tokens - + Args: text: Input text language: Language code - + Returns: tuple[str, list[int]]: Phonemes and token IDs """ @@ -136,48 +166,97 @@ def process_text(cls, text: str, language: str) -> tuple[str, list[int]]: return phonemes, tokens @classmethod - def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]: + def generate_from_text( + cls, text: str, voicepack: torch.Tensor, language: str, speed: float + ) -> tuple[np.ndarray, str]: """Generate audio from text - + Args: text: Input text voicepack: Voice tensor language: Language code speed: Speed factor - + Returns: tuple[np.ndarray, str]: Generated audio samples and phonemes """ if cls._instance is None: raise RuntimeError("GPU model not initialized") - + # Process text phonemes, tokens = cls.process_text(text, language) - + # Generate audio audio = cls.generate_from_tokens(tokens, voicepack, speed) - + return audio, phonemes @classmethod - def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray: - """Generate audio from tokens - + def generate_from_tokens( + cls, tokens: list[int], voicepack: torch.Tensor, speed: float + ) -> np.ndarray: + """Generate audio from tokens with moderate memory management + Args: tokens: Token IDs voicepack: Voice tensor speed: Speed factor - + Returns: np.ndarray: Generated audio samples """ if cls._instance is None: raise RuntimeError("GPU model not initialized") + + try: + device = cls._device - # Get reference style - ref_s = voicepack[len(tokens)] - - # Generate audio - audio = forward(cls._instance, tokens, ref_s, speed) + # Check memory pressure + if torch.cuda.is_available(): + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB + if memory_allocated > 2.0: # 2GB limit + logger.info( + f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB " + f"Clearing cache" + ) + torch.cuda.empty_cache() + import gc + gc.collect() + + # Get reference style with proper device placement + ref_s = voicepack[len(tokens)].clone().to(device) + + # Generate audio + audio = forward(cls._instance, tokens, ref_s, speed) + + return audio + + except RuntimeError as e: + if "out of memory" in str(e): + # On OOM, do a full cleanup and retry + if torch.cuda.is_available(): + logger.warning("Out of memory detected, performing full cleanup") + torch.cuda.synchronize() + torch.cuda.empty_cache() + import gc + gc.collect() + + # Log memory stats after cleanup + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + logger.info( + f"Memory after OOM cleanup: " + f"Allocated: {memory_allocated / 1e9:.2f}GB, " + f"Reserved: {memory_reserved / 1e9:.2f}GB" + ) + + # Retry generation + ref_s = voicepack[len(tokens)].clone().to(device) + audio = forward(cls._instance, tokens, ref_s, speed) + return audio + raise - return audio + finally: + # Only synchronize at the top level, no empty_cache + if torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 4414ea2..61471a8 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -1,5 +1,4 @@ import io -import aiofiles.os import os import re import time @@ -8,25 +7,28 @@ import numpy as np import torch +import aiofiles.os import scipy.io.wavfile as wavfile -from .text_processing import normalize_text, chunker from loguru import logger -from ..core.config import settings -from .tts_model import TTSModel from .audio import AudioService, AudioNormalizer +from .tts_model import TTSModel +from ..core.config import settings +from .text_processing import chunker, normalize_text class TTSService: def __init__(self, output_dir: str = None): self.output_dir = output_dir - + self.model = TTSModel.get_instance() @staticmethod - @lru_cache(maxsize=20) # Cache up to 8 most recently used voices + @lru_cache(maxsize=3) # Cache up to 3 most recently used voices def _load_voice(voice_path: str) -> torch.Tensor: """Load and cache a voice model""" - return torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) + return torch.load( + voice_path, map_location=TTSModel.get_device(), weights_only=True + ) def _get_voice_path(self, voice_name: str) -> Optional[str]: """Get the path to a voice file""" @@ -37,7 +39,9 @@ def _generate_audio( self, text: str, voice: str, speed: float, stitch_long_output: bool = True ) -> Tuple[torch.Tensor, float]: """Generate complete audio and return with processing time""" - audio, processing_time = self._generate_audio_internal(text, voice, speed, stitch_long_output) + audio, processing_time = self._generate_audio_internal( + text, voice, speed, stitch_long_output + ) return audio, processing_time def _generate_audio_internal( @@ -72,7 +76,9 @@ def _generate_audio_internal( phonemes, tokens = TTSModel.process_text(chunk, voice[0]) chunks_data.append((chunk, tokens)) except Exception as e: - logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") + logger.error( + f"Failed to process chunk: '{chunk}'. Error: {str(e)}" + ) continue if not chunks_data: @@ -82,20 +88,28 @@ def _generate_audio_internal( audio_chunks = [] for chunk, tokens in chunks_data: try: - chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) + chunk_audio = TTSModel.generate_from_tokens( + tokens, voicepack, speed + ) if chunk_audio is not None: audio_chunks.append(chunk_audio) else: logger.error(f"No audio generated for chunk: '{chunk}'") except Exception as e: - logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}") + logger.error( + f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}" + ) continue if not audio_chunks: raise ValueError("No audio chunks were generated successfully") # Concatenate all chunks - audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] + audio = ( + np.concatenate(audio_chunks) + if len(audio_chunks) > 1 + else audio_chunks[0] + ) else: # Process single chunk phonemes, tokens = TTSModel.process_text(text, voice[0]) @@ -109,14 +123,19 @@ def _generate_audio_internal( raise async def generate_audio_stream( - self, text: str, voice: str, speed: float, output_format: str = "wav", silent=False + self, + text: str, + voice: str, + speed: float, + output_format: str = "wav", + silent=False, ): """Generate and yield audio chunks as they're generated for real-time streaming""" try: stream_start = time.time() # Create normalizer for consistent audio levels stream_normalizer = AudioNormalizer() - + # Input validation and preprocessing if not text: raise ValueError("Text is empty") @@ -125,7 +144,9 @@ async def generate_audio_stream( if not normalized: raise ValueError("Text is empty after preprocessing") text = str(normalized) - logger.debug(f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms") + logger.debug( + f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms" + ) # Voice validation and loading voice_start = time.time() @@ -133,74 +154,56 @@ async def generate_audio_stream( if not voice_path: raise ValueError(f"Voice not found: {voice}") voicepack = self._load_voice(voice_path) - logger.debug(f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms") + logger.debug( + f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms" + ) # Process chunks as they're generated is_first = True chunks_processed = 0 - # last_chunk_end = time.time() - + # Process chunks as they come from generator chunk_gen = chunker.split_text(text) current_chunk = next(chunk_gen, None) - + while current_chunk is not None: next_chunk = next(chunk_gen, None) # Peek at next chunk - # chunk_start = time.time() chunks_processed += 1 try: # Process text and generate audio - # text_process_start = time.time() phonemes, tokens = TTSModel.process_text(current_chunk, voice[0]) - # text_process_time = time.time() - text_process_start - - # audio_gen_start = time.time() - chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) - # audio_gen_time = time.time() - audio_gen_start - + chunk_audio = TTSModel.generate_from_tokens( + tokens, voicepack, speed + ) + if chunk_audio is not None: - # Convert chunk with proper header handling - convert_start = time.time() + # Convert chunk with proper streaming header handling chunk_bytes = AudioService.convert_audio( chunk_audio, 24000, output_format, is_first_chunk=is_first, normalizer=stream_normalizer, - is_last_chunk=(next_chunk is None) # Last if no next chunk + is_last_chunk=(next_chunk is None), # Last if no next chunk + stream=True # Ensure proper streaming format handling ) - # convert_time = time.time() - convert_start - - # Calculate gap from last chunk - # gap_time = chunk_start - last_chunk_end - - # Log timing details if not silent - # if not silent: - # logger.debug( - # f"\nChunk {chunks_processed} timing:" - # f"\n Gap from last chunk: {gap_time*1000:.1f}ms" - # f"\n Text processing: {text_process_time*1000:.1f}ms" - # f"\n Audio generation: {audio_gen_time*1000:.1f}ms" - # f"\n Audio conversion: {convert_time*1000:.1f}ms" - # f"\n Total chunk time: {(time.time() - chunk_start)*1000:.1f}ms" - # ) - + yield chunk_bytes is_first = False - # last_chunk_end = time.time() else: logger.error(f"No audio generated for chunk: '{current_chunk}'") except Exception as e: - logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}") - + logger.error( + f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}" + ) + current_chunk = next_chunk # Move to next chunk - + except Exception as e: logger.error(f"Error in audio generation stream: {str(e)}") raise - def _save_audio(self, audio: torch.Tensor, filepath: str): """Save audio to file""" os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -257,10 +260,10 @@ async def list_voices(self) -> List[str]: """List all available voices""" voices = [] try: - async with aiofiles.scandir(TTSModel.VOICES_DIR) as it: - async for entry in it: - if entry.name.endswith(".pt"): - voices.append(entry.name[:-3]) # Remove .pt extension + it = await aiofiles.os.scandir(TTSModel.VOICES_DIR) + for entry in it: + if entry.name.endswith(".pt"): + voices.append(entry.name[:-3]) # Remove .pt extension except Exception as e: logger.error(f"Error listing voices: {str(e)}") return sorted(voices) diff --git a/api/src/services/warmup.py b/api/src/services/warmup.py index 67937dd..1be2013 100644 --- a/api/src/services/warmup.py +++ b/api/src/services/warmup.py @@ -1,50 +1,58 @@ import os from typing import List, Tuple + import torch from loguru import logger -from .tts_service import TTSService from .tts_model import TTSModel +from .tts_service import TTSService +from ..core.config import settings class WarmupService: """Service for warming up TTS models and voice caches""" - + def __init__(self): + """Initialize warmup service and ensure model is ready""" + # Initialize model if not already initialized + if TTSModel._instance is None: + TTSModel.initialize(settings.model_dir) self.tts_service = TTSService() - + def load_voices(self) -> List[Tuple[str, torch.Tensor]]: """Load and cache voices up to LRU limit""" # Get all voices sorted by filename length (shorter names first, usually base voices) voice_files = sorted( - [f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], - key=len + [f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len ) - - # Load up to LRU cache limit (20) + + n_voices_cache = 1 loaded_voices = [] - for voice_file in voice_files[:20]: + for voice_file in voice_files[:n_voices_cache]: try: voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file) - voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) - loaded_voices.append((voice_file[:-3], voicepack)) # Store name and tensor + # load using service, lru cache + voicepack = self.tts_service._load_voice(voice_path) + loaded_voices.append( + (voice_file[:-3], voicepack) + ) # Store name and tensor + # voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) # logger.info(f"Loaded voice {voice_file[:-3]} into cache") except Exception as e: logger.error(f"Failed to load voice {voice_file}: {e}") logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache") return loaded_voices - - async def warmup_voices(self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]): + + async def warmup_voices( + self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]] + ): """Warm up voice inference and streaming""" n_warmups = 1 for voice_name, _ in loaded_voices[:n_warmups]: try: logger.info(f"Running warmup inference on voice {voice_name}") async for _ in self.tts_service.generate_audio_stream( - warmup_text, - voice_name, - 1.0, - "pcm" + warmup_text, voice_name, 1.0, "pcm" ): pass # Process all chunks to properly warm up logger.info(f"Completed warmup for voice {voice_name}") diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 48bc099..8db014c 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -1,14 +1,15 @@ from enum import Enum -from typing import Literal, Union, List +from typing import List, Union, Literal from pydantic import Field, BaseModel class VoiceCombineRequest(BaseModel): """Request schema for voice combination endpoint that accepts either a string with + or a list""" + voices: Union[str, List[str]] = Field( ..., - description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine" + description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine", ) diff --git a/api/src/structures/text_schemas.py b/api/src/structures/text_schemas.py index 5ae1b08..f820f68 100644 --- a/api/src/structures/text_schemas.py +++ b/api/src/structures/text_schemas.py @@ -1,9 +1,19 @@ -from pydantic import BaseModel +from pydantic import Field, BaseModel + class PhonemeRequest(BaseModel): text: str language: str = "a" # Default to American English + class PhonemeResponse(BaseModel): phonemes: str tokens: list[int] + + +class GenerateFromPhonemesRequest(BaseModel): + phonemes: str + voice: str = Field(..., description="Voice ID to use for generation") + speed: float = Field( + default=1.0, ge=0.1, le=5.0, description="Speed factor for generation" + ) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index c4a295a..900e6ae 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,8 +1,9 @@ import os import sys import shutil -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, MagicMock, patch +import numpy as np import pytest import aiofiles.threadpool @@ -36,6 +37,7 @@ def cleanup(): mock_torch.cuda = Mock() mock_torch.cuda.is_available = Mock(return_value=False) + # Create a mock tensor class that supports basic operations class MockTensor: def __init__(self, data): @@ -45,54 +47,57 @@ def __init__(self, data): elif isinstance(data, MockTensor): self.shape = data.shape else: - self.shape = getattr(data, 'shape', [1]) - + self.shape = getattr(data, "shape", [1]) + def __getitem__(self, idx): if isinstance(self.data, (list, tuple)): if isinstance(idx, slice): return MockTensor(self.data[idx]) return self.data[idx] return self - + def max(self): if isinstance(self.data, (list, tuple)): max_val = max(self.data) return MockTensor(max_val) return 5 # Default for testing - + def item(self): if isinstance(self.data, (list, tuple)): return max(self.data) if isinstance(self.data, (int, float)): return self.data return 5 # Default for testing - + def cuda(self): """Support cuda conversion""" return self - + def any(self): if isinstance(self.data, (list, tuple)): return any(self.data) return False - + def all(self): if isinstance(self.data, (list, tuple)): return all(self.data) return True - + def unsqueeze(self, dim): return self - + def expand(self, *args): return self - + def type_as(self, other): return self + # Add tensor operations to mock torch mock_torch.tensor = lambda x: MockTensor(x) -mock_torch.zeros = lambda *args: MockTensor([0] * (args[0] if isinstance(args[0], int) else args[0][0])) +mock_torch.zeros = lambda *args: MockTensor( + [0] * (args[0] if isinstance(args[0], int) else args[0][0]) +) mock_torch.arange = lambda x: MockTensor(list(range(x))) mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0]) @@ -106,22 +111,86 @@ def type_as(self, other): sys.modules["kokoro.generate"] = Mock() sys.modules["kokoro.phonemize"] = Mock() sys.modules["kokoro.tokenize"] = Mock() -sys.modules["onnxruntime"] = Mock() +# Mock ONNX runtime +mock_onnx = Mock() +mock_onnx.InferenceSession = Mock() +mock_onnx.SessionOptions = Mock() +mock_onnx.GraphOptimizationLevel = Mock() +mock_onnx.ExecutionMode = Mock() +sys.modules["onnxruntime"] = mock_onnx -@pytest.fixture(autouse=True) -def mock_tts_model(): - """Mock TTSModel and TTS model initialization""" - with patch("api.src.services.tts_model.TTSModel") as mock_tts_model, \ - patch("api.src.services.tts_base.TTSBaseModel") as mock_base_model: - - # Mock TTSModel - model_instance = Mock() - model_instance.get_instance.return_value = model_instance - model_instance.get_voicepack.return_value = None - mock_tts_model.get_instance.return_value = model_instance - - # Mock TTS model initialization - mock_base_model.setup.return_value = 1 # Return dummy voice count - - yield model_instance +# Create mock settings module +mock_settings_module = Mock() +mock_settings = Mock() +mock_settings.model_dir = "/mock/model/dir" +mock_settings.onnx_model_path = "mock.onnx" +mock_settings_module.settings = mock_settings +sys.modules["api.src.core.config"] = mock_settings_module + + +class MockTTSModel: + _instance = None + _onnx_session = None + VOICES_DIR = "/mock/voices/dir" + + def __init__(self): + self._initialized = False + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def initialize(cls, model_dir): + cls._onnx_session = Mock() + cls._onnx_session.run = Mock(return_value=[np.zeros(48000)]) + cls._instance._initialized = True + return cls._onnx_session + + @classmethod + def setup(cls): + if not cls._instance._initialized: + cls.initialize("/mock/model/dir") + return cls._instance + + @classmethod + def generate_from_tokens(cls, tokens, voicepack, speed): + if not cls._instance._initialized: + raise RuntimeError("Model not initialized. Call setup() first.") + return np.zeros(48000) + + @classmethod + def process_text(cls, text, language): + return "mock phonemes", [1, 2, 3] + + @staticmethod + def get_device(): + return "cpu" + + +@pytest.fixture +def mock_tts_service(monkeypatch): + """Mock TTSService for testing""" + mock_service = Mock() + mock_service._get_voice_path.return_value = "/mock/path/voice.pt" + mock_service._load_voice.return_value = np.zeros((1, 192)) + + # Mock TTSModel.generate_from_tokens since we call it directly + mock_generate = Mock(return_value=np.zeros(48000)) + monkeypatch.setattr( + "api.src.routers.development.TTSModel.generate_from_tokens", mock_generate + ) + + return mock_service + + +@pytest.fixture +def mock_audio_service(monkeypatch): + """Mock AudioService""" + mock_service = Mock() + mock_service.convert_audio.return_value = b"mock audio data" + monkeypatch.setattr("api.src.routers.development.AudioService", mock_service) + return mock_service diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 32f4300..758e4f4 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -1,9 +1,19 @@ """Tests for AudioService""" +from unittest.mock import patch + import numpy as np import pytest -from api.src.services.audio import AudioService +from api.src.services.audio import AudioService, AudioNormalizer + + +@pytest.fixture(autouse=True) +def mock_settings(): + """Mock settings for all tests""" + with patch("api.src.services.audio.settings") as mock_settings: + mock_settings.gap_trim_ms = 250 + yield mock_settings @pytest.fixture @@ -53,7 +63,7 @@ def test_convert_to_aac_raises_error(sample_audio): audio_data, sample_rate = sample_audio with pytest.raises( ValueError, - match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.", + match="Failed to convert audio to aac: Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm.", ): AudioService.convert_audio(audio_data, sample_rate, "aac") diff --git a/api/tests/test_chunker.py b/api/tests/test_chunker.py index ed598c0..002da72 100644 --- a/api/tests/test_chunker.py +++ b/api/tests/test_chunker.py @@ -1,9 +1,20 @@ """Tests for text chunking service""" +from unittest.mock import patch + import pytest + from api.src.services.text_processing import chunker +@pytest.fixture(autouse=True) +def mock_settings(): + """Mock settings for all tests""" + with patch("api.src.services.text_processing.chunker.settings") as mock_settings: + mock_settings.max_chunk_size = 300 + yield mock_settings + + def test_split_text(): """Test text splitting into sentences""" text = "First sentence. Second sentence! Third sentence?" diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index bd9e578..c3bcb43 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,16 +1,17 @@ +import asyncio from unittest.mock import Mock, AsyncMock import pytest import pytest_asyncio -import asyncio -from fastapi.testclient import TestClient from httpx import AsyncClient +from fastapi.testclient import TestClient from ..src.main import app # Create test client client = TestClient(app) + # Create async client fixture @pytest_asyncio.fixture async def async_client(): @@ -23,25 +24,28 @@ async def async_client(): def mock_tts_service(monkeypatch): mock_service = Mock() mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0) - + # Create proper async generator mock async def mock_stream(*args, **kwargs): for chunk in [b"chunk1", b"chunk2"]: yield chunk + mock_service.generate_audio_stream = mock_stream - + # Create async mocks - mock_service.list_voices = AsyncMock(return_value=[ - "af", - "bm_lewis", - "bf_isabella", - "bf_emma", - "af_sarah", - "af_bella", - "am_adam", - "am_michael", - "bm_george", - ]) + mock_service.list_voices = AsyncMock( + return_value=[ + "af", + "bm_lewis", + "bf_isabella", + "bf_emma", + "af_sarah", + "af_bella", + "am_adam", + "am_michael", + "bm_george", + ] + ) mock_service.combine_voices = AsyncMock() monkeypatch.setattr( "api.src.routers.openai_compatible.TTSService", @@ -54,9 +58,7 @@ async def mock_stream(*args, **kwargs): def mock_audio_service(monkeypatch): mock_service = Mock() mock_service.convert_audio.return_value = b"converted mock audio data" - monkeypatch.setattr( - "api.src.routers.openai_compatible.AudioService", mock_service - ) + monkeypatch.setattr("api.src.routers.openai_compatible.AudioService", mock_service) return mock_service @@ -68,7 +70,9 @@ def test_health_check(): @pytest.mark.asyncio -async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client): +async def test_openai_speech_endpoint( + mock_tts_service, mock_audio_service, async_client +): """Test the OpenAI-compatible speech endpoint""" test_request = { "model": "kokoro", @@ -76,7 +80,7 @@ async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, asyn "voice": "bm_lewis", "response_format": "wav", "speed": 1.0, - "stream": False # Explicitly disable streaming + "stream": False, # Explicitly disable streaming } response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 200 @@ -97,7 +101,7 @@ async def test_openai_speech_invalid_voice(mock_tts_service, async_client): "voice": "invalid_voice", "response_format": "wav", "speed": 1.0, - "stream": False # Explicitly disable streaming + "stream": False, # Explicitly disable streaming } response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 400 # Bad request @@ -113,7 +117,7 @@ async def test_openai_speech_invalid_speed(mock_tts_service, async_client): "voice": "af", "response_format": "wav", "speed": -1.0, # Invalid speed - "stream": False # Explicitly disable streaming + "stream": False, # Explicitly disable streaming } response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 422 # Validation error @@ -129,7 +133,7 @@ async def test_openai_speech_generation_error(mock_tts_service, async_client): "voice": "af", "response_format": "wav", "speed": 1.0, - "stream": False # Explicitly disable streaming + "stream": False, # Explicitly disable streaming } response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 500 @@ -159,7 +163,9 @@ async def test_combine_voices_string_success(mock_tts_service, async_client): assert response.status_code == 200 assert response.json()["voice"] == "af_bella_af_sarah" - mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"]) + mock_tts_service.combine_voices.assert_called_once_with( + voices=["af_bella", "af_sarah"] + ) @pytest.mark.asyncio @@ -184,7 +190,9 @@ async def test_combine_voices_empty_list(mock_tts_service, async_client): async def test_combine_voices_error(mock_tts_service, async_client): """Test error handling in voice combination""" test_voices = ["af_bella", "af_sarah"] - mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed")) + mock_tts_service.combine_voices = AsyncMock( + side_effect=Exception("Combination failed") + ) response = await async_client.post("/v1/audio/voices/combine", json=test_voices) assert response.status_code == 500 @@ -192,50 +200,56 @@ async def test_combine_voices_error(mock_tts_service, async_client): @pytest.mark.asyncio -async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client): +async def test_speech_with_combined_voice( + mock_tts_service, mock_audio_service, async_client +): """Test speech generation with combined voice using + syntax""" mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah") - + test_request = { "model": "kokoro", "input": "Hello world", "voice": "af_bella+af_sarah", "response_format": "wav", "speed": 1.0, - "stream": False + "stream": False, } - + response = await async_client.post("/v1/audio/speech", json=test_request) - + assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" mock_tts_service._generate_audio.assert_called_once_with( - text="Hello world", - voice="af_bella_af_sarah", - speed=1.0, - stitch_long_output=True + text="Hello world", + voice="af_bella_af_sarah", + speed=1.0, + stitch_long_output=True, ) @pytest.mark.asyncio -async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client): +async def test_speech_with_whitespace_in_voice( + mock_tts_service, mock_audio_service, async_client +): """Test speech generation with whitespace in voice combination""" mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah") - + test_request = { "model": "kokoro", "input": "Hello world", "voice": " af_bella + af_sarah ", "response_format": "wav", "speed": 1.0, - "stream": False + "stream": False, } - + response = await async_client.post("/v1/audio/speech", json=test_request) - + assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" - mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"]) + mock_tts_service.combine_voices.assert_called_once_with( + voices=["af_bella", "af_sarah"] + ) @pytest.mark.asyncio @@ -247,9 +261,9 @@ async def test_speech_with_empty_voice_combination(mock_tts_service, async_clien "voice": "+", "response_format": "wav", "speed": 1.0, - "stream": False + "stream": False, } - + response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 400 assert "No voices provided" in response.json()["detail"]["message"] @@ -264,9 +278,9 @@ async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client "voice": "invalid+combination", "response_format": "wav", "speed": 1.0, - "stream": False + "stream": False, } - + response = await async_client.post("/v1/audio/speech", json=test_request) assert response.status_code == 400 assert "not found" in response.json()["detail"]["message"] @@ -276,25 +290,28 @@ async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client): """Test streaming speech with combined voice using + syntax""" mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah") - + test_request = { "model": "kokoro", "input": "Hello world", "voice": "af_bella+af_sarah", "response_format": "mp3", - "stream": True + "stream": True, } - + # Create streaming mock async def mock_stream(*args, **kwargs): for chunk in [b"mp3header", b"mp3data"]: yield chunk + mock_tts_service.generate_audio_stream = mock_stream - + # Add streaming header headers = {"x-raw-response": "stream"} - response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) - + response = await async_client.post( + "/v1/audio/speech", json=test_request, headers=headers + ) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/mpeg" assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" @@ -308,19 +325,22 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client): "input": "Hello world", "voice": "af", "response_format": "pcm", - "stream": True + "stream": True, } - + # Create streaming mock for this test async def mock_stream(*args, **kwargs): for chunk in [b"chunk1", b"chunk2"]: yield chunk + mock_tts_service.generate_audio_stream = mock_stream - + # Add streaming header headers = {"x-raw-response": "stream"} - response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) - + response = await async_client.post( + "/v1/audio/speech", json=test_request, headers=headers + ) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/pcm" @@ -333,19 +353,22 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client): "input": "Hello world", "voice": "af", "response_format": "mp3", - "stream": True + "stream": True, } - + # Create streaming mock for this test async def mock_stream(*args, **kwargs): for chunk in [b"mp3header", b"mp3data"]: yield chunk + mock_tts_service.generate_audio_stream = mock_stream - + # Add streaming header headers = {"x-raw-response": "stream"} - response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) - + response = await async_client.post( + "/v1/audio/speech", json=test_request, headers=headers + ) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/mpeg" assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" @@ -359,18 +382,21 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client) "input": "Hello world", "voice": "af", "response_format": "pcm", - "stream": True + "stream": True, } - + # Create streaming mock for this test async def mock_stream(*args, **kwargs): for chunk in [b"chunk1", b"chunk2"]: yield chunk + mock_tts_service.generate_audio_stream = mock_stream - + # Add streaming header headers = {"x-raw-response": "stream"} - response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) - + response = await async_client.post( + "/v1/audio/speech", json=test_request, headers=headers + ) + assert response.status_code == 200 assert response.headers["content-type"] == "audio/pcm" diff --git a/api/tests/test_main.py b/api/tests/test_main.py index cb7aa8b..f779483 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,6 +1,6 @@ """Tests for FastAPI application""" -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, call, patch import pytest from fastapi.testclient import TestClient @@ -28,14 +28,15 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): """Test successful model warmup in lifespan""" # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" - + # Create async mock async def async_setup(): return 3 + mock_tts_model.setup = MagicMock() mock_tts_model.setup.side_effect = async_setup mock_tts_model.get_device.return_value = "cuda" - + with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]): # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) @@ -44,7 +45,7 @@ async def async_setup(): # Verify the expected logging sequence mock_logger.info.assert_any_call("Loading TTS model and voice packs...") - + # Check for the startup message containing the required info startup_calls = [call[0][0] for call in mock_logger.info.call_args_list] startup_msg = next(msg for msg in startup_calls if "Model warmed up on" in msg) @@ -86,14 +87,15 @@ async def test_lifespan_cuda_warmup(mock_tts_model): """Test model warmup specifically on CUDA""" # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" - + # Create async mock async def async_setup(): return 2 + mock_tts_model.setup = MagicMock() mock_tts_model.setup.side_effect = async_setup mock_tts_model.get_device.return_value = "cuda" - + with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]): # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) diff --git a/api/tests/test_normalizer.py b/api/tests/test_normalizer.py index 9555e22..9146252 100644 --- a/api/tests/test_normalizer.py +++ b/api/tests/test_normalizer.py @@ -1,43 +1,88 @@ """Tests for text normalization service""" import pytest + from api.src.services.text_processing.normalizer import normalize_text + def test_url_protocols(): """Test URL protocol handling""" - assert normalize_text("Check out https://example.com") == "Check out https example dot com" + assert ( + normalize_text("Check out https://example.com") + == "Check out https example dot com" + ) assert normalize_text("Visit http://site.com") == "Visit http site dot com" - assert normalize_text("Go to https://test.org/path") == "Go to https test dot org slash path" + assert ( + normalize_text("Go to https://test.org/path") + == "Go to https test dot org slash path" + ) + def test_url_www(): """Test www prefix handling""" assert normalize_text("Go to www.example.com") == "Go to www example dot com" - assert normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs" - assert normalize_text("Check www.site.com?q=test") == "Check www site dot com question-mark q equals test" + assert ( + normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs" + ) + assert ( + normalize_text("Check www.site.com?q=test") + == "Check www site dot com question-mark q equals test" + ) + def test_url_localhost(): """Test localhost URL handling""" - assert normalize_text("Running on localhost:7860") == "Running on localhost colon 78 60" - assert normalize_text("Server at localhost:8080/api") == "Server at localhost colon 80 80 slash api" - assert normalize_text("Test localhost:3000/test?v=1") == "Test localhost colon 3000 slash test question-mark v equals 1" + assert ( + normalize_text("Running on localhost:7860") + == "Running on localhost colon 78 60" + ) + assert ( + normalize_text("Server at localhost:8080/api") + == "Server at localhost colon 80 80 slash api" + ) + assert ( + normalize_text("Test localhost:3000/test?v=1") + == "Test localhost colon 3000 slash test question-mark v equals 1" + ) + def test_url_ip_addresses(): """Test IP address URL handling""" - assert normalize_text("Access 0.0.0.0:9090/test") == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" - assert normalize_text("API at 192.168.1.1:8000") == "API at 192 dot 168 dot 1 dot 1 colon 8000" + assert ( + normalize_text("Access 0.0.0.0:9090/test") + == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" + ) + assert ( + normalize_text("API at 192.168.1.1:8000") + == "API at 192 dot 168 dot 1 dot 1 colon 8000" + ) assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1" + def test_url_raw_domains(): """Test raw domain handling""" - assert normalize_text("Visit google.com/search") == "Visit google dot com slash search" - assert normalize_text("Go to example.com/path?q=test") == "Go to example dot com slash path question-mark q equals test" + assert ( + normalize_text("Visit google.com/search") == "Visit google dot com slash search" + ) + assert ( + normalize_text("Go to example.com/path?q=test") + == "Go to example dot com slash path question-mark q equals test" + ) assert normalize_text("Check docs.test.com") == "Check docs dot test dot com" + def test_url_email_addresses(): """Test email address handling""" - assert normalize_text("Email me at user@example.com") == "Email me at user at example dot com" + assert ( + normalize_text("Email me at user@example.com") + == "Email me at user at example dot com" + ) assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org" - assert normalize_text("Send to test.user@site.com") == "Send to test dot user at site dot com" + assert ( + normalize_text("Send to test.user@site.com") + == "Send to test dot user at site dot com" + ) + def test_non_url_text(): """Test that non-URL text is unaffected""" diff --git a/api/tests/test_text_processing.py b/api/tests/test_text_processing.py new file mode 100644 index 0000000..aacb973 --- /dev/null +++ b/api/tests/test_text_processing.py @@ -0,0 +1,122 @@ +"""Tests for text processing endpoints""" + +from unittest.mock import Mock, patch + +import numpy as np +import pytest +import pytest_asyncio +from httpx import AsyncClient + +from .conftest import MockTTSModel +from ..src.main import app + + +@pytest_asyncio.fixture +async def async_client(): + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac + + +@pytest.mark.asyncio +async def test_phonemize_endpoint(async_client): + """Test phoneme generation endpoint""" + with patch("api.src.routers.development.phonemize") as mock_phonemize, patch( + "api.src.routers.development.tokenize" + ) as mock_tokenize: + # Setup mocks + mock_phonemize.return_value = "həlˈoʊ" + mock_tokenize.return_value = [1, 2, 3] + + # Test request + response = await async_client.post( + "/text/phonemize", json={"text": "hello", "language": "a"} + ) + + # Verify response + assert response.status_code == 200 + result = response.json() + assert result["phonemes"] == "həlˈoʊ" + assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens + + +@pytest.mark.asyncio +async def test_phonemize_empty_text(async_client): + """Test phoneme generation with empty text""" + response = await async_client.post( + "/text/phonemize", json={"text": "", "language": "a"} + ) + + assert response.status_code == 500 + assert "error" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_generate_from_phonemes( + async_client, mock_tts_service, mock_audio_service +): + """Test audio generation from phonemes""" + with patch( + "api.src.routers.development.TTSService", return_value=mock_tts_service + ): + response = await async_client.post( + "/text/generate_from_phonemes", + json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": 1.0}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert ( + response.headers["content-disposition"] == "attachment; filename=speech.wav" + ) + assert response.content == b"mock audio data" + + +@pytest.mark.asyncio +async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service): + """Test audio generation with invalid voice""" + mock_tts_service._get_voice_path.return_value = None + with patch( + "api.src.routers.development.TTSService", return_value=mock_tts_service + ): + response = await async_client.post( + "/text/generate_from_phonemes", + json={"phonemes": "həlˈoʊ", "voice": "invalid_voice", "speed": 1.0}, + ) + + assert response.status_code == 400 + assert "Voice not found" in response.json()["detail"]["message"] + + +@pytest.mark.asyncio +async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch): + """Test audio generation with invalid speed""" + # Mock TTSModel initialization + mock_model = Mock() + mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000)) + monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model) + monkeypatch.setattr( + "api.src.services.tts_model.TTSModel.get_instance", + Mock(return_value=mock_model), + ) + + response = await async_client.post( + "/text/generate_from_phonemes", + json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": -1.0}, + ) + + assert response.status_code == 422 # Validation error + + +@pytest.mark.asyncio +async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service): + """Test audio generation with empty phonemes""" + with patch( + "api.src.routers.development.TTSService", return_value=mock_tts_service + ): + response = await async_client.post( + "/text/generate_from_phonemes", + json={"phonemes": "", "voice": "af_bella", "speed": 1.0}, + ) + + assert response.status_code == 400 + assert "Invalid request" in response.json()["detail"]["error"] diff --git a/api/tests/test_tts_implementations.py b/api/tests/test_tts_implementations.py index 9e92392..99b28bf 100644 --- a/api/tests/test_tts_implementations.py +++ b/api/tests/test_tts_implementations.py @@ -1,13 +1,16 @@ """Tests for TTS model implementations""" + import os +from unittest.mock import MagicMock, patch + +import numpy as np import torch import pytest -import numpy as np -from unittest.mock import patch, MagicMock -from api.src.services.tts_base import TTSBaseModel from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_gpu import TTSGPUModel, length_to_mask +from api.src.services.tts_base import TTSBaseModel + # Base Model Tests def test_get_device_error(): @@ -16,14 +19,17 @@ def test_get_device_error(): with pytest.raises(RuntimeError, match="Model not initialized"): TTSBaseModel.get_device() + @pytest.mark.asyncio -@patch('torch.cuda.is_available') -@patch('os.path.exists') -@patch('os.path.join') -@patch('os.listdir') -@patch('torch.load') -@patch('torch.save') -async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available): +@patch("torch.cuda.is_available") +@patch("os.path.exists") +@patch("os.path.join") +@patch("os.listdir") +@patch("torch.load") +@patch("torch.save") +async def test_setup_cuda_available( + mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available +): """Test setup with CUDA available""" TTSBaseModel._device = None mock_cuda_available.return_value = True @@ -31,24 +37,32 @@ async def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_joi mock_load.return_value = torch.zeros(1) mock_listdir.return_value = ["voice1.pt", "voice2.pt"] mock_join.return_value = "/mocked/path" - - # Mock the abstract methods - TTSBaseModel.initialize = MagicMock(return_value=True) - TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3])) - TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) - + + # Create mock model + mock_model = MagicMock() + mock_model.bert = MagicMock() + mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3])) + mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) + + # Mock initialize to return our mock model + TTSBaseModel.initialize = MagicMock(return_value=mock_model) + TTSBaseModel._instance = mock_model + voice_count = await TTSBaseModel.setup() assert TTSBaseModel._device == "cuda" assert voice_count == 2 + @pytest.mark.asyncio -@patch('torch.cuda.is_available') -@patch('os.path.exists') -@patch('os.path.join') -@patch('os.listdir') -@patch('torch.load') -@patch('torch.save') -async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available): +@patch("torch.cuda.is_available") +@patch("os.path.exists") +@patch("os.path.join") +@patch("os.listdir") +@patch("torch.load") +@patch("torch.save") +async def test_setup_cuda_unavailable( + mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available +): """Test setup with CUDA unavailable""" TTSBaseModel._device = None mock_cuda_available.return_value = False @@ -56,91 +70,105 @@ async def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_j mock_load.return_value = torch.zeros(1) mock_listdir.return_value = ["voice1.pt", "voice2.pt"] mock_join.return_value = "/mocked/path" - - # Mock the abstract methods - TTSBaseModel.initialize = MagicMock(return_value=True) - TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3])) - TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) - + + # Create mock model + mock_model = MagicMock() + mock_model.bert = MagicMock() + mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3])) + mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000)) + + # Mock initialize to return our mock model + TTSBaseModel.initialize = MagicMock(return_value=mock_model) + TTSBaseModel._instance = mock_model + voice_count = await TTSBaseModel.setup() assert TTSBaseModel._device == "cpu" assert voice_count == 2 + # CPU Model Tests def test_cpu_initialize_missing_model(): """Test CPU initialize with missing model""" - with patch('os.path.exists', return_value=False): + TTSCPUModel._onnx_session = None # Reset the session + with patch("os.path.exists", return_value=False), patch( + "onnxruntime.InferenceSession", return_value=None + ): result = TTSCPUModel.initialize("dummy_dir") assert result is None + def test_cpu_generate_uninitialized(): """Test CPU generate methods with uninitialized model""" TTSCPUModel._onnx_session = None - + with pytest.raises(RuntimeError, match="ONNX model not initialized"): TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0) - + with pytest.raises(RuntimeError, match="ONNX model not initialized"): - TTSCPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0) + TTSCPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0) + def test_cpu_process_text(): """Test CPU process_text functionality""" - with patch('api.src.services.tts_cpu.phonemize') as mock_phonemize, \ - patch('api.src.services.tts_cpu.tokenize') as mock_tokenize: - + with patch("api.src.services.tts_cpu.phonemize") as mock_phonemize, patch( + "api.src.services.tts_cpu.tokenize" + ) as mock_tokenize: mock_phonemize.return_value = "test phonemes" mock_tokenize.return_value = [1, 2, 3] - + phonemes, tokens = TTSCPUModel.process_text("test", "en") assert phonemes == "test phonemes" assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens + # GPU Model Tests -@patch('torch.cuda.is_available') +@patch("torch.cuda.is_available") def test_gpu_initialize_cuda_unavailable(mock_cuda_available): """Test GPU initialize with CUDA unavailable""" mock_cuda_available.return_value = False TTSGPUModel._instance = None - + result = TTSGPUModel.initialize("dummy_dir", "dummy_path") assert result is None -@patch('api.src.services.tts_gpu.length_to_mask') + +@patch("api.src.services.tts_gpu.length_to_mask") def test_gpu_length_to_mask(mock_length_to_mask): """Test length_to_mask function""" # Setup mock return value - expected_mask = torch.tensor([ - [False, False, False, True, True], - [False, False, False, False, False] - ]) + expected_mask = torch.tensor( + [[False, False, False, True, True], [False, False, False, False, False]] + ) mock_length_to_mask.return_value = expected_mask - + # Call function with test input lengths = torch.tensor([3, 5]) mask = mock_length_to_mask(lengths) - + # Verify mock was called with correct input mock_length_to_mask.assert_called_once() assert torch.equal(mask, expected_mask) + def test_gpu_generate_uninitialized(): """Test GPU generate methods with uninitialized model""" TTSGPUModel._instance = None - + with pytest.raises(RuntimeError, match="GPU model not initialized"): TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0) - + with pytest.raises(RuntimeError, match="GPU model not initialized"): - TTSGPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0) + TTSGPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0) + def test_gpu_process_text(): """Test GPU process_text functionality""" - with patch('api.src.services.tts_gpu.phonemize') as mock_phonemize, \ - patch('api.src.services.tts_gpu.tokenize') as mock_tokenize: - + with patch("api.src.services.tts_gpu.phonemize") as mock_phonemize, patch( + "api.src.services.tts_gpu.tokenize" + ) as mock_tokenize: mock_phonemize.return_value = "test phonemes" mock_tokenize.return_value = [1, 2, 3] - + phonemes, tokens = TTSGPUModel.process_text("test", "en") assert phonemes == "test phonemes" assert tokens == [1, 2, 3] # GPU implementation doesn't add start/end tokens diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index e3c3da9..0f613da 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -9,15 +9,30 @@ from onnxruntime import InferenceSession from api.src.core.config import settings -from api.src.services.tts_model import TTSModel -from api.src.services.tts_service import TTSService from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_gpu import TTSGPUModel +from api.src.services.tts_model import TTSModel +from api.src.services.tts_service import TTSService @pytest.fixture -def tts_service(): +def tts_service(monkeypatch): """Create a TTSService instance for testing""" + # Mock TTSModel initialization + mock_model = MagicMock() + mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(48000)) + mock_model.process_text = MagicMock(return_value=("mock phonemes", [1, 2, 3])) + + # Set up model instance + monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model) + monkeypatch.setattr( + "api.src.services.tts_model.TTSModel.get_instance", + MagicMock(return_value=mock_model), + ) + monkeypatch.setattr( + "api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu") + ) + return TTSService() @@ -41,13 +56,15 @@ def test_audio_to_bytes(tts_service, sample_audio): @pytest.mark.asyncio async def test_list_voices(tts_service): """Test listing available voices""" - # Override list_voices for testing - # # TODO: + + # Override list_voices for testing + # # TODO: # Whatever aiofiles does here pathing aiofiles vs aiofiles.os - # I am thoroughly confused by it. + # I am thoroughly confused by it. # Cheating the test as it seems to work in the real world (for now) async def mock_list_voices(): return ["voice1", "voice2"] + tts_service.list_voices = mock_list_voices voices = await tts_service.list_voices() @@ -59,10 +76,12 @@ async def mock_list_voices(): @pytest.mark.asyncio async def test_list_voices_error(tts_service): """Test error handling in list_voices""" + # Override list_voices for testing # TODO: See above. async def mock_list_voices(): return [] + tts_service.list_voices = mock_list_voices voices = await tts_service.list_voices() @@ -83,7 +102,7 @@ def mock_model_setup(cuda_available=False): # Set device based on CUDA availability TTSModel._device = "cuda" if cuda_available else "cpu" - + return 3 # Return voice count (including af.pt) @@ -91,7 +110,7 @@ def test_model_initialization_cuda(): """Test model initialization with CUDA""" # Simulate CUDA availability voice_count = mock_model_setup(cuda_available=True) - + assert TTSModel.get_device() == "cuda" assert voice_count == 3 # voice1.pt, voice2.pt, af.pt @@ -100,7 +119,7 @@ def test_model_initialization_cpu(): """Test model initialization with CPU""" # Simulate no CUDA availability voice_count = mock_model_setup(cuda_available=False) - + assert TTSModel.get_device() == "cpu" assert voice_count == 3 # voice1.pt, voice2.pt, af.pt @@ -111,6 +130,14 @@ def test_generate_audio_empty_text(tts_service): tts_service._generate_audio("", "af", 1.0) +@pytest.fixture(autouse=True) +def mock_settings(): + """Mock settings for all tests""" + with patch("api.src.services.text_processing.chunker.settings") as mock_settings: + mock_settings.max_chunk_size = 300 + yield mock_settings + + @patch("api.src.services.tts_model.TTSModel.get_instance") @patch("api.src.services.tts_model.TTSModel.get_device") @patch("os.path.exists") @@ -133,7 +160,10 @@ def test_generate_audio_phonemize_error( """Test handling phonemization error""" mock_normalize.return_value = "Test text" mock_phonemize.side_effect = Exception("Phonemization failed") - mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency + mock_instance.return_value = ( + mock_generate, + "cpu", + ) # Use the same mock for consistency mock_get_device.return_value = "cpu" mock_exists.return_value = True mock_torch_load.return_value = torch.zeros((10, 24000)) @@ -167,7 +197,10 @@ def test_generate_audio_error( mock_phonemize.return_value = "Test text" mock_tokenize.return_value = [1, 2] # Return integers instead of strings mock_generate.side_effect = Exception("Generation failed") - mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency + mock_instance.return_value = ( + mock_generate, + "cpu", + ) # Use the same mock for consistency mock_get_device.return_value = "cpu" mock_exists.return_value = True mock_torch_load.return_value = torch.zeros((10, 24000)) @@ -188,12 +221,11 @@ def test_save_audio(tts_service, sample_audio, tmp_path): async def test_combine_voices(tts_service): """Test combining multiple voices""" # Setup mocks for torch operations - with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \ - patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \ - patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \ - patch('torch.save'), \ - patch('os.path.exists', return_value=True): - + with patch("torch.load", return_value=torch.tensor([1.0, 2.0])), patch( + "torch.stack", return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + ), patch("torch.mean", return_value=torch.tensor([2.0, 3.0])), patch( + "torch.save" + ), patch("os.path.exists", return_value=True): # Test combining two voices result = await tts_service.combine_voices(["voice1", "voice2"]) diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index 5bccbe2..32c8710 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -1,3 +1,4 @@ +name: kokoro-fastapi services: model-fetcher: image: datamachines/git-lfs:latest @@ -6,6 +7,8 @@ services: working_dir: /app/Kokoro-82M command: > sh -c " + mkdir -p /app/Kokoro-82M; + cd /app/Kokoro-82M; rm -f .git/index.lock; if [ -z \"$(ls -A .)\" ]; then git clone https://huggingface.co/hexgrad/Kokoro-82M . @@ -26,11 +29,11 @@ services: start_period: 1s kokoro-tts: - image: ghcr.io/remsky/kokoro-fastapi:latest-cpu - # Uncomment below to build from source instead of using the released image - # build: - # context: . - # dockerfile: Dockerfile.cpu + image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.0.5post1 + # Uncomment below (and comment out above) to build from source instead of using the released image + build: + context: . + dockerfile: Dockerfile.cpu volumes: - ./api/src:/app/api/src - ./Kokoro-82M:/app/Kokoro-82M @@ -46,16 +49,23 @@ services: - ONNX_MEMORY_PATTERN=true - ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8880/health"] + interval: 10s + timeout: 5s + retries: 30 + start_period: 30s depends_on: model-fetcher: condition: service_healthy + # Gradio UI service [Comment out everything below if you don't need it] gradio-ui: - image: ghcr.io/remsky/kokoro-fastapi:latest-ui - # Uncomment below to build from source instead of using the released image + image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1 + # Uncomment below (and comment out above) to build from source instead of using the released image # build: - # context: ./ui + # context: ./ui ports: - "7860:7860" volumes: @@ -63,3 +73,7 @@ services: - ./ui/app.py:/app/app.py # Mount app.py for hot reload environment: - GRADIO_WATCH=True # Enable hot reloading + - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered + depends_on: + kokoro-tts: + condition: service_healthy diff --git a/docker-compose.yml b/docker-compose.yml index 1958f72..9ffe144 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,3 +1,4 @@ +name: kokoro-fastapi services: model-fetcher: image: datamachines/git-lfs:latest @@ -32,10 +33,10 @@ services: start_period: 1s kokoro-tts: - # image: ghcr.io/remsky/kokoro-fastapi:latest - # Uncomment below to build from source instead of using the released image - build: - context: . + image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.0.5post1 + # Uncomment below (and comment out above) to build from source instead of using the released image + # build: + # context: . volumes: - ./api/src:/app/api/src - ./Kokoro-82M:/app/Kokoro-82M @@ -50,14 +51,20 @@ services: - driver: nvidia count: 1 capabilities: [gpu] + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8880/health"] + interval: 10s + timeout: 5s + retries: 30 + start_period: 30s depends_on: model-fetcher: condition: service_healthy # Gradio UI service [Comment out everything below if you don't need it] gradio-ui: - image: ghcr.io/remsky/kokoro-fastapi:latest-ui - # Uncomment below to build from source instead of using the released image + image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1 + # Uncomment below (and comment out above) to build from source instead of using the released image # build: # context: ./ui ports: @@ -67,3 +74,7 @@ services: - ./ui/app.py:/app/app.py # Mount app.py for hot reload environment: - GRADIO_WATCH=True # Enable hot reloading + - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered + depends_on: + kokoro-tts: + condition: service_healthy diff --git a/examples/assorted_checks/benchmarks/benchmark_first_token_stream_unified.py b/examples/assorted_checks/benchmarks/benchmark_first_token_stream_unified.py index 0b673ae..df1855c 100644 --- a/examples/assorted_checks/benchmarks/benchmark_first_token_stream_unified.py +++ b/examples/assorted_checks/benchmarks/benchmark_first_token_stream_unified.py @@ -166,7 +166,7 @@ def measure_first_token_openai( def main(): script_dir = os.path.dirname(os.path.abspath(__file__)) - prefix='cpu' + prefix = "cpu" # Run requests benchmark print("\n=== Running Direct Requests Benchmark ===") run_benchmark( @@ -176,7 +176,7 @@ def main(): output_plots_dir=os.path.join(script_dir, "output_plots"), suffix="_stream", plot_title_suffix="(Streaming)", - prefix=prefix + prefix=prefix, ) # Run OpenAI benchmark print("\n=== Running OpenAI Library Benchmark ===") @@ -187,7 +187,7 @@ def main(): output_plots_dir=os.path.join(script_dir, "output_plots"), suffix="_stream_openai", plot_title_suffix="(OpenAI Streaming)", - prefix=prefix + prefix=prefix, ) diff --git a/examples/assorted_checks/benchmarks/lib/stream_utils.py b/examples/assorted_checks/benchmarks/lib/stream_utils.py index 623b18a..d2decec 100644 --- a/examples/assorted_checks/benchmarks/lib/stream_utils.py +++ b/examples/assorted_checks/benchmarks/lib/stream_utils.py @@ -149,19 +149,19 @@ def run_benchmark( result["run_number"] = i + 1 # Handle time to first audio - first_chunk = result.get('time_to_first_chunk') + first_chunk = result.get("time_to_first_chunk") print( f"Time to First Audio: {f'{first_chunk:.3f}s' if first_chunk is not None else 'N/A'}" ) - + # Handle total time - total_time = result.get('total_time') + total_time = result.get("total_time") print( f"Time to Save Complete: {f'{total_time:.3f}s' if total_time is not None else 'N/A'}" ) - + # Handle audio length - audio_length = result.get('audio_length') + audio_length = result.get("audio_length") print( f"Audio length: {f'{audio_length:.3f}s' if audio_length is not None else 'N/A'}" ) @@ -191,10 +191,18 @@ def run_benchmark( # Print paths print("\nResults and plots saved to:") - print(f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}") - print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}") - print(f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}") - print(f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}") + print( + f"- {os.path.join(output_data_dir, f'{prefix}first_token_benchmark{suffix}.json')}" + ) + print( + f"- {os.path.join(output_plots_dir, f'{prefix}first_token_latency{suffix}.png')}" + ) + print( + f"- {os.path.join(output_plots_dir, f'{prefix}total_time_latency{suffix}.png')}" + ) + print( + f"- {os.path.join(output_plots_dir, f'{prefix}first_token_timeline{suffix}.png')}" + ) # Print silence check summary if silent_files: diff --git a/examples/openai_streaming_audio.py b/examples/openai_streaming_audio.py index dc16c55..35ef58f 100644 --- a/examples/openai_streaming_audio.py +++ b/examples/openai_streaming_audio.py @@ -1,6 +1,4 @@ - #!/usr/bin/env rye run python - import time from pathlib import Path @@ -18,25 +16,29 @@ def main() -> None: # Create text-to-speech audio file with openai.audio.speech.with_streaming_response.create( model="kokoro", - voice="af", + voice="af_bella", input="the quick brown fox jumped over the lazy dogs", ) as response: response.stream_to_file(speech_file_path) - def stream_to_speakers() -> None: import pyaudio - player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True) + player_stream = pyaudio.PyAudio().open( + format=pyaudio.paInt16, channels=1, rate=24000, output=True + ) start_time = time.time() with openai.audio.speech.with_streaming_response.create( model="kokoro", - voice="af_sky+af_bella+af_nicole+bm_george", + voice="af_bella", response_format="pcm", # similar to WAV, but without a header chunk at the start. - input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earth’s surface""", + input="""I see skies of blue and clouds of white + The bright blessed days, the dark sacred nights + And I think to myself + What a wonderful world""", ) as response: print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms") for chunk in response.iter_bytes(chunk_size=1024): diff --git a/examples/phoneme_examples/generate_phonemes.py b/examples/phoneme_examples/generate_phonemes.py new file mode 100644 index 0000000..6b261a8 --- /dev/null +++ b/examples/phoneme_examples/generate_phonemes.py @@ -0,0 +1,104 @@ +import json +from typing import Tuple, Optional +from pathlib import Path + +import requests + +# Get the directory this script is in +SCRIPT_DIR = Path(__file__).parent.absolute() + + +def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]: + """Get phonemes and tokens for input text. + + Args: + text: Input text to convert to phonemes + language: Language code (defaults to "a" for American English) + + Returns: + Tuple of (phonemes string, token list) + """ + # Create the request payload + payload = {"text": text, "language": language} + + # Make POST request to the phonemize endpoint + response = requests.post("http://localhost:8880/text/phonemize", json=payload) + + # Raise exception for error status codes + response.raise_for_status() + + # Parse the response + result = response.json() + return result["phonemes"], result["tokens"] + + +def generate_audio_from_phonemes( + phonemes: str, voice: str = "af_bella", speed: float = 1.0 +) -> Optional[bytes]: + """Generate audio from phonemes. + + Args: + phonemes: Phoneme string to synthesize + voice: Voice ID to use (defaults to af_bella) + speed: Speed factor (defaults to 1.0) + + Returns: + WAV audio bytes if successful, None if failed + """ + # Create the request payload + payload = {"phonemes": phonemes, "voice": voice, "speed": speed} + + # Make POST request to generate audio + response = requests.post( + "http://localhost:8880/text/generate_from_phonemes", json=payload + ) + + # Raise exception for error status codes + response.raise_for_status() + + return response.content + + +def main(): + # Example texts to convert + examples = [ + "Hello world! Welcome to the phoneme generation system.", + "How are you today? I am doing reasonably well, thank you for asking", + """This is a test of the phoneme generation system. Do not be alarmed. + This is only a test. If this were a real phoneme emergency, ' + you would be instructed to a phoneme shelter in your area.""", + ] + + print("Generating phonemes and audio for example texts...\n") + + # Create output directory in same directory as script + output_dir = SCRIPT_DIR / "output" + output_dir.mkdir(exist_ok=True) + + for i, text in enumerate(examples): + print(f"{len(text)}: Input text: {text}") + try: + # Get phonemes + phonemes, tokens = get_phonemes(text) + print(f"{len(phonemes)} Phonemes: {phonemes}") + print(f"{len(tokens)} Tokens: {tokens}") + + # Generate audio from phonemes + print("Generating audio...") + audio_bytes = generate_audio_from_phonemes(phonemes) + + if audio_bytes: + # Save audio file + output_path = output_dir / f"example_{i+1}.wav" + with output_path.open("wb") as f: + f.write(audio_bytes) + print(f"Audio saved to: {output_path}") + + print() + + except requests.RequestException as e: + print(f"Error: {e}\n") + + +if __name__ == "__main__": + main() diff --git a/examples/stream_tts_playback.py b/examples/stream_tts_playback.py index 70999a8..d231fe7 100644 --- a/examples/stream_tts_playback.py +++ b/examples/stream_tts_playback.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 -import requests -import numpy as np -import sounddevice as sd -import time import os +import time import wave +import numpy as np +import requests +import sounddevice as sd + + def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): """Stream TTS audio and play it back in real-time""" - + print("\nStarting TTS stream request...") start_time = time.time() - + # Initialize variables sample_rate = 24000 # Known sample rate for Kokoro audio_started = False @@ -19,17 +21,17 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): total_bytes = 0 first_chunk_time = None all_audio_data = bytearray() # Raw PCM audio data - + # Start sounddevice stream with buffer stream = sd.OutputStream( samplerate=sample_rate, channels=1, dtype=np.int16, blocksize=1024, # Buffer size in samples - latency='low' # Request low latency + latency="low", # Request low latency ) stream.start() - + # Make streaming request to API try: response = requests.post( @@ -39,39 +41,45 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): "input": text, "voice": voice, "response_format": "pcm", - "stream": True + "stream": True, }, stream=True, - timeout=1800 + timeout=1800, ) response.raise_for_status() print(f"Request started successfully after {time.time() - start_time:.2f}s") - + # Process streaming response with smaller chunks for lower latency - for chunk in response.iter_content(chunk_size=512): # 512 bytes = 256 samples at 16-bit + for chunk in response.iter_content( + chunk_size=512 + ): # 512 bytes = 256 samples at 16-bit if chunk: chunk_count += 1 total_bytes += len(chunk) - + # Handle first chunk if not audio_started: first_chunk_time = time.time() - print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s") + print( + f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s" + ) print(f"First chunk size: {len(chunk)} bytes") audio_started = True - + # Convert bytes to numpy array and play audio_chunk = np.frombuffer(chunk, dtype=np.int16) stream.write(audio_chunk) - + # Accumulate raw audio data all_audio_data.extend(chunk) - + # Log progress every 10 chunks - if chunk_count % 10 == 0: + if chunk_count % 100 == 0: elapsed = time.time() - start_time - print(f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed") - + print( + f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed" + ) + # Final stats total_time = time.time() - start_time print(f"\nStream complete:") @@ -79,21 +87,21 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): print(f"Total data: {total_bytes/1024:.1f}KB") print(f"Total time: {total_time:.2f}s") print(f"Average speed: {(total_bytes/1024)/total_time:.1f}KB/s") - + # Save as WAV file if output_file: print(f"\nWriting audio to {output_file}") - with wave.open(output_file, 'wb') as wav_file: + with wave.open(output_file, "wb") as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 2 bytes per sample (16-bit) wav_file.setframerate(sample_rate) wav_file.writeframes(all_audio_data) print(f"Saved {len(all_audio_data)} bytes of audio data") - + # Clean up stream.stop() stream.close() - + except requests.exceptions.ConnectionError as e: print(f"Connection error - Is the server running? Error: {str(e)}") stream.stop() @@ -103,23 +111,27 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): stream.stop() stream.close() + def main(): # Load sample text from HG Wells script_dir = os.path.dirname(os.path.abspath(__file__)) - wells_path = os.path.join(script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt") + wells_path = os.path.join( + script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt" + ) output_path = os.path.join(script_dir, "output.wav") - + with open(wells_path, "r", encoding="utf-8") as f: full_text = f.read() # Take first few paragraphs text = " ".join(full_text.split("\n\n")[:2]) - + print("\nStarting TTS stream playback...") print(f"Text length: {len(text)} characters") print("\nFirst 100 characters:") print(text[:100] + "...") - + play_streaming_tts(text, output_file=output_path) + if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index ec659ec..cc3e135 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ sqlalchemy==2.0.27 transformers==4.47.1 numpy==2.2.1 scipy==1.14.1 +onnxruntime==1.20.1 # Audio processing soundfile==0.13.0 diff --git a/ui/app.py b/ui/app.py index 96aae35..8920a7a 100644 --- a/ui/app.py +++ b/ui/app.py @@ -1,3 +1,9 @@ +import warnings + +# Filter out Gradio Dropdown warnings about values not in choices +#TODO: Warning continues to be displayed, though it isn't breaking anything +warnings.filterwarnings('ignore', category=UserWarning, module='gradio.components.dropdown') + from lib.interface import create_interface if __name__ == "__main__": diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py index 444d0f8..2046b32 100644 --- a/ui/lib/components/model.py +++ b/ui/lib/components/model.py @@ -21,8 +21,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di voice_input = gr.Dropdown( choices=voice_ids, label="Voice", - value=voice_ids[0] if voice_ids else None, + value=voice_ids[0] if voice_ids else None, # Set default value to first item if available interactive=True, + allow_custom_value=True, # Allow temporary values during updates ) format_input = gr.Dropdown( choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3" diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py index e25601d..640deba 100644 --- a/ui/lib/components/output.py +++ b/ui/lib/components/output.py @@ -12,12 +12,16 @@ def create_output_column() -> Tuple[gr.Column, dict]: audio_output = gr.Audio(label="Generated Speech", type="filepath") gr.Markdown("### Generated Files") + # Initialize dropdown with empty choices first output_files = gr.Dropdown( label="Previous Outputs", - choices=files.list_output_files(), + choices=[], value=None, - allow_custom_value=False, + allow_custom_value=True, + interactive=True, ) + # Then update choices after component creation + output_files.choices = files.list_output_files() play_btn = gr.Button("▶️ Play Selected", size="sm") diff --git a/ui/lib/files.py b/ui/lib/files.py index 867f4f4..5495ea9 100644 --- a/ui/lib/files.py +++ b/ui/lib/files.py @@ -12,9 +12,9 @@ def list_input_files() -> List[str]: def list_output_files() -> List[str]: """List all output audio files.""" + # Just return filenames since paths will be different inside/outside container return [ - os.path.join(OUTPUTS_DIR, f) - for f in os.listdir(OUTPUTS_DIR) + f for f in os.listdir(OUTPUTS_DIR) if any(f.endswith(ext) for ext in AUDIO_FORMATS) ] diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py index eba6cda..30062a0 100644 --- a/ui/lib/handlers.py +++ b/ui/lib/handlers.py @@ -1,6 +1,5 @@ import os import shutil - import gradio as gr from . import api, files @@ -97,11 +96,12 @@ def generate_from_text(text, voice, format, speed): gr.Warning("Failed to generate speech. Please try again.") return [None, gr.update(choices=files.list_output_files())] + # Update list and select the newly generated file + output_files = files.list_output_files() + last_file = output_files[-1] if output_files else None return [ result, - gr.update( - choices=files.list_output_files(), value=os.path.basename(result) - ), + gr.update(choices=output_files, value=last_file), ] def generate_from_file(selected_file, voice, format, speed): @@ -121,16 +121,19 @@ def generate_from_file(selected_file, voice, format, speed): gr.Warning("Failed to generate speech. Please try again.") return [None, gr.update(choices=files.list_output_files())] + # Update list and select the newly generated file + output_files = files.list_output_files() + last_file = output_files[-1] if output_files else None return [ result, - gr.update( - choices=files.list_output_files(), value=os.path.basename(result) - ), + gr.update(choices=output_files, value=last_file), ] - def play_selected(file_path): - if file_path and os.path.exists(file_path): - return gr.update(value=file_path, visible=True) + def play_selected(filename): + if filename: + file_path = os.path.join(files.OUTPUTS_DIR, filename) + if os.path.exists(file_path): + return gr.update(value=file_path, visible=True) return gr.update(visible=False) def clear_files(voice, format, speed): diff --git a/ui/tests/test_components.py b/ui/tests/test_components.py index b125cb7..d9576c0 100644 --- a/ui/tests/test_components.py +++ b/ui/tests/test_components.py @@ -54,7 +54,7 @@ def test_model_column_default_values(): def test_model_column_no_voices(): """Test model column creation with no voice IDs""" - _, components = create_model_column() + _, components = create_model_column([]) assert components["voice"].choices == [] assert components["voice"].value is None @@ -96,7 +96,7 @@ def test_output_column_configuration(): # Test output files dropdown assert components["output_files"].label == "Previous Outputs" - assert components["output_files"].allow_custom_value is False + assert components["output_files"].allow_custom_value is True # Test play button assert components["play_btn"].value == "▶️ Play Selected"