Skip to content

Commit

Permalink
config: remove whisper.model option (#312 and #313)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi committed Feb 14, 2025
1 parent 6e936ef commit 4afe5f9
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 63 deletions.
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# PermissionError: [Errno 13] Permission denied: '/home/ubuntu/.cache/huggingface/hub'
# This error occurs because the volume is mounted as root and the `ubuntu` user doesn't have permission to write to it. Pre-creating the directory solves this issue.
RUN mkdir -p $HOME/.cache/huggingface/hub
ENV WHISPER__MODEL=Systran/faster-whisper-large-v3
ENV UVICORN_HOST=0.0.0.0
ENV UVICORN_PORT=8000
ENV PATH="$HOME/speaches/.venv/bin:$PATH"
Expand Down
2 changes: 0 additions & 2 deletions compose.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ services:
build:
args:
BASE_IMAGE: ubuntu:24.04
environment:
- WHISPER__MODEL=Systran/faster-whisper-small
volumes:
- hf-hub-cache:/home/ubuntu/.cache/huggingface/hub
volumes:
Expand Down
2 changes: 0 additions & 2 deletions compose.cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ services:
build:
args:
BASE_IMAGE: nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04
environment:
- WHISPER__MODEL=Systran/faster-whisper-large-v3
volumes:
- hf-hub-cache:/home/ubuntu/.cache/huggingface/hub
deploy:
Expand Down
8 changes: 4 additions & 4 deletions examples/youtube/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
set -e

# NOTE: do not use any distil-* model other than the large ones as they don't work on long audio files for some reason.
export WHISPER__MODEL=Systran/faster-distil-whisper-large-v3 # or Systran/faster-whisper-tiny.en if you are running on a CPU for a faster inference.
export TRANSCRIPTION_MODEL=Systran/faster-distil-whisper-large-v3 # or Systran/faster-whisper-tiny.en if you are running on a CPU for a faster inference.

# Ensure you have `speaches` running. If this is your first time running it expect to wait up-to a minute for the model to be downloaded and loaded into memory. You can run `curl localhost:8000/health` to check if the server is ready or watch the logs with `docker logs -f <container_id>`.
docker run --detach --gpus=all --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub --env WHISPER__MODEL=$WHISPER__MODEL ghcr.io/speaches-ai/speaches:latest-cuda
docker run --detach --gpus=all --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub ghcr.io/speaches-ai/speaches:latest-cuda
# or you can run it on a CPU
# docker run --detach --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub --env WHISPER__MODEL=$WHISPER__MODEL ghcr.io/speaches-ai/speaches:latest-cpu
# docker run --detach --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub ghcr.io/speaches-ai/speaches:latest-cpu

# Download the audio from a YouTube video. In this example I'm downloading "The Evolution of the Operating System" by Asionometry YouTube channel. I highly checking this channel out, the guy produces very high content. If you don't have `youtube-dl`, you'll have to install it. https://github.com/ytdl-org/youtube-dl
youtube-dl --extract-audio --audio-format mp3 -o the-evolution-of-the-operating-system.mp3 'https://www.youtube.com/watch?v=1lG7lFLXBIs'

# Make a request to the API to transcribe the audio. The response will be streamed to the terminal and saved to a file. The video is 30 minutes long, so it might take a while to transcribe, especially if you are running this on a CPU. `Systran/faster-distil-whisper-large-v3` takes ~30 seconds on Nvidia L4. `Systran/faster-whisper-tiny.en` takes ~1 minute on Ryzen 7 7700X. The .txt file in the example was transcribed using `Systran/faster-distil-whisper-large-v3`.
curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt
curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "model=$TRANSCRIPTION_MODEL" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt

# Here I'm using `aichat` which is a CLI LLM client. You could use any other client that supports attaching/uploading files. https://github.com/sigoden/aichat
aichat -m openai:gpt-4o -f the-evolution-of-the-operating-system.txt 'What companies are mentioned in the following Youtube video transcription? Responed with just a list of names'
Expand Down
10 changes: 1 addition & 9 deletions src/speaches/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@
class WhisperConfig(BaseModel):
"""See https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/transcribe.py#L599."""

model: str = Field(default="Systran/faster-whisper-small")
"""
Default HuggingFace model to use for transcription. Note, the model must support being ran using CTranslate2.
This model will be used if no model is specified in the request.
Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
"""
inference_device: Device = "auto"
device_index: int | list[int] = 0
compute_type: Quantization = "default" # TODO: should this even be a configuration option?
Expand All @@ -52,7 +44,7 @@ class Config(BaseSettings):
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER__MODEL`(note the double underscore) to `whisper.model`, to set quantization to int8, use `WHISPER__COMPUTE_TYPE=int8`, etc.
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER__INFERENCE_DEVICE`(note the double underscore) to `whisper.inference_device`, to set quantization to int8, use `WHISPER__COMPUTE_TYPE=int8`, etc.
"""

model_config = SettingsConfigDict(env_nested_delimiter="__")
Expand Down
2 changes: 1 addition & 1 deletion src/speaches/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
voice="alloy",
input_audio_format="pcm16",
output_audio_format="pcm16",
input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small.en"), # changed
input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small"), # changed
turn_detection=DEFAULT_TURN_DETECTION,
temperature=0.8,
tools=[],
Expand Down
25 changes: 4 additions & 21 deletions src/speaches/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from fastapi.responses import StreamingResponse
from faster_whisper.transcribe import BatchedInferencePipeline, TranscriptionInfo
from pydantic import AfterValidator, Field
from pydantic import Field

from speaches.api_types import (
DEFAULT_TIMESTAMP_GRANULARITIES,
Expand All @@ -21,7 +21,7 @@
TimestampGranularities,
TranscriptionSegment,
)
from speaches.dependencies import AudioFileDependency, ConfigDependency, ModelManagerDependency, get_config
from speaches.dependencies import AudioFileDependency, ConfigDependency, ModelManagerDependency
from speaches.text_utils import segments_to_srt, segments_to_text, segments_to_vtt

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,21 +91,8 @@ def segment_responses() -> Generator[str, None, None]:
return StreamingResponse(segment_responses(), media_type="text/event-stream")


def handle_default_openai_model(model_name: str) -> str:
"""Exists because some callers may not be able override the default("whisper-1") model name.
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
"""
config = get_config() # HACK
if model_name == "whisper-1":
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
return config.whisper.model
return model_name


ModelName = Annotated[
str,
AfterValidator(handle_default_openai_model),
Field(
description="The ID of the model. You can get a list of available models by calling `/v1/models`.",
examples=[
Expand All @@ -124,15 +111,13 @@ def translate_file(
config: ConfigDependency,
model_manager: ModelManagerDependency,
audio: AudioFileDependency,
model: Annotated[ModelName | None, Form()] = None,
model: Annotated[ModelName, Form()],
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = DEFAULT_RESPONSE_FORMAT,
temperature: Annotated[float, Form()] = 0.0,
stream: Annotated[bool, Form()] = False,
vad_filter: Annotated[bool, Form()] = False,
) -> Response | StreamingResponse:
if model is None:
model = config.whisper.model
with model_manager.load_model(model) as whisper:
whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
segments, transcription_info = whisper_model.transcribe(
Expand Down Expand Up @@ -173,7 +158,7 @@ def transcribe_file(
model_manager: ModelManagerDependency,
request: Request,
audio: AudioFileDependency,
model: Annotated[ModelName | None, Form()] = None,
model: Annotated[ModelName, Form()],
language: Annotated[str | None, Form()] = None,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = DEFAULT_RESPONSE_FORMAT,
Expand All @@ -187,8 +172,6 @@ def transcribe_file(
hotwords: Annotated[str | None, Form()] = None,
vad_filter: Annotated[bool, Form()] = False,
) -> Response | StreamingResponse:
if model is None:
model = config.whisper.model
timestamp_granularities = asyncio.run(get_timestamp_granularities(request))
if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != "verbose_json":
logger.warning(
Expand Down
11 changes: 3 additions & 8 deletions src/speaches/ui/tabs/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,10 @@ async def update_whisper_model_dropdown(request: gr.Request) -> gr.Dropdown:
openai_client = openai_client_from_gradio_req(request, config)
models = (await openai_client.models.list()).data
model_names: list[str] = [model.id for model in models]
assert config.whisper.model in model_names
recommended_models = {model for model in model_names if model.startswith("Systran")}
other_models = [model for model in model_names if model not in recommended_models]
model_names = list(recommended_models) + other_models
return gr.Dropdown(
choices=model_names,
label="Model",
value=config.whisper.model,
)
return gr.Dropdown(choices=model_names, label="Model", value="Systran/faster-whisper-small")

async def audio_task(
http_client: httpx.AsyncClient, file_path: str, endpoint: str, temperature: float, model: str
Expand Down Expand Up @@ -84,9 +79,9 @@ async def whisper_handler(
with gr.Tab(label="Speech-to-Text") as tab:
audio = gr.Audio(type="filepath")
whisper_model_dropdown = gr.Dropdown(
choices=[config.whisper.model],
choices=["Systran/faster-whisper-small"], # TODO: does this need to be non-empty
label="Model",
value=config.whisper.model,
value="Systran/faster-whisper-small",
)
task_dropdown = gr.Dropdown(
choices=[task.value for task in Task],
Expand Down
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
OPENAI_BASE_URL = "https://api.openai.com/v1"
DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests
DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
DEFAULT_WHISPER_CONFIG = WhisperConfig(ttl=0)
DEFAULT_CONFIG = Config(
whisper=DEFAULT_WHISPER_CONFIG,
# disable the UI as it slightly increases the app startup time due to the imports it's doing
Expand All @@ -43,7 +42,6 @@ def pytest_configure() -> None:
# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
@pytest.fixture
def client() -> Generator[TestClient, None, None]:
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
with TestClient(create_app()) as client:
yield client

Expand Down
14 changes: 7 additions & 7 deletions tests/model_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import pytest

from speaches.config import Config, WhisperConfig
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
from tests.conftest import AclientFactory

MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable
MODEL = "Systran/faster-whisper-tiny.en"


@pytest.mark.asyncio
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
ttl = 5
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
Expand All @@ -27,7 +27,7 @@ async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None
@pytest.mark.asyncio
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
ttl = 5
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
await aclient.post(f"/api/ps/{MODEL}")
res = (await aclient.get("/api/ps")).json()
Expand Down Expand Up @@ -69,7 +69,7 @@ async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
@pytest.mark.asyncio
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
ttl = 0
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
Expand All @@ -91,7 +91,7 @@ async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory)
@pytest.mark.asyncio
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
ttl = -1
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = await aclient.post(f"/api/ps/{MODEL}")
assert res.status_code == 201
Expand All @@ -104,7 +104,7 @@ async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> No
@pytest.mark.asyncio
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
ttl = 0
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
Expand Down
11 changes: 6 additions & 5 deletions tests/sse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CreateTranscriptionResponseVerboseJson,
)

MODEL = "Systran/faster-whisper-tiny.en"
FILE_PATHS = ["audio.wav"] # HACK
ENDPOINTS = [
"/v1/audio/transcriptions",
Expand All @@ -32,7 +33,7 @@ async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "text", "stream": True},
"data": {"model": MODEL, "response_format": "text", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
Expand All @@ -48,7 +49,7 @@ async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "json", "stream": True},
"data": {"model": MODEL, "response_format": "json", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
Expand All @@ -63,7 +64,7 @@ async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_p
data = await f.read()
kwargs = {
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
"data": {"response_format": "verbose_json", "stream": True},
"data": {"model": MODEL, "response_format": "verbose_json", "stream": True},
}
async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source:
async for event in event_source.aiter_sse():
Expand All @@ -76,7 +77,7 @@ async def test_transcription_vtt(aclient: AsyncClient) -> None:
data = await f.read()
kwargs = {
"files": {"file": ("audio.wav", data, "audio/wav")},
"data": {"response_format": "vtt", "stream": False},
"data": {"model": MODEL, "response_format": "vtt", "stream": False},
}
response = await aclient.post("/v1/audio/transcriptions", **kwargs)
assert response.status_code == 200
Expand All @@ -94,7 +95,7 @@ async def test_transcription_srt(aclient: AsyncClient) -> None:
data = await f.read()
kwargs = {
"files": {"file": ("audio.wav", data, "audio/wav")},
"data": {"response_format": "srt", "stream": False},
"data": {"model": MODEL, "response_format": "srt", "stream": False},
}
response = await aclient.post("/v1/audio/transcriptions", **kwargs)
assert response.status_code == 200
Expand Down

0 comments on commit 4afe5f9

Please sign in to comment.