Source code for pipecat.services.whisper.stt

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Whisper speech-to-text services with locally-downloaded models.

This module implements Whisper transcription using locally-downloaded models,
supporting both Faster Whisper and MLX Whisper backends for efficient inference.
"""

import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING

import numpy as np
from loguru import logger
from typing_extensions import override

from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven, assert_given
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

if TYPE_CHECKING:
    try:
        from faster_whisper import WhisperModel
    except ModuleNotFoundError as e:
        logger.error(f"Exception: {e}")
        logger.error("In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.")
        raise Exception(f"Missing module: {e}")

    try:
        import mlx_whisper  # noqa: F401
    except ModuleNotFoundError as e:
        logger.error(f"Exception: {e}")
        logger.error("In order to use Whisper, you need to `pip install pipecat-ai[mlx-whisper]`.")
        raise Exception(f"Missing module: {e}")


[docs] class Model(Enum): """Whisper model selection options for Faster Whisper. Provides various model sizes and specializations for speech recognition, balancing quality and performance based on use case requirements. Parameters: TINY: Smallest multilingual model, fastest inference. BASE: Basic multilingual model, good speed/quality balance. SMALL: Small multilingual model, better speed/quality balance than BASE. MEDIUM: Medium-sized multilingual model, better quality. LARGE: Best quality multilingual model, slower inference. LARGE_V3_TURBO: Fast multilingual model, slightly lower quality than LARGE. DISTIL_LARGE_V2: Fast multilingual distilled model. DISTIL_MEDIUM_EN: Fast English-only distilled model. """ # Multilingual models TINY = "tiny" BASE = "base" SMALL = "small" MEDIUM = "medium" LARGE = "large-v3" LARGE_V3_TURBO = "deepdml/faster-whisper-large-v3-turbo-ct2" DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2" # English-only models DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
[docs] class MLXModel(Enum): """MLX Whisper model selection options for Apple Silicon. Provides various model sizes optimized for Apple Silicon hardware, including quantized variants for improved performance. Parameters: TINY: Smallest multilingual model for MLX. MEDIUM: Medium-sized multilingual model for MLX. LARGE_V3: Best quality multilingual model for MLX. LARGE_V3_TURBO: Finetuned, pruned Whisper large-v3, much faster with slightly lower quality. DISTIL_LARGE_V3: Fast multilingual distilled model for MLX. LARGE_V3_TURBO_Q4: LARGE_V3_TURBO quantized to Q4 for reduced memory usage. """ # Multilingual models TINY = "mlx-community/whisper-tiny" MEDIUM = "mlx-community/whisper-medium-mlx" LARGE_V3 = "mlx-community/whisper-large-v3-mlx" LARGE_V3_TURBO = "mlx-community/whisper-large-v3-turbo" DISTIL_LARGE_V3 = "mlx-community/distil-whisper-large-v3" LARGE_V3_TURBO_Q4 = "mlx-community/whisper-large-v3-turbo-q4"
[docs] def language_to_whisper_language(language: Language) -> str | None: """Maps pipecat Language enum to Whisper language codes. Args: language: A Language enum value representing the input language. Returns: str or None: The corresponding Whisper language code, or None if not supported. Note: Only includes languages officially supported by Whisper. """ LANGUAGE_MAP = { # Arabic Language.AR: "ar", # Bengali Language.BN: "bn", # Czech Language.CS: "cs", # Danish Language.DA: "da", # German Language.DE: "de", # Greek Language.EL: "el", # English Language.EN: "en", # Spanish Language.ES: "es", # Persian Language.FA: "fa", # Finnish Language.FI: "fi", # French Language.FR: "fr", # Hindi Language.HI: "hi", # Hungarian Language.HU: "hu", # Indonesian Language.ID: "id", # Italian Language.IT: "it", # Japanese Language.JA: "ja", # Korean Language.KO: "ko", # Dutch Language.NL: "nl", # Polish Language.PL: "pl", # Portuguese Language.PT: "pt", # Romanian Language.RO: "ro", # Russian Language.RU: "ru", # Slovak Language.SK: "sk", # Swedish Language.SV: "sv", # Thai Language.TH: "th", # Turkish Language.TR: "tr", # Ukrainian Language.UK: "uk", # Urdu Language.UR: "ur", # Vietnamese Language.VI: "vi", # Chinese Language.ZH: "zh", } return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
[docs] @dataclass class WhisperSTTSettings(STTSettings): """Settings for WhisperSTTService. Parameters: no_speech_prob: Probability threshold for filtering non-speech segments. """ no_speech_prob: float | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
[docs] @dataclass class WhisperMLXSTTSettings(STTSettings): """Settings for WhisperMLXSTTService. Parameters: no_speech_prob: Probability threshold for filtering non-speech segments. temperature: Sampling temperature (0.0-1.0). engine: Whisper engine identifier. """ no_speech_prob: float | _NotGiven = field(default_factory=lambda: NOT_GIVEN) temperature: float | _NotGiven = field(default_factory=lambda: NOT_GIVEN) engine: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
[docs] class WhisperSTTService(SegmentedSTTService): """Class to transcribe audio with a locally-downloaded Whisper model. This service uses Faster Whisper to perform speech-to-text transcription on audio segments. It supports multiple languages and various model sizes. """ Settings = WhisperSTTSettings _settings: Settings
[docs] def __init__( self, *, model: str | Model | None = None, device: str = "auto", compute_type: str = "default", no_speech_prob: float | None = None, language: Language | None = None, settings: Settings | None = None, **kwargs, ): """Initialize the Whisper STT service. Args: model: The Whisper model to use for transcription. Can be a Model enum or string. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTService.Settings(model=...)`` instead. device: The device to run inference on ('cpu', 'cuda', or 'auto'). Defaults to ``"auto"``. compute_type: The compute type for inference ('default', 'int8', 'int8_float16', etc.). Defaults to ``"default"``. no_speech_prob: Probability threshold for filtering out non-speech segments. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTService.Settings(no_speech_prob=...)`` instead. language: The default language for transcription. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTService.Settings(language=...)`` instead. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. **kwargs: Additional arguments passed to SegmentedSTTService. """ # --- 1. Hardcoded defaults --- default_settings = self.Settings( model=Model.DISTIL_MEDIUM_EN.value, language=Language.EN, no_speech_prob=0.4, ) # --- 2. Deprecated direct-arg overrides --- if model is not None: self._warn_init_param_moved_to_settings("model", "model") default_settings.model = model if isinstance(model, str) else model.value if no_speech_prob is not None: self._warn_init_param_moved_to_settings("no_speech_prob", "no_speech_prob") default_settings.no_speech_prob = no_speech_prob if language is not None: self._warn_init_param_moved_to_settings("language", "language") default_settings.language = language # --- 3. (no params object for this service) --- # --- 4. Settings delta (canonical API, always wins) --- if settings is not None: default_settings.apply_update(settings) super().__init__( settings=default_settings, **kwargs, ) # Init-only inference config self._device = device self._compute_type = compute_type self._model: WhisperModel | None = None self._load()
[docs] def can_generate_metrics(self) -> bool: """Indicates whether this service can generate metrics. Returns: bool: True, as this service supports metric generation. """ return True
[docs] def language_to_service_language(self, language: Language) -> str | None: """Convert from pipecat Language to Whisper language code. Args: language: The Language enum value to convert. Returns: str or None: The corresponding Whisper language code, or None if not supported. """ return language_to_whisper_language(language)
def _load(self): """Loads the Whisper model. Note: If this is the first time this model is being run, it will take time to download from the Hugging Face model hub. """ try: from faster_whisper import WhisperModel logger.debug("Loading Whisper model...") model_name = assert_given(self._settings.model) if model_name is None: raise ValueError("Whisper model must be specified") self._model = WhisperModel( model_name, device=self._device, compute_type=self._compute_type ) logger.debug("Loaded Whisper model") except ModuleNotFoundError as e: logger.error(f"Exception: {e}") logger.error("In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.") self._model = None @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Language | None = None ): """Handle a transcription result with tracing.""" pass
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Transcribe audio data using Whisper. Args: audio: Raw audio bytes in 16-bit PCM format. Yields: Frame: Either a TranscriptionFrame containing the transcribed text or an ErrorFrame if transcription fails. Note: The audio is expected to be 16-bit signed PCM data. The service will normalize it to float32 in the range [-1, 1]. """ if not self._model: yield ErrorFrame("Whisper model not available") return await self.start_processing_metrics() # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 language = assert_given(self._settings.language) segments, _ = await asyncio.to_thread( self._model.transcribe, audio_float, language=language ) text: str = "" no_speech_prob_threshold = assert_given(self._settings.no_speech_prob) for segment in segments: if ( no_speech_prob_threshold is not None and segment.no_speech_prob < no_speech_prob_threshold ): text += f"{segment.text} " await self.stop_processing_metrics() if text: await self._handle_transcription(text, True, language) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame( text, self._user_id, time_now_iso8601(), language, )
[docs] class WhisperSTTServiceMLX(WhisperSTTService): """Subclass of `WhisperSTTService` with MLX Whisper model support. This service uses MLX Whisper to perform speech-to-text transcription on audio segments. It's optimized for Apple Silicon and supports multiple languages and quantizations. """ Settings = WhisperMLXSTTSettings _settings: Settings
[docs] def __init__( self, *, model: str | MLXModel | None = None, no_speech_prob: float | None = None, language: Language | None = None, temperature: float | None = None, settings: Settings | None = None, **kwargs, ): """Initialize the MLX Whisper STT service. Args: model: The MLX Whisper model to use for transcription. Can be an MLXModel enum or string. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTServiceMLX.Settings(model=...)`` instead. no_speech_prob: Probability threshold for filtering out non-speech segments. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTServiceMLX.Settings(no_speech_prob=...)`` instead. language: The default language for transcription. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTServiceMLX.Settings(language=...)`` instead. temperature: Temperature for sampling. Can be a float or tuple of floats. .. deprecated:: 0.0.105 Use ``settings=WhisperSTTServiceMLX.Settings(temperature=...)`` instead. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. **kwargs: Additional arguments passed to SegmentedSTTService. """ # --- 1. Hardcoded defaults --- default_settings = self.Settings( model=MLXModel.TINY.value, language=Language.EN, no_speech_prob=0.6, temperature=0.0, engine="mlx", ) # --- 2. Deprecated direct-arg overrides --- if model is not None: self._warn_init_param_moved_to_settings("model", "model") default_settings.model = model if isinstance(model, str) else model.value if no_speech_prob is not None: self._warn_init_param_moved_to_settings("no_speech_prob", "no_speech_prob") default_settings.no_speech_prob = no_speech_prob if language is not None: self._warn_init_param_moved_to_settings("language", "language") default_settings.language = language if temperature is not None: self._warn_init_param_moved_to_settings("temperature", "temperature") default_settings.temperature = temperature # --- 3. (no params object for this service) --- # --- 4. Settings delta (canonical API, always wins) --- if settings is not None: default_settings.apply_update(settings) # Skip WhisperSTTService.__init__ and call its parent directly SegmentedSTTService.__init__( self, settings=default_settings, **kwargs, )
# No need to call _load() as MLX Whisper loads models on demand @override def _load(self): """MLX Whisper loads models on demand, so this is a no-op.""" pass @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Language | None = None ): """Handle a transcription result with tracing.""" pass
[docs] @override async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Transcribe audio data using MLX Whisper. The audio is expected to be 16-bit signed PCM data. MLX Whisper will handle the conversion internally. Args: audio: Raw audio bytes in 16-bit PCM format. Yields: Frame: Either a TranscriptionFrame containing the transcribed text or an ErrorFrame if transcription fails. """ try: import mlx_whisper await self.start_processing_metrics() # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 model_path = assert_given(self._settings.model) if model_path is None: raise ValueError("Whisper model must be specified") temperature = assert_given(self._settings.temperature) language = assert_given(self._settings.language) chunk = await asyncio.to_thread( mlx_whisper.transcribe, audio_float, path_or_hf_repo=model_path, temperature=temperature, language=language, ) text: str = "" no_speech_prob_threshold = assert_given(self._settings.no_speech_prob) for segment in chunk.get("segments", []): # Drop likely hallucinations if segment.get("compression_ratio", None) == 0.5555555555555556: continue if ( no_speech_prob_threshold is not None and segment.get("no_speech_prob", 0.0) < no_speech_prob_threshold ): text += f"{segment.get('text', '')} " if len(text.strip()) == 0: text = None await self.stop_processing_metrics() if text: await self._handle_transcription(text, True, language) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame( text, self._user_id, time_now_iso8601(), language, ) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")