#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""AWS Transcribe Speech-to-Text service implementation.
This module provides a WebSocket-based connection to AWS Transcribe for real-time
speech-to-text transcription with support for multiple languages and audio formats.
"""
import json
import os
import random
import string
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.settings import STTSettings, assert_given
from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
raise Exception(f"Missing module: {e}")
[docs]
@dataclass
class AWSTranscribeSTTSettings(STTSettings):
"""Settings for AWSTranscribeSTTService."""
pass
[docs]
class AWSTranscribeSTTService(WebsocketSTTService):
"""AWS Transcribe Speech-to-Text service using WebSocket streaming.
Provides real-time speech transcription using AWS Transcribe's streaming API.
Supports multiple languages, configurable sample rates, and both interim and
final transcription results.
"""
Settings = AWSTranscribeSTTSettings
_settings: Settings
[docs]
def __init__(
self,
*,
api_key: str | None = None,
aws_access_key_id: str | None = None,
aws_session_token: str | None = None,
region: str | None = None,
sample_rate: int | None = None,
language: Language | None = None,
settings: Settings | None = None,
ttfs_p99_latency: float | None = AWS_TRANSCRIBE_TTFS_P99,
**kwargs,
):
"""Initialize the AWS Transcribe STT service.
Args:
api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
region: AWS region for the service.
sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate.
AWS Transcribe only supports 8000 or 16000 Hz; other values are
clamped to 16000 Hz at connect time.
language: Language for transcription.
.. deprecated:: 0.0.105
Use ``settings=AWSTranscribeSTTService.Settings(language=...)`` instead.
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
**kwargs: Additional arguments passed to parent STTService class.
"""
# 1. Initialize default_settings with hardcoded defaults
default_settings = self.Settings(
model=None,
language=Language.EN,
)
# 2. Apply direct init arg overrides (deprecated)
if language is not None:
self._warn_init_param_moved_to_settings("language", "language")
default_settings.language = language
# 3. (No step 3, as there's no params object to apply)
# 4. Apply settings delta (canonical API, always wins)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
sample_rate=sample_rate,
ttfs_p99_latency=ttfs_p99_latency,
settings=default_settings,
**kwargs,
)
# Init-only connection config (not runtime-updatable).
self._media_encoding = "linear16"
self._number_of_channels = 1
self._show_speaker_label = False
self._enable_channel_identification = False
self._credentials = {
"aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
"region": region or os.getenv("AWS_REGION", "us-east-1"),
}
self._receive_task = None
[docs]
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as AWS Transcribe STT supports metrics generation.
"""
return True
[docs]
def get_service_encoding(self, encoding: str) -> str:
"""Convert internal encoding format to AWS Transcribe format.
Args:
encoding: Internal encoding format string.
Returns:
AWS Transcribe compatible encoding format.
"""
encoding_map = {
"linear16": "pcm", # AWS expects "pcm" for 16-bit linear PCM
}
return encoding_map.get(encoding, encoding)
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta and reconnect if anything changed."""
changed = await super()._update_settings(delta)
if changed and self._websocket:
await self._disconnect()
await self._connect()
return changed
[docs]
async def start(self, frame: StartFrame):
"""Initialize the connection when the service starts.
Args:
frame: Start frame signaling service initialization.
"""
await super().start(frame)
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the service and disconnect from AWS Transcribe.
Args:
frame: End frame signaling service shutdown.
"""
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the service and disconnect from AWS Transcribe.
Args:
frame: Cancel frame signaling service cancellation.
"""
await super().cancel(frame)
await self._disconnect()
[docs]
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame | None, None]:
"""Process audio data and send to AWS Transcribe.
Args:
audio: Raw audio bytes to transcribe.
Yields:
ErrorFrame: If processing fails or connection issues occur.
"""
if self._websocket and self._websocket.state is State.OPEN:
try:
# Format the audio data according to AWS event stream format
event_message = build_event_message(audio)
# Send the formatted event message
await self._websocket.send(event_message)
# Start metrics after first chunk sent
await self.start_processing_metrics()
except Exception as e:
yield ErrorFrame(error=f"Error sending audio: {e}")
yield None
async def _connect(self):
"""Connect to the AWS Transcribe service.
Establishes websocket connection and starts receive task.
"""
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
"""Disconnect from the AWS Transcribe service.
Sends end-stream message and cleans up.
"""
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
# Send end-stream message before closing
if self._websocket and self._websocket.state is State.OPEN:
try:
end_stream = {"message-type": "event", "event": "end"}
await self._websocket.send(json.dumps(end_stream))
except Exception as e:
await self.push_error(error_msg=f"Error sending end-stream: {e}", exception=e)
await self._disconnect_websocket()
async def _connect_websocket(self):
"""Establish the websocket connection to AWS Transcribe."""
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to AWS Transcribe WebSocket")
language_code = assert_given(self._settings.language)
if not language_code:
raise ValueError(f"Unsupported language: {language_code}")
# Validate sample rate — AWS Transcribe only supports 8000 or 16000 Hz
connect_sample_rate = self.sample_rate
if connect_sample_rate not in (8000, 16000):
logger.warning(
f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. "
f"Converting from {connect_sample_rate} Hz to 16000 Hz."
)
connect_sample_rate = 16000
# Generate random websocket key
websocket_key = "".join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
)
)
# Add required headers
additional_headers = {
"Origin": "https://localhost",
"Sec-WebSocket-Key": websocket_key,
"Sec-WebSocket-Version": "13",
"Connection": "keep-alive",
}
# Get presigned URL
presigned_url = get_presigned_url(
region=self._credentials["region"],
credentials={
"access_key": self._credentials["aws_access_key_id"],
"secret_key": self._credentials["aws_secret_access_key"],
"session_token": self._credentials["aws_session_token"],
},
language_code=language_code,
media_encoding=self.get_service_encoding(
self._media_encoding
), # Convert to AWS format
sample_rate=connect_sample_rate,
number_of_channels=self._number_of_channels,
enable_partial_results_stabilization=True,
partial_results_stability="high",
show_speaker_label=self._show_speaker_label,
enable_channel_identification=self._enable_channel_identification,
)
logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")
# Connect with the required headers and settings
self._websocket = await websocket_connect(
presigned_url,
additional_headers=additional_headers,
subprotocols=["mqtt"],
ping_interval=None,
ping_timeout=None,
compression=None,
)
await self._call_event_handler("on_connected")
logger.info(f"{self} Successfully connected to AWS Transcribe")
except Exception as e:
await self.push_error(
error_msg=f"Unable to connect to AWS Transcribe: {e}", exception=e
)
raise
async def _disconnect_websocket(self):
"""Close the websocket connection to AWS Transcribe."""
try:
if self._websocket:
logger.debug("Disconnecting from AWS Transcribe WebSocket")
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
[docs]
def language_to_service_language(self, language: Language) -> str | None:
"""Convert internal language enum to AWS Transcribe language code.
Source:
https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
All language codes that support streaming are included.
Args:
language: Internal language enumeration value.
Returns:
AWS Transcribe compatible language code, or None if unsupported.
"""
LANGUAGE_MAP = {
# Afrikaans
Language.AF: "af-ZA",
Language.AF_ZA: "af-ZA",
# Arabic
Language.AR: "ar-SA", # Default to Modern Standard Arabic
Language.AR_AE: "ar-AE", # Gulf Arabic
Language.AR_SA: "ar-SA", # Modern Standard Arabic
# Basque
Language.EU: "eu-ES",
Language.EU_ES: "eu-ES",
# Catalan
Language.CA: "ca-ES",
Language.CA_ES: "ca-ES",
# Chinese
Language.ZH: "zh-CN", # Default to Simplified
Language.ZH_CN: "zh-CN", # Simplified
Language.ZH_TW: "zh-TW", # Traditional
Language.ZH_HK: "zh-HK", # Cantonese (also yue-HK)
Language.YUE: "zh-HK", # Cantonese fallback
# Croatian
Language.HR: "hr-HR",
Language.HR_HR: "hr-HR",
# Czech
Language.CS: "cs-CZ",
Language.CS_CZ: "cs-CZ",
# Danish
Language.DA: "da-DK",
Language.DA_DK: "da-DK",
# Dutch
Language.NL: "nl-NL",
Language.NL_NL: "nl-NL",
# English
Language.EN: "en-US", # Default to US
Language.EN_AU: "en-AU", # Australian
Language.EN_GB: "en-GB", # British
Language.EN_IN: "en-IN", # Indian
Language.EN_IE: "en-IE", # Irish
Language.EN_NZ: "en-NZ", # New Zealand
# Note: Scottish (en-AB) and Welsh (en-WL) don't have direct Language enum matches
Language.EN_ZA: "en-ZA", # South African
Language.EN_US: "en-US", # US
# Persian/Farsi
Language.FA: "fa-IR",
Language.FA_IR: "fa-IR",
# Finnish
Language.FI: "fi-FI",
Language.FI_FI: "fi-FI",
# French
Language.FR: "fr-FR", # Default to France
Language.FR_FR: "fr-FR",
Language.FR_CA: "fr-CA", # Canadian
# Galician
Language.GL: "gl-ES",
Language.GL_ES: "gl-ES",
# Georgian
Language.KA: "ka-GE",
Language.KA_GE: "ka-GE",
# German
Language.DE: "de-DE", # Default to Germany
Language.DE_DE: "de-DE",
Language.DE_CH: "de-CH", # Swiss
# Greek
Language.EL: "el-GR",
Language.EL_GR: "el-GR",
# Hebrew
Language.HE: "he-IL",
Language.HE_IL: "he-IL",
# Hindi
Language.HI: "hi-IN",
Language.HI_IN: "hi-IN",
# Indonesian
Language.ID: "id-ID",
Language.ID_ID: "id-ID",
# Italian
Language.IT: "it-IT",
Language.IT_IT: "it-IT",
# Japanese
Language.JA: "ja-JP",
Language.JA_JP: "ja-JP",
# Korean
Language.KO: "ko-KR",
Language.KO_KR: "ko-KR",
# Latvian
Language.LV: "lv-LV",
Language.LV_LV: "lv-LV",
# Malay
Language.MS: "ms-MY",
Language.MS_MY: "ms-MY",
# Norwegian
Language.NB: "no-NO", # Norwegian Bokmål
Language.NB_NO: "no-NO",
Language.NO: "no-NO",
# Polish
Language.PL: "pl-PL",
Language.PL_PL: "pl-PL",
# Portuguese
Language.PT: "pt-PT", # Default to Portugal
Language.PT_PT: "pt-PT",
Language.PT_BR: "pt-BR", # Brazilian
# Romanian
Language.RO: "ro-RO",
Language.RO_RO: "ro-RO",
# Russian
Language.RU: "ru-RU",
Language.RU_RU: "ru-RU",
# Serbian
Language.SR: "sr-RS",
Language.SR_RS: "sr-RS",
# Slovak
Language.SK: "sk-SK",
Language.SK_SK: "sk-SK",
# Somali
Language.SO: "so-SO",
Language.SO_SO: "so-SO",
# Spanish
Language.ES: "es-ES", # Default to Spain
Language.ES_ES: "es-ES",
Language.ES_US: "es-US", # US Spanish
# Swedish
Language.SV: "sv-SE",
Language.SV_SE: "sv-SE",
# Tagalog/Filipino
Language.TL: "tl-PH",
Language.FIL: "tl-PH", # Filipino maps to Tagalog
Language.FIL_PH: "tl-PH",
# Thai
Language.TH: "th-TH",
Language.TH_TH: "th-TH",
# Ukrainian
Language.UK: "uk-UA",
Language.UK_UA: "uk-UA",
# Vietnamese
Language.VI: "vi-VN",
Language.VI_VN: "vi-VN",
# Zulu
Language.ZU: "zu-ZA",
Language.ZU_ZA: "zu-ZA",
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: str | None = None
):
pass
def _get_websocket(self):
"""Get the current WebSocket connection.
Returns:
The WebSocket connection.
Raises:
Exception: If WebSocket is not connected.
"""
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
async def _receive_messages(self):
"""Receive and process websocket messages.
Continuously processes messages from the websocket connection.
"""
async for response in self._get_websocket():
try:
headers, payload = decode_event(response)
if headers.get(":message-type") == "event":
# Process transcription results
results = payload.get("Transcript", {}).get("Results", [])
if results:
result = results[0]
alternatives = result.get("Alternatives", [])
if alternatives:
transcript = alternatives[0].get("Transcript", "")
is_final = not result.get("IsPartial", True)
if transcript:
language = assert_given(self._settings.language)
if is_final:
await self.push_frame(
TranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=result,
)
)
await self._handle_transcription(
transcript,
is_final,
language,
)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(
transcript,
self._user_id,
time_now_iso8601(),
language,
result=result,
)
)
elif headers.get(":message-type") == "exception":
error_msg = payload.get("Message", "Unknown error")
await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
else:
logger.debug(f"{self} Other message type received: {headers}")
logger.debug(f"{self} Payload: {payload}")
except Exception as e:
logger.warning(f"Error processing message: {e}")