Source code for pipecat.services.mistral.stt

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Mistral Speech-to-Text service implementation.

This module provides a real-time STT service that integrates with Mistral's
Voxtral Realtime transcription API using the Mistral SDK's RealtimeConnection.
"""

from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    VADUserStartedSpeakingFrame,
    VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import STTSettings, assert_given
from pipecat.services.stt_latency import MISTRAL_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

try:
    from mistralai.client import Mistral
    from mistralai.client.models import (
        AudioFormat,
        RealtimeTranscriptionError,
        RealtimeTranscriptionSessionCreated,
        TranscriptionStreamDone,
        TranscriptionStreamLanguage,
        TranscriptionStreamTextDelta,
    )
    from mistralai.extra.realtime import RealtimeConnection, UnknownRealtimeEvent
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Mistral STT, you need to `pip install pipecat-ai[mistral]`.")
    raise Exception(f"Missing module: {e}")


[docs] @dataclass class MistralSTTSettings(STTSettings): """Settings for MistralSTTService. Parameters: model: STT model identifier. language: Language hint for transcription. """ pass
[docs] class MistralSTTService(STTService): """Mistral Speech-to-Text service using the Voxtral Realtime API. This service uses the Mistral SDK's RealtimeConnection to stream audio and receive transcription events over WebSocket. It extends STTService directly (rather than WebsocketSTTService) because the SDK manages the WebSocket connection internally. Event handlers available: - on_connected: Called when a transcription session is created. - on_disconnected: Called when the connection is closed. - on_connection_error: Called when a transcription error occurs. Example:: @stt.event_handler("on_connected") async def on_connected(stt): logger.info("Mistral STT connected") """ Settings = MistralSTTSettings _settings: Settings
[docs] def __init__( self, *, api_key: str | None = None, base_url: str | None = None, sample_rate: int | None = None, target_streaming_delay_ms: int | None = None, ttfs_p99_latency: float | None = MISTRAL_TTFS_P99, settings: Settings | None = None, **kwargs, ): """Initialize Mistral STT service. Args: api_key: Mistral API key for authentication. base_url: Custom API endpoint URL. sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate. target_streaming_delay_ms: Streaming delay for accuracy/latency tradeoff. Higher values may improve accuracy at the cost of latency. ttfs_p99_latency: P99 latency from speech end to final transcript in seconds. Override for your deployment. settings: Runtime-updatable settings. **kwargs: Additional keyword arguments passed to STTService. """ default_settings = self.Settings( model="voxtral-mini-transcribe-realtime-2602", language=None, ) if settings is not None: default_settings.apply_update(settings) super().__init__( sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, settings=default_settings, **kwargs, ) self._client = Mistral(api_key=api_key, server_url=base_url) self._target_streaming_delay_ms = target_streaming_delay_ms self._connection: RealtimeConnection | None = None self._receive_task = None self._accumulated_text = "" self._detected_language: str | None = None
[docs] def can_generate_metrics(self) -> bool: """Check if the service can generate processing metrics. Returns: True, indicating metrics are supported. """ return True
[docs] async def start(self, frame: StartFrame): """Start the STT service and establish connection. Args: frame: Frame indicating service should start. """ await super().start(frame) await self._connect()
[docs] async def stop(self, frame: EndFrame): """Stop the STT service and close connection. Args: frame: Frame indicating service should stop. """ await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the STT service and close connection. Args: frame: Frame indicating service should be cancelled. """ await super().cancel(frame) await self._disconnect()
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): """Process incoming frames and handle speech events. Args: frame: The frame to process. direction: Direction of frame flow in the pipeline. """ await super().process_frame(frame, direction) if isinstance(frame, VADUserStartedSpeakingFrame): self._accumulated_text = "" await self._start_metrics() elif isinstance(frame, VADUserStoppedSpeakingFrame): if self._connection and not self._connection.is_closed: await self._connection.flush_audio()
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]: """Send audio data to Mistral for transcription. Args: audio: Raw audio bytes to transcribe. Yields: None - transcription results arrive via the receive events task. """ if not self._connection or self._connection.is_closed: await self._connect() await self._connection.send_audio(audio) yield None
async def _start_metrics(self): """Start performance metrics collection for transcription processing.""" await self.start_processing_metrics() async def _connect(self): """Establish a connection to the Mistral Realtime API.""" try: logger.debug(f"{self}: Connecting to Mistral STT") audio_format = AudioFormat( encoding="pcm_s16le", sample_rate=self.sample_rate, ) model = assert_given(self._settings.model) if model is None: raise ValueError("Mistral STT model must be specified") self._connection = await self._client.audio.realtime.connect( model=model, audio_format=audio_format, target_streaming_delay_ms=self._target_streaming_delay_ms, ) self._receive_task = self.create_task( self._receive_events(), name="mistral_stt_receive" ) except Exception as e: await self.push_error(error_msg=f"Error connecting to Mistral STT: {e}", exception=e) async def _disconnect(self): """Close the connection and cancel the receive task.""" if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None if self._connection and not self._connection.is_closed: try: logger.debug(f"{self}: Disconnecting from Mistral STT") await self._connection.close() except Exception as e: logger.warning(f"{self}: Error closing connection: {e}") finally: self._connection = None await self._call_event_handler("on_disconnected") async def _receive_events(self): """Background task: iterate connection events and handle them.""" try: async for event in self._connection.events(): if isinstance(event, RealtimeTranscriptionSessionCreated): logger.debug(f"{self}: Session created: {event.session}") await self._call_event_handler("on_connected") elif isinstance(event, TranscriptionStreamTextDelta): self._accumulated_text += event.text await self.push_frame( InterimTranscriptionFrame( self._accumulated_text, self._user_id, time_now_iso8601(), ) ) elif isinstance(event, TranscriptionStreamDone): if event.text: await self.push_frame( TranscriptionFrame( event.text, self._user_id, time_now_iso8601(), language=self._detected_language, ) ) await self._handle_transcription(event.text, True, self._detected_language) await self.stop_processing_metrics() self._accumulated_text = "" elif isinstance(event, TranscriptionStreamLanguage): self._detected_language = event.audio_language elif isinstance(event, RealtimeTranscriptionError): error_msg = event.error.message if event.error else "Unknown error" await self.push_error(error_msg=f"Mistral STT error: {error_msg}") await self._call_event_handler("on_connection_error", error_msg) elif isinstance(event, UnknownRealtimeEvent): logger.warning(f"{self}: Unknown realtime event: {event}") except Exception as e: await self.push_error(error_msg=f"Mistral STT receive error: {e}", exception=e) await self._call_event_handler("on_connection_error", str(e)) finally: self._connection = None @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: str | None = None ): """Handle a transcription result with tracing.""" pass async def _update_settings(self, delta: STTSettings) -> dict[str, Any]: """Apply a settings delta, reconnecting if model or language changes. Args: delta: An STT settings delta. Returns: Dict mapping changed field names to their previous values. """ changed = await super()._update_settings(delta) if changed: await self._disconnect() await self._connect() return changed