Skip to content

Commit

Permalink
Add model listing and retrieval endpoints with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
remsky committed Feb 10, 2025
1 parent d73ed87 commit 8ed2f2a
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
</p>

# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-66%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]()
[![Tests](https://img.shields.io/badge/tests-69%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-51%25-tan)]()
[![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)

[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
Expand Down
93 changes: 93 additions & 0 deletions api/src/routers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,99 @@ async def download_audio_file(filename: str):
)


@router.get("/models")
async def list_models():
"""List all available models"""
try:
# Create standard model list
models = [
{
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
{
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
}
]

return {
"object": "list",
"data": models
}
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve model list",
"type": "server_error",
},
)

@router.get("/models/{model}")
async def retrieve_model(model: str):
"""Retrieve a specific model"""
try:
# Define available models
models = {
"tts-1": {
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
"tts-1-hd": {
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
},
"kokoro": {
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro"
}
}

# Check if requested model exists
if model not in models:
raise HTTPException(
status_code=404,
detail={
"error": "model_not_found",
"message": f"Model '{model}' not found",
"type": "invalid_request_error"
}
)

# Return the specific model
return models[model]
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving model {model}: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve model information",
"type": "server_error",
},
)

@router.get("/audio/voices")
async def list_voices():
"""List all available voices for text-to-speech"""
Expand Down
43 changes: 43 additions & 0 deletions api/tests/test_openai_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,49 @@ def test_load_openai_mappings_file_not_found():
assert mappings == {"models": {}, "voices": {}}


def test_list_models(mock_openai_mappings):
"""Test listing available models endpoint"""
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro

# Verify all expected models are present
model_ids = [model["id"] for model in data["data"]]
assert "tts-1" in model_ids
assert "tts-1-hd" in model_ids
assert "kokoro" in model_ids

# Verify model format
for model in data["data"]:
assert model["object"] == "model"
assert "created" in model
assert model["owned_by"] == "kokoro"


def test_retrieve_model(mock_openai_mappings):
"""Test retrieving a specific model endpoint"""
# Test successful model retrieval
response = client.get("/v1/models/tts-1")
assert response.status_code == 200
data = response.json()
assert data["id"] == "tts-1"
assert data["object"] == "model"
assert data["owned_by"] == "kokoro"
assert "created" in data

# Test non-existent model
response = client.get("/v1/models/nonexistent-model")
assert response.status_code == 404
error = response.json()
assert error["detail"]["error"] == "model_not_found"
assert "not found" in error["detail"]["message"]
assert error["detail"]["type"] == "invalid_request_error"



@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
Expand Down
12 changes: 12 additions & 0 deletions debug.http
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,16 @@ Accept: application/json
# Shows active ONNX sessions, CUDA stream usage, and session ages
# Useful for debugging resource exhaustion issues
GET http://localhost:8880/debug/session_pools
Accept: application/json

### List Available Models
# Returns list of all available models in OpenAI format
# Response includes tts-1, tts-1-hd, and kokoro models
GET http://localhost:8880/v1/models
Accept: application/json

### Get Specific Model
# Returns same model list as above for compatibility
# Works with any model name (e.g., tts-1, tts-1-hd, kokoro)
GET http://localhost:8880/v1/models/tts-1
Accept: application/json

0 comments on commit 8ed2f2a

Please sign in to comment.