Source code for pipecat.services.deepgram.sagemaker.stt

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Deepgram speech-to-text service for AWS SageMaker.

This module provides a Pipecat STT service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time transcription with support for interim results, multiple
languages, and various Deepgram features.
"""

import asyncio
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, fields
from typing import Any

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    VADUserStartedSpeakingFrame,
    VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.deepgram.stt import DeepgramSTTService, LiveOptions
from pipecat.services.settings import STTSettings, is_given
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt


[docs] @dataclass class DeepgramSageMakerSTTSettings(DeepgramSTTService.Settings): """Settings for the Deepgram SageMaker STT service. Inherits all fields from :class:`DeepgramSTTService.Settings`. """ pass
[docs] class DeepgramSageMakerSTTService(STTService): """Deepgram speech-to-text service for AWS SageMaker. Provides real-time speech recognition using Deepgram models deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency transcription with support for interim results, speaker diarization, and multiple languages. Requirements: - AWS credentials configured (via environment variables, AWS CLI, or instance metadata) - A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker Example:: stt = DeepgramSageMakerSTTService( endpoint_name="my-deepgram-endpoint", region="us-east-2", settings=DeepgramSageMakerSTTService.Settings( model="nova-3", language="en", interim_results=True, punctuate=True, ), ) """ Settings = DeepgramSageMakerSTTSettings _settings: Settings
[docs] def __init__( self, *, endpoint_name: str, region: str, encoding: str = "linear16", channels: int = 1, multichannel: bool = False, sample_rate: int | None = None, mip_opt_out: bool | None = None, live_options: LiveOptions | None = None, settings: Settings | None = None, ttfs_p99_latency: float | None = DEEPGRAM_SAGEMAKER_TTFS_P99, **kwargs, ): """Initialize the Deepgram SageMaker STT service. Args: endpoint_name: Name of the SageMaker endpoint with Deepgram model deployed (e.g., "my-deepgram-nova-3-endpoint"). region: AWS region where the endpoint is deployed (e.g., "us-east-2"). encoding: Audio encoding format. Defaults to "linear16". channels: Number of audio channels. Defaults to 1. multichannel: Transcribe each audio channel independently. Defaults to False. sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate. mip_opt_out: Opt out of Deepgram model improvement program. live_options: Legacy configuration options. .. deprecated:: 0.0.105 Use ``settings=DeepgramSageMakerSTTService.Settings(...)`` for runtime-updatable fields and direct init parameters for connection-level config. settings: Runtime-updatable settings. When provided alongside ``live_options``, ``settings`` values take precedence (applied after the ``live_options`` merge). ttfs_p99_latency: P99 latency from speech end to final transcript in seconds. Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark **kwargs: Additional arguments passed to the parent STTService. """ # 1. Initialize default_settings with hardcoded defaults default_settings = self.Settings( model="nova-3", language=Language.EN, detect_entities=False, diarize=False, dictation=False, endpointing=None, interim_results=True, keyterm=None, keywords=None, numerals=False, profanity_filter=True, punctuate=True, redact=None, replace=None, search=None, smart_format=False, utterance_end_ms=None, vad_events=False, ) # 2. Apply live_options overrides — only if settings not provided if live_options is not None: self._warn_init_param_moved_to_settings("live_options") if not settings: # Extract init-only fields from live_options if live_options.sample_rate is not None and sample_rate is None: sample_rate = live_options.sample_rate if live_options.encoding is not None: encoding = live_options.encoding if live_options.channels is not None: channels = live_options.channels if live_options.multichannel is not None: multichannel = live_options.multichannel if live_options.mip_opt_out is not None: mip_opt_out = live_options.mip_opt_out # Build settings delta from remaining fields init_only = { "sample_rate", "encoding", "channels", "multichannel", "mip_opt_out", } lo_dict = {k: v for k, v in live_options.to_dict().items() if k not in init_only} delta = self.Settings.from_mapping(lo_dict) default_settings.apply_update(delta) # 3. Apply settings delta (canonical API, always wins) if settings is not None: default_settings.apply_update(settings) # Sync extra to top-level fields so self._settings is unambiguous default_settings._sync_extra_to_fields() super().__init__( sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, settings=default_settings, **kwargs, ) self._endpoint_name = endpoint_name self._region = region # Init-only connection config (not runtime-updatable). self._encoding = encoding self._channels = channels self._multichannel = multichannel self._mip_opt_out = mip_opt_out self._client: SageMakerBidiClient | None = None self._response_task: asyncio.Task | None = None self._keepalive_task: asyncio.Task | None = None
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. Returns: True, as Deepgram SageMaker service supports metrics generation. """ return True
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]: """Apply a settings delta and warn about unhandled changes.""" changed = await super()._update_settings(delta) if not changed: return changed # Sync extra to fields after the update so self._settings stays unambiguous if isinstance(self._settings, self.Settings): self._settings._sync_extra_to_fields() # TODO: someday we could reconnect here to apply updated settings. # Code might look something like the below: # await self._disconnect() # await self._connect() self._warn_unhandled_updated_settings(changed) return changed
[docs] async def start(self, frame: StartFrame): """Start the Deepgram SageMaker STT service. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) await self._connect()
[docs] async def stop(self, frame: EndFrame): """Stop the Deepgram SageMaker STT service. Args: frame: The end frame. """ await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the Deepgram SageMaker STT service. Args: frame: The cancel frame. """ await super().cancel(frame) await self._disconnect()
[docs] async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]: """Send audio data to Deepgram for transcription. Args: audio: Raw audio bytes to transcribe. Yields: Frame: None (transcription results come via BiDi stream callbacks). """ if self._client and self._client.is_active: try: await self._client.send_audio_chunk(audio) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") yield None
def _build_query_string(self) -> str: """Build query string from current settings and init-only connection config.""" params = {} s = self._settings # Declared Deepgram-specific fields from settings for f in fields(s): if f.name in ("model", "language", "extra") or f.name.startswith("_"): continue value = getattr(s, f.name) if not is_given(value) or value is None: continue params[f.name] = str(value).lower() if isinstance(value, bool) else str(value) # model and language if is_given(s.model) and s.model is not None: params["model"] = str(s.model) if is_given(s.language) and s.language is not None: params["language"] = str(s.language) # Init-only connection config params["encoding"] = self._encoding params["channels"] = str(self._channels) params["multichannel"] = str(self._multichannel).lower() params["sample_rate"] = str(self.sample_rate) if self._mip_opt_out is not None: params["mip_opt_out"] = str(self._mip_opt_out).lower() # Any remaining values in extra if s.extra: for key, value in s.extra.items(): if value is not None: params[key] = str(value).lower() if isinstance(value, bool) else str(value) return "&".join(f"{k}={v}" for k, v in params.items()) async def _connect(self): """Connect to the SageMaker endpoint and start the BiDi session. Builds the Deepgram query string from settings, creates the BiDi client, starts the streaming session, and launches background tasks for processing responses and sending KeepAlive messages. """ logger.debug("Connecting to Deepgram on SageMaker...") query_string = self._build_query_string() # Create BiDi client self._client = SageMakerBidiClient( endpoint_name=self._endpoint_name, region=self._region, model_invocation_path="v1/listen", model_query_string=query_string, ) try: # Start the session await self._client.start_session() # Start processing responses in the background self._response_task = self.create_task(self._process_responses()) # Start keepalive task to maintain connection self._keepalive_task = self.create_task(self._send_keepalive()) logger.debug("Connected to Deepgram on SageMaker") await self._call_event_handler("on_connected") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) await self._call_event_handler("on_connection_error", str(e)) async def _disconnect(self): """Disconnect from the SageMaker endpoint. Sends a CloseStream message to Deepgram, cancels background tasks (KeepAlive and response processing), and closes the BiDi session. Safe to call multiple times. """ if self._client and self._client.is_active: logger.debug("Disconnecting from Deepgram on SageMaker...") # Send CloseStream message to Deepgram try: await self._client.send_json({"type": "CloseStream"}) except Exception as e: logger.warning(f"Failed to send CloseStream message: {e}") # Cancel keepalive task if self._keepalive_task and not self._keepalive_task.done(): await self.cancel_task(self._keepalive_task) # Cancel response processing task if self._response_task and not self._response_task.done(): await self.cancel_task(self._response_task) # Close the BiDi session await self._client.close_session() logger.debug("Disconnected from Deepgram on SageMaker") await self._call_event_handler("on_disconnected") async def _send_keepalive(self): """Send periodic KeepAlive messages to maintain the connection. Sends a KeepAlive JSON message to Deepgram every 5 seconds while the connection is active. This prevents the connection from timing out during periods of silence. """ while self._client and self._client.is_active: await asyncio.sleep(5) if self._client and self._client.is_active: try: await self._client.send_json({"type": "KeepAlive"}) except Exception as e: logger.warning(f"Failed to send KeepAlive: {e}") async def _process_responses(self): """Process streaming responses from Deepgram on SageMaker. Continuously receives responses from the BiDi stream, decodes the payload, parses JSON responses from Deepgram, and processes transcription results. Runs as a background task until the connection is closed or cancelled. """ try: while self._client and self._client.is_active: result = await self._client.receive_response() if result is None: break # Check if this is a PayloadPart with bytes if hasattr(result, "value") and hasattr(result.value, "bytes_"): if result.value.bytes_: response_data = result.value.bytes_.decode("utf-8") try: # Parse JSON response from Deepgram parsed = json.loads(response_data) # Extract and process transcript if available if "channel" in parsed: await self._handle_transcript_response(parsed) except json.JSONDecodeError: logger.warning(f"Non-JSON response: {response_data}") except asyncio.CancelledError: logger.debug("Response processor cancelled") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: logger.debug("Response processor stopped") async def _handle_transcript_response(self, parsed: dict): """Handle a transcript response from Deepgram. Extracts the transcript text, determines if it's final or interim, extracts language information, and pushes the appropriate frame (TranscriptionFrame or InterimTranscriptionFrame) downstream. Args: parsed: The parsed JSON response from Deepgram containing channel, alternatives, transcript, and metadata. """ alternatives = parsed.get("channel", {}).get("alternatives", []) if not alternatives or not alternatives[0].get("transcript"): return transcript = alternatives[0]["transcript"] if not transcript.strip(): return is_final = parsed.get("is_final", False) # Extract language if available language = None if alternatives[0].get("languages"): language = alternatives[0]["languages"][0] language = Language(language) if is_final: # Check if this response is from a finalize() call. # Only mark as finalized when both we requested it AND Deepgram confirms it. from_finalize = parsed.get("from_finalize", False) if from_finalize: self.confirm_finalize() await self.push_frame( TranscriptionFrame( transcript, self._user_id, time_now_iso8601(), language, result=parsed, ) ) await self._handle_transcription(transcript, is_final, language) await self.stop_processing_metrics() else: # Interim transcription await self.push_frame( InterimTranscriptionFrame( transcript, self._user_id, time_now_iso8601(), language, result=parsed, ) ) @traced_stt async def _handle_transcription( self, transcript: str, is_final: bool, language: Language | None = None ): """Handle a transcription result with tracing. This method is decorated with @traced_stt for observability and tracing integration. The actual transcription processing is handled by the parent class and observers. Args: transcript: The transcribed text. is_final: Whether this is a final transcription result. language: The detected language of the transcription, if available. """ pass async def _start_metrics(self): """Start processing metrics collection.""" await self.start_processing_metrics()
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): """Process frames with Deepgram SageMaker-specific handling. Args: frame: The frame to process. direction: The direction of frame processing. """ await super().process_frame(frame, direction) # Start metrics when user starts speaking (if VAD is not provided by Deepgram) if isinstance(frame, VADUserStartedSpeakingFrame): await self._start_metrics() elif isinstance(frame, VADUserStoppedSpeakingFrame): # https://developers.deepgram.com/docs/finalize # Mark that we're awaiting a from_finalize response self.request_finalize() if self._client and self._client.is_active: try: await self._client.send_json({"type": "Finalize"}) except Exception as e: logger.warning(f"Error sending Finalize message: {e}") logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")