Source code for pipecat.services.resembleai.tts

#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Resemble AI text-to-speech service implementations."""

import base64
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    StartFrame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Resemble AI, you need to `pip install pipecat-ai[resembleai]`.")
    raise Exception(f"Missing module: {e}")


[docs] @dataclass class ResembleAITTSSettings(TTSSettings): """Settings for ResembleAITTSService.""" pass
[docs] class ResembleAITTSService(WebsocketTTSService): """Resemble AI TTS service with WebSocket streaming and word timestamps. Provides text-to-speech using Resemble AI's streaming WebSocket API. Supports word-level timestamps and audio context management for handling multiple simultaneous synthesis requests with proper interruption support. """ Settings = ResembleAITTSSettings _settings: Settings
[docs] def __init__( self, *, api_key: str, voice_id: str | None = None, url: str = "wss://websocket.cluster.resemble.ai/stream", precision: str | None = "PCM_16", output_format: str | None = "wav", sample_rate: int | None = 22050, settings: Settings | None = None, **kwargs, ): """Initialize the Resemble AI TTS service. Args: api_key: Resemble AI API key for authentication. voice_id: Voice UUID to use for synthesis. .. deprecated:: 0.0.105 Use ``settings=ResembleAITTSService.Settings(voice=...)`` instead. url: WebSocket URL for Resemble AI TTS API. precision: PCM bit depth (PCM_32, PCM_24, PCM_16, or MULAW). output_format: Audio format (wav or mp3). sample_rate: Audio sample rate (8000, 16000, 22050, 32000, or 44100). Defaults to 22050. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. **kwargs: Additional arguments passed to the parent service. """ # 1. Initialize default_settings with hardcoded defaults default_settings = self.Settings( model=None, voice=None, language=None, ) # 2. Apply direct init arg overrides (deprecated) if voice_id is not None: self._warn_init_param_moved_to_settings("voice_id", "voice") default_settings.voice = voice_id # 3. (No step 3, as there's no params object to apply) # 4. Apply settings delta (canonical API, always wins) if settings is not None: default_settings.apply_update(settings) super().__init__( sample_rate=sample_rate, reuse_context_id_within_turn=False, settings=default_settings, **kwargs, ) self._api_key = api_key self._url = url # Init-only audio format config (not runtime-updatable). self._precision = precision or "PCM_16" self._output_format = output_format or "wav" self._resemble_sample_rate = 0 # Set in start() self._websocket = None self._request_id_counter = 0 self._receive_task = None # Map request_id to context_id for tracking TTS requests self._request_id_to_context: dict[int, str] = {} # Per-request audio buffers to handle concurrent TTS requests # ResembleAI may send odd-length data even for PCM_16, so buffering helps us # create properly aligned frames while maintaining smooth audio output self._audio_buffers: dict[str, bytearray] = {} self._buffer_threshold_bytes = 2208 # Jitter buffer: accumulate audio before starting playback to absorb network latency # ResembleAI sends audio in bursts with 300-450ms gaps between them # We need to buffer enough to cover these gaps before starting playback self._jitter_buffer_bytes = 44100 # ~1000ms at 22050Hz to handle 400ms+ network gaps self._playback_started: dict[str, bool] = {} # Track if we've started playback per request
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. Returns: True, as Resemble AI service supports metrics generation. """ return True
def _build_msg(self, text: str = "") -> str: """Build a JSON message for the Resemble AI WebSocket API. Args: text: The text or SSML to synthesize. Returns: JSON string containing the request payload. """ msg = { "voice_uuid": self._settings.voice, "data": text, "binary_response": False, # Use JSON frames to get timestamps "request_id": self._request_id_counter, # ResembleAI only accepts number "output_format": self._output_format, "sample_rate": self._resemble_sample_rate, "precision": self._precision, "no_audio_header": True, } self._request_id_counter += 1 return json.dumps(msg)
[docs] async def start(self, frame: StartFrame): """Start the Resemble AI TTS service. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) self._resemble_sample_rate = self.sample_rate await self._connect()
[docs] async def stop(self, frame: EndFrame): """Stop the Resemble AI TTS service. Args: frame: The end frame. """ await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the Resemble AI TTS service. Args: frame: The cancel frame. """ await super().cancel(frame) await self._disconnect()
async def _connect(self): """Connect to the Resemble AI WebSocket.""" await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): """Disconnect from the Resemble AI WebSocket.""" if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None await self._disconnect_websocket() async def _connect_websocket(self): """Establish WebSocket connection to Resemble AI.""" try: if self._websocket and self._websocket.state is State.OPEN: return logger.debug("Connecting to Resemble AI TTS") headers = {"Authorization": f"Bearer {self._api_key}"} self._websocket = await websocket_connect(self._url, additional_headers=headers) await self._call_event_handler("on_connected") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) self._websocket = None await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): """Close WebSocket connection to Resemble AI.""" try: await self.stop_all_metrics() if self._websocket: logger.debug("Disconnecting from Resemble AI") # ResembleAI doesn't send disconnect acknowledgement, set close_timeout to 0 self._websocket.close_timeout = 0 await self._websocket.close() except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: self._websocket = None self._audio_buffers.clear() self._playback_started.clear() self._request_id_to_context.clear() await self._call_event_handler("on_disconnected") def _get_websocket(self): """Get the current WebSocket connection. Returns: The active WebSocket connection. Raises: Exception: If websocket is not connected. """ if self._websocket: return self._websocket raise Exception("Websocket not connected")
[docs] async def on_audio_context_interrupted(self, context_id: str): """Stop metrics when the bot is interrupted.""" await self.stop_all_metrics() await super().on_audio_context_interrupted(context_id)
[docs] async def on_audio_context_completed(self, context_id: str): """Stop metrics after the Resemble AI context finishes playing. No close message is needed: Resemble AI signals completion with an ``audio_end`` message (handled in ``_process_messages``), after which the server-side context is already closed. """ await super().on_audio_context_completed(context_id)
[docs] async def flush_audio(self, context_id: str | None = None): """Flush any pending audio and finalize the current context.""" logger.trace(f"{self}: flushing audio")
# For Resemble AI, we just wait for the audio_end message # which is handled in _process_messages async def _process_messages(self): """Process incoming WebSocket messages from Resemble AI.""" async for message in self._get_websocket(): try: msg = json.loads(message) except json.JSONDecodeError: await self.push_error(error_msg=f"Received invalid JSON: {message}") continue if not msg: continue msg_type = msg.get("type") request_id = msg.get("request_id") # Convert request_id to string for audio context tracking context_id = self._request_id_to_context.get(request_id, str(request_id)) # Check if this message belongs to a valid audio context if not self.audio_context_available(context_id): continue if msg_type == "audio": # Decode base64 audio content audio_content = msg.get("audio_content", "") if not audio_content: continue audio_bytes = base64.b64decode(audio_content) if len(audio_bytes) == 0: continue # Get or create buffer for this request if context_id not in self._audio_buffers: self._audio_buffers[context_id] = bytearray() self._playback_started[context_id] = False buffer = self._audio_buffers[context_id] # Add to buffer buffer.extend(audio_bytes) # Wait for jitter buffer to fill before starting playback # This absorbs network latency gaps (ResembleAI sends in bursts) if not self._playback_started.get(context_id, False): if len(buffer) < self._jitter_buffer_bytes: continue self._playback_started[context_id] = True # Send complete (even-byte) chunks for PCM_16 alignment while len(buffer) >= self._buffer_threshold_bytes: chunk_size = self._buffer_threshold_bytes if chunk_size % 2 != 0: chunk_size -= 1 chunk_to_send = bytes(buffer[:chunk_size]) self._audio_buffers[context_id] = buffer[chunk_size:] buffer = self._audio_buffers[context_id] if len(chunk_to_send) == 0: continue frame = TTSAudioRawFrame( audio=chunk_to_send, sample_rate=self.sample_rate, num_channels=1, context_id=context_id, ) await self.append_to_audio_context(context_id, frame) # Process timestamps if available timestamps = msg.get("audio_timestamps", {}) if timestamps: graph_chars = timestamps.get("graph_chars", []) graph_times = timestamps.get("graph_times", []) # Convert graph_times (start, end pairs) to word timestamps word_times = [] for char, times in zip(graph_chars, graph_times): if times and len(times) >= 2: start_time = times[0] word_times.append((char, start_time)) if word_times: await self.add_word_timestamps(word_times, context_id) elif msg_type == "audio_end": await self.stop_ttfb_metrics() # Flush remaining buffer, ensuring even length for PCM_16 buffer = self._audio_buffers.get(context_id, bytearray()) if buffer: remaining = bytes(buffer) # PCM_16 requires even number of bytes if len(remaining) % 2 != 0: remaining = remaining[:-1] if remaining: frame = TTSAudioRawFrame( audio=remaining, sample_rate=self.sample_rate, num_channels=1, context_id=context_id, ) await self.append_to_audio_context(context_id, frame) # Clean up buffer and playback tracking for this request if context_id in self._audio_buffers: del self._audio_buffers[context_id] if context_id in self._playback_started: del self._playback_started[context_id] # Clean up request_id mapping if request_id in self._request_id_to_context: del self._request_id_to_context[request_id] await self.append_to_audio_context( context_id, TTSStoppedFrame(context_id=context_id) ) await self.remove_audio_context(context_id) elif msg_type == "error": error_name = msg.get("error_name", "Unknown") error_msg = msg.get("message", "Unknown error") status_code = msg.get("status_code", 0) await self.push_error( error_msg=f"Error: {error_name} (status {status_code}): {error_msg}" ) # Clean up buffer and playback tracking for this request if context_id in self._audio_buffers: del self._audio_buffers[context_id] if context_id in self._playback_started: del self._playback_started[context_id] await self.push_frame(TTSStoppedFrame(context_id=context_id)) await self.stop_all_metrics() await self.push_error(ErrorFrame(error=f"{self} error: {error_name} - {error_msg}")) # Check if this is an unrecoverable error (connection-level failure) if status_code in [401, 403]: # Close and reconnect for auth errors await self._disconnect_websocket() await self._connect_websocket() else: logger.warning(f"{self} unknown message type: {msg_type}") async def _receive_messages(self): """Main loop for receiving messages from Resemble AI.""" while True: try: await self._process_messages() except Exception as e: await self.push_error(error_msg=f"Error in receive loop: {e}", exception=e) # Try to reconnect logger.debug(f"{self} Resemble AI connection lost, reconnecting") await self._connect_websocket()
[docs] @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]: """Generate speech from text using Resemble AI's streaming API. Args: text: The text to synthesize into speech. context_id: Unique identifier for this TTS context. Yields: Frame: Audio frames containing the synthesized speech. """ logger.debug(f"{self}: Generating TTS [{text}]") try: if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() if not self.audio_context_available(context_id): await self.create_audio_context(context_id) await self.start_ttfb_metrics() yield TTSStartedFrame(context_id=context_id) # Map request_id to context_id for tracking self._request_id_to_context[self._request_id_counter] = context_id msg = self._build_msg(text=text) try: await self._get_websocket().send(msg) await self.start_tts_usage_metrics(text) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") yield TTSStoppedFrame(context_id=context_id) await self._disconnect() await self._connect() return yield None except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")