Source code for pipecat.services.whisper.base_stt

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Base class for Whisper-based speech-to-text services.

This module provides common functionality for services implementing the Whisper API
interface, including language mapping, metrics generation, and error handling.
"""

from collections.abc import AsyncGenerator
from dataclasses import dataclass, field

from loguru import logger
from openai import AsyncOpenAI
from openai.types.audio import Transcription

from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven
from pipecat.services.stt_latency import WHISPER_TTFS_P99
from pipecat.services.stt_service import SegmentedSTTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt


[docs] @dataclass class BaseWhisperSTTSettings(STTSettings): """Settings for BaseWhisperSTTService. Parameters: prompt: Optional text to guide the model's style or continue a previous segment. temperature: Sampling temperature between 0 and 1. """ prompt: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) temperature: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
[docs] def language_to_whisper_language(language: Language) -> str | None: """Maps pipecat Language enum to Whisper API language codes. Language support for Whisper API. Docs: https://platform.openai.com/docs/guides/speech-to-text#supported-languages Args: language: A Language enum value representing the input language. Returns: str or None: The corresponding Whisper language code, or None if not supported. """ LANGUAGE_MAP = { Language.AF: "af", Language.AR: "ar", Language.HY: "hy", Language.AZ: "az", Language.BE: "be", Language.BS: "bs", Language.BG: "bg", Language.CA: "ca", Language.ZH: "zh", Language.HR: "hr", Language.CS: "cs", Language.DA: "da", Language.NL: "nl", Language.EN: "en", Language.ET: "et", Language.FI: "fi", Language.FR: "fr", Language.GL: "gl", Language.DE: "de", Language.EL: "el", Language.HE: "he", Language.HI: "hi", Language.HU: "hu", Language.IS: "is", Language.ID: "id", Language.IT: "it", Language.JA: "ja", Language.KN: "kn", Language.KK: "kk", Language.KO: "ko", Language.LV: "lv", Language.LT: "lt", Language.MK: "mk", Language.MS: "ms", Language.MR: "mr", Language.MI: "mi", Language.NE: "ne", Language.NO: "no", Language.FA: "fa", Language.PL: "pl", Language.PT: "pt", Language.RO: "ro", Language.RU: "ru", Language.SR: "sr", Language.SK: "sk", Language.SL: "sl", Language.ES: "es", Language.SW: "sw", Language.SV: "sv", Language.TL: "tl", Language.TA: "ta", Language.TH: "th", Language.TR: "tr", Language.UK: "uk", Language.UR: "ur", Language.VI: "vi", Language.CY: "cy", } return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
[docs] class BaseWhisperSTTService(SegmentedSTTService): """Base class for Whisper-based speech-to-text services. Provides common functionality for services implementing the Whisper API interface, including metrics generation and error handling. """ Settings = BaseWhisperSTTSettings _settings: Settings
[docs] def __init__( self, *, model: str | None = None, api_key: str | None = None, base_url: str | None = None, language: Language | None = None, prompt: str | None = None, temperature: float | None = None, include_prob_metrics: bool = False, push_empty_transcripts: bool = False, settings: Settings | None = None, ttfs_p99_latency: float | None = WHISPER_TTFS_P99, **kwargs, ): """Initialize the Whisper STT service. Args: model: Name of the Whisper model to use. .. deprecated:: 0.0.105 Use ``settings=BaseWhisperSTTService.Settings(model=...)`` instead. api_key: Service API key. Defaults to None. base_url: Service API base URL. Defaults to None. language: Language of the audio input. .. deprecated:: 0.0.105 Use ``settings=BaseWhisperSTTService.Settings(language=...)`` instead. prompt: Optional text to guide the model's style or continue a previous segment. .. deprecated:: 0.0.105 Use ``settings=BaseWhisperSTTService.Settings(prompt=...)`` instead. temperature: Sampling temperature between 0 and 1. .. deprecated:: 0.0.105 Use ``settings=BaseWhisperSTTService.Settings(temperature=...)`` instead. include_prob_metrics: If True, enables probability metrics in API response. Each service implements this differently (see child classes). Defaults to False. push_empty_transcripts: If true, allow empty `TranscriptionFrame` frames to be pushed downstream instead of discarding them. This is intended for situations where VAD fires even though the user did not speak. In these cases, it is useful to know that nothing was transcribed so that the agent can resume speaking, instead of waiting longer for a transcription. Defaults to False. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. ttfs_p99_latency: P99 latency from speech end to final transcript in seconds. Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark **kwargs: Additional arguments passed to SegmentedSTTService. """ # --- 1. Hardcoded defaults --- default_settings = self.Settings( model=None, language=None, prompt=None, temperature=None, ) # --- 2. Deprecated direct-arg overrides --- if model is not None: self._warn_init_param_moved_to_settings("model", "model") default_settings.model = model if language is not None: self._warn_init_param_moved_to_settings("language", "language") default_settings.language = language if prompt is not None: self._warn_init_param_moved_to_settings("prompt", "prompt") default_settings.prompt = prompt if temperature is not None: self._warn_init_param_moved_to_settings("temperature", "temperature") default_settings.temperature = temperature # --- 3. (no params object for this service) --- # --- 4. Settings delta (canonical API, always wins) --- if settings is not None: default_settings.apply_update(settings) super().__init__( ttfs_p99_latency=ttfs_p99_latency, settings=default_settings, **kwargs, ) self._client = self._create_client(api_key, base_url) self._include_prob_metrics = include_prob_metrics self._push_empty_transcripts = push_empty_transcripts
def _create_client(self, api_key: str | None, base_url: str | None): return AsyncOpenAI(api_key=api_key, base_url=base_url)
[docs] def can_generate_metrics(self) -> bool: """Whether this service can generate processing metrics. Returns: bool: True, as this service supports metric generation. """ return True
[docs] def language_to_service_language(self, language: Language) -> str | None: """Convert from pipecat Language to service language code. Args: language: The Language enum value to convert. Returns: str or None: The corresponding service language code, or None if not supported. """ return language_to_whisper_language(language)
@traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Language | None = None ): """Handle a transcription result with tracing.""" pass
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Transcribe audio data to text. Args: audio: Raw audio data to transcribe. Yields: Frame: Either a TranscriptionFrame containing the transcribed text or an ErrorFrame if transcription fails. """ try: await self.start_processing_metrics() response = await self._transcribe(audio) await self.stop_processing_metrics() text = response.text.strip() if not text: logger.warning("Received empty transcription from API") if text or self._push_empty_transcripts: await self._handle_transcription(text, True, self._settings.language) logger.debug(f"Transcription: [{text}]") yield TranscriptionFrame( text, self._user_id, time_now_iso8601(), result=response, ) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")
async def _transcribe(self, audio: bytes) -> Transcription: """Transcribe audio data to text. Args: audio: Raw audio data in WAV format. Returns: Transcription: Object containing the transcribed text. Raises: NotImplementedError: Must be implemented by subclasses. """ raise NotImplementedError