diff --git a/Dockerfile b/Dockerfile index 831bd307..de2a0585 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,7 +35,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # PermissionError: [Errno 13] Permission denied: '/home/ubuntu/.cache/huggingface/hub' # This error occurs because the volume is mounted as root and the `ubuntu` user doesn't have permission to write to it. Pre-creating the directory solves this issue. RUN mkdir -p $HOME/.cache/huggingface/hub -ENV WHISPER__MODEL=Systran/faster-whisper-large-v3 ENV UVICORN_HOST=0.0.0.0 ENV UVICORN_PORT=8000 ENV PATH="$HOME/speaches/.venv/bin:$PATH" diff --git a/compose.cpu.yaml b/compose.cpu.yaml index 25beaedd..bee54c5b 100644 --- a/compose.cpu.yaml +++ b/compose.cpu.yaml @@ -9,8 +9,6 @@ services: build: args: BASE_IMAGE: ubuntu:24.04 - environment: - - WHISPER__MODEL=Systran/faster-whisper-small volumes: - hf-hub-cache:/home/ubuntu/.cache/huggingface/hub volumes: diff --git a/compose.cuda.yaml b/compose.cuda.yaml index 0d3b8482..e8b09ef4 100644 --- a/compose.cuda.yaml +++ b/compose.cuda.yaml @@ -10,8 +10,6 @@ services: build: args: BASE_IMAGE: nvidia/cuda:12.6.3-cudnn-runtime-ubuntu24.04 - environment: - - WHISPER__MODEL=Systran/faster-whisper-large-v3 volumes: - hf-hub-cache:/home/ubuntu/.cache/huggingface/hub deploy: diff --git a/examples/youtube/script.sh b/examples/youtube/script.sh index df0908ff..1ce01a16 100755 --- a/examples/youtube/script.sh +++ b/examples/youtube/script.sh @@ -3,18 +3,18 @@ set -e # NOTE: do not use any distil-* model other than the large ones as they don't work on long audio files for some reason. -export WHISPER__MODEL=Systran/faster-distil-whisper-large-v3 # or Systran/faster-whisper-tiny.en if you are running on a CPU for a faster inference. +export TRANSCRIPTION_MODEL=Systran/faster-distil-whisper-large-v3 # or Systran/faster-whisper-tiny.en if you are running on a CPU for a faster inference. # Ensure you have `speaches` running. If this is your first time running it expect to wait up-to a minute for the model to be downloaded and loaded into memory. You can run `curl localhost:8000/health` to check if the server is ready or watch the logs with `docker logs -f `. -docker run --detach --gpus=all --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub --env WHISPER__MODEL=$WHISPER__MODEL ghcr.io/speaches-ai/speaches:latest-cuda +docker run --detach --gpus=all --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub ghcr.io/speaches-ai/speaches:latest-cuda # or you can run it on a CPU -# docker run --detach --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub --env WHISPER__MODEL=$WHISPER__MODEL ghcr.io/speaches-ai/speaches:latest-cpu +# docker run --detach --publish 8000:8000 --volume hf-hub-cache:/home/ubuntu/.cache/huggingface/hub ghcr.io/speaches-ai/speaches:latest-cpu # Download the audio from a YouTube video. In this example I'm downloading "The Evolution of the Operating System" by Asionometry YouTube channel. I highly checking this channel out, the guy produces very high content. If you don't have `youtube-dl`, you'll have to install it. https://github.com/ytdl-org/youtube-dl youtube-dl --extract-audio --audio-format mp3 -o the-evolution-of-the-operating-system.mp3 'https://www.youtube.com/watch?v=1lG7lFLXBIs' # Make a request to the API to transcribe the audio. The response will be streamed to the terminal and saved to a file. The video is 30 minutes long, so it might take a while to transcribe, especially if you are running this on a CPU. `Systran/faster-distil-whisper-large-v3` takes ~30 seconds on Nvidia L4. `Systran/faster-whisper-tiny.en` takes ~1 minute on Ryzen 7 7700X. The .txt file in the example was transcribed using `Systran/faster-distil-whisper-large-v3`. -curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt +curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "model=$TRANSCRIPTION_MODEL" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt # Here I'm using `aichat` which is a CLI LLM client. You could use any other client that supports attaching/uploading files. https://github.com/sigoden/aichat aichat -m openai:gpt-4o -f the-evolution-of-the-operating-system.txt 'What companies are mentioned in the following Youtube video transcription? Responed with just a list of names' diff --git a/src/speaches/config.py b/src/speaches/config.py index 5c367bbd..33837843 100644 --- a/src/speaches/config.py +++ b/src/speaches/config.py @@ -21,14 +21,6 @@ class WhisperConfig(BaseModel): """See https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/transcribe.py#L599.""" - model: str = Field(default="Systran/faster-whisper-small") - """ - Default HuggingFace model to use for transcription. Note, the model must support being ran using CTranslate2. - This model will be used if no model is specified in the request. - - Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran - You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2 - """ inference_device: Device = "auto" device_index: int | list[int] = 0 compute_type: Quantization = "default" # TODO: should this even be a configuration option? @@ -52,7 +44,7 @@ class Config(BaseSettings): Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields. To populate nested, the environment should be prefixed with the nested field name and an underscore. For example, - the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER__MODEL`(note the double underscore) to `whisper.model`, to set quantization to int8, use `WHISPER__COMPUTE_TYPE=int8`, etc. + the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER__INFERENCE_DEVICE`(note the double underscore) to `whisper.inference_device`, to set quantization to int8, use `WHISPER__COMPUTE_TYPE=int8`, etc. """ model_config = SettingsConfigDict(env_nested_delimiter="__") diff --git a/src/speaches/realtime/session.py b/src/speaches/realtime/session.py index 284c2d50..c789d76b 100644 --- a/src/speaches/realtime/session.py +++ b/src/speaches/realtime/session.py @@ -36,7 +36,7 @@ voice="alloy", input_audio_format="pcm16", output_audio_format="pcm16", - input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small.en"), # changed + input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small"), # changed turn_detection=DEFAULT_TURN_DETECTION, temperature=0.8, tools=[], diff --git a/src/speaches/routers/stt.py b/src/speaches/routers/stt.py index d9131296..5e59c900 100644 --- a/src/speaches/routers/stt.py +++ b/src/speaches/routers/stt.py @@ -11,7 +11,7 @@ ) from fastapi.responses import StreamingResponse from faster_whisper.transcribe import BatchedInferencePipeline, TranscriptionInfo -from pydantic import AfterValidator, Field +from pydantic import Field from speaches.api_types import ( DEFAULT_TIMESTAMP_GRANULARITIES, @@ -21,7 +21,7 @@ TimestampGranularities, TranscriptionSegment, ) -from speaches.dependencies import AudioFileDependency, ConfigDependency, ModelManagerDependency, get_config +from speaches.dependencies import AudioFileDependency, ConfigDependency, ModelManagerDependency from speaches.text_utils import segments_to_srt, segments_to_text, segments_to_vtt logger = logging.getLogger(__name__) @@ -91,21 +91,8 @@ def segment_responses() -> Generator[str, None, None]: return StreamingResponse(segment_responses(), media_type="text/event-stream") -def handle_default_openai_model(model_name: str) -> str: - """Exists because some callers may not be able override the default("whisper-1") model name. - - For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. - """ - config = get_config() # HACK - if model_name == "whisper-1": - logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") - return config.whisper.model - return model_name - - ModelName = Annotated[ str, - AfterValidator(handle_default_openai_model), Field( description="The ID of the model. You can get a list of available models by calling `/v1/models`.", examples=[ @@ -124,15 +111,13 @@ def translate_file( config: ConfigDependency, model_manager: ModelManagerDependency, audio: AudioFileDependency, - model: Annotated[ModelName | None, Form()] = None, + model: Annotated[ModelName, Form()], prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat, Form()] = DEFAULT_RESPONSE_FORMAT, temperature: Annotated[float, Form()] = 0.0, stream: Annotated[bool, Form()] = False, vad_filter: Annotated[bool, Form()] = False, ) -> Response | StreamingResponse: - if model is None: - model = config.whisper.model with model_manager.load_model(model) as whisper: whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper segments, transcription_info = whisper_model.transcribe( @@ -173,7 +158,7 @@ def transcribe_file( model_manager: ModelManagerDependency, request: Request, audio: AudioFileDependency, - model: Annotated[ModelName | None, Form()] = None, + model: Annotated[ModelName, Form()], language: Annotated[str | None, Form()] = None, prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat, Form()] = DEFAULT_RESPONSE_FORMAT, @@ -187,8 +172,6 @@ def transcribe_file( hotwords: Annotated[str | None, Form()] = None, vad_filter: Annotated[bool, Form()] = False, ) -> Response | StreamingResponse: - if model is None: - model = config.whisper.model timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != "verbose_json": logger.warning( diff --git a/src/speaches/ui/tabs/stt.py b/src/speaches/ui/tabs/stt.py index d3e7a3e1..46ae8e3f 100644 --- a/src/speaches/ui/tabs/stt.py +++ b/src/speaches/ui/tabs/stt.py @@ -20,15 +20,10 @@ async def update_whisper_model_dropdown(request: gr.Request) -> gr.Dropdown: openai_client = openai_client_from_gradio_req(request, config) models = (await openai_client.models.list()).data model_names: list[str] = [model.id for model in models] - assert config.whisper.model in model_names recommended_models = {model for model in model_names if model.startswith("Systran")} other_models = [model for model in model_names if model not in recommended_models] model_names = list(recommended_models) + other_models - return gr.Dropdown( - choices=model_names, - label="Model", - value=config.whisper.model, - ) + return gr.Dropdown(choices=model_names, label="Model", value="Systran/faster-whisper-small") async def audio_task( http_client: httpx.AsyncClient, file_path: str, endpoint: str, temperature: float, model: str @@ -84,9 +79,9 @@ async def whisper_handler( with gr.Tab(label="Speech-to-Text") as tab: audio = gr.Audio(type="filepath") whisper_model_dropdown = gr.Dropdown( - choices=[config.whisper.model], + choices=["Systran/faster-whisper-small"], # TODO: does this need to be non-empty label="Model", - value=config.whisper.model, + value="Systran/faster-whisper-small", ) task_dropdown = gr.Dropdown( choices=[task.value for task in Task], diff --git a/tests/conftest.py b/tests/conftest.py index cb1443cd..b33bd229 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,8 @@ DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"] OPENAI_BASE_URL = "https://api.openai.com/v1" -DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en" # TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests -DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0) +DEFAULT_WHISPER_CONFIG = WhisperConfig(ttl=0) DEFAULT_CONFIG = Config( whisper=DEFAULT_WHISPER_CONFIG, # disable the UI as it slightly increases the app startup time due to the imports it's doing @@ -43,7 +42,6 @@ def pytest_configure() -> None: # NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory` @pytest.fixture def client() -> Generator[TestClient, None, None]: - os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" with TestClient(create_app()) as client: yield client diff --git a/tests/model_manager_test.py b/tests/model_manager_test.py index 6b84afff..829897b9 100644 --- a/tests/model_manager_test.py +++ b/tests/model_manager_test.py @@ -4,15 +4,15 @@ import pytest from speaches.config import Config, WhisperConfig -from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory +from tests.conftest import AclientFactory -MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable +MODEL = "Systran/faster-whisper-tiny.en" @pytest.mark.asyncio async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None: ttl = 5 - config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False) async with aclient_factory(config) as aclient: res = (await aclient.get("/api/ps")).json() assert len(res["models"]) == 0 @@ -27,7 +27,7 @@ async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None @pytest.mark.asyncio async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None: ttl = 5 - config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False) async with aclient_factory(config) as aclient: await aclient.post(f"/api/ps/{MODEL}") res = (await aclient.get("/api/ps")).json() @@ -69,7 +69,7 @@ async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None: @pytest.mark.asyncio async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None: ttl = 0 - config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False) async with aclient_factory(config) as aclient: async with await anyio.open_file("audio.wav", "rb") as f: data = await f.read() @@ -91,7 +91,7 @@ async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) @pytest.mark.asyncio async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None: ttl = -1 - config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False) async with aclient_factory(config) as aclient: res = await aclient.post(f"/api/ps/{MODEL}") assert res.status_code == 201 @@ -104,7 +104,7 @@ async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> No @pytest.mark.asyncio async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None: ttl = 0 - config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) + config = Config(whisper=WhisperConfig(ttl=ttl), enable_ui=False) async with aclient_factory(config) as aclient: async with await anyio.open_file("audio.wav", "rb") as f: data = await f.read() diff --git a/tests/sse_test.py b/tests/sse_test.py index 7371583b..8e4fd340 100644 --- a/tests/sse_test.py +++ b/tests/sse_test.py @@ -14,6 +14,7 @@ CreateTranscriptionResponseVerboseJson, ) +MODEL = "Systran/faster-whisper-tiny.en" FILE_PATHS = ["audio.wav"] # HACK ENDPOINTS = [ "/v1/audio/transcriptions", @@ -32,7 +33,7 @@ async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str data = await f.read() kwargs = { "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, - "data": {"response_format": "text", "stream": True}, + "data": {"model": MODEL, "response_format": "text", "stream": True}, } async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: async for event in event_source.aiter_sse(): @@ -48,7 +49,7 @@ async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str data = await f.read() kwargs = { "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, - "data": {"response_format": "json", "stream": True}, + "data": {"model": MODEL, "response_format": "json", "stream": True}, } async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: async for event in event_source.aiter_sse(): @@ -63,7 +64,7 @@ async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_p data = await f.read() kwargs = { "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, - "data": {"response_format": "verbose_json", "stream": True}, + "data": {"model": MODEL, "response_format": "verbose_json", "stream": True}, } async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: async for event in event_source.aiter_sse(): @@ -76,7 +77,7 @@ async def test_transcription_vtt(aclient: AsyncClient) -> None: data = await f.read() kwargs = { "files": {"file": ("audio.wav", data, "audio/wav")}, - "data": {"response_format": "vtt", "stream": False}, + "data": {"model": MODEL, "response_format": "vtt", "stream": False}, } response = await aclient.post("/v1/audio/transcriptions", **kwargs) assert response.status_code == 200 @@ -94,7 +95,7 @@ async def test_transcription_srt(aclient: AsyncClient) -> None: data = await f.read() kwargs = { "files": {"file": ("audio.wav", data, "audio/wav")}, - "data": {"response_format": "srt", "stream": False}, + "data": {"model": MODEL, "response_format": "srt", "stream": False}, } response = await aclient.post("/v1/audio/transcriptions", **kwargs) assert response.status_code == 200