Source code for pipecat.services.xai.tts

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""xAI text-to-speech service implementation.

Provides two TTS services against xAI's voice API:

- :class:`XAIHttpTTSService` uses the batch HTTP endpoint at
  ``https://api.x.ai/v1/tts``.
- :class:`XAITTSService` uses the streaming WebSocket endpoint at
  ``wss://api.x.ai/v1/tts``.

See https://docs.x.ai/developers/rest-api-reference/inference/voice.
"""

import base64
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlencode

import aiohttp
from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    StartFrame,
    TTSAudioRawFrame,
    TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use XAITTSService, you need to `pip install pipecat-ai[xai]`.")
    raise Exception(f"Missing module: {e}")


[docs] def language_to_xai_language(language: Language) -> str | None: """Convert a Language enum to xAI language code. Args: language: The Language enum value to convert. Returns: The corresponding xAI language code, or None if not supported. """ LANGUAGE_MAP = { Language.AR: "ar-EG", Language.AR_EG: "ar-EG", Language.AR_SA: "ar-SA", Language.AR_AE: "ar-AE", Language.BN: "bn", Language.DE: "de", Language.EN: "en", Language.ES: "es-ES", Language.ES_ES: "es-ES", Language.ES_MX: "es-MX", Language.FR: "fr", Language.HI: "hi", Language.ID: "id", Language.IT: "it", Language.JA: "ja", Language.KO: "ko", Language.PT: "pt-PT", Language.PT_BR: "pt-BR", Language.PT_PT: "pt-PT", Language.RU: "ru", Language.TR: "tr", Language.VI: "vi", Language.ZH: "zh", } return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
[docs] @dataclass class XAITTSSettings(TTSSettings): """Settings for XAIHttpTTSService.""" pass
[docs] class XAIHttpTTSService(TTSService): """xAI HTTP text-to-speech service. The service requests raw PCM audio so emitted ``TTSAudioRawFrame`` objects match Pipecat's downstream expectations without extra decoding. """ Settings = XAITTSSettings _settings: Settings
[docs] def __init__( self, *, api_key: str, base_url: str = "https://api.x.ai/v1/tts", sample_rate: int | None = None, encoding: str | None = "pcm", aiohttp_session: aiohttp.ClientSession | None = None, settings: Settings | None = None, **kwargs, ): """Initialize the xAI TTS service. Args: api_key: xAI API key for authentication. base_url: xAI TTS endpoint. Defaults to ``https://api.x.ai/v1/tts``. sample_rate: Audio sample rate. If None, uses default. encoding: Output encoding format. Defaults to "pcm". aiohttp_session: Optional shared aiohttp session. settings: Runtime-updatable settings. **kwargs: Additional keyword arguments passed to ``TTSService``. """ default_settings = self.Settings( model=None, voice="eve", language=Language.EN, ) if settings is not None: default_settings.apply_update(settings) super().__init__( sample_rate=sample_rate, push_start_frame=True, push_stop_frames=True, settings=default_settings, **kwargs, ) self._api_key = api_key self._base_url = base_url self._session = aiohttp_session self._session_owner = aiohttp_session is None self._encoding = encoding
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics.""" return True
[docs] def language_to_service_language(self, language: Language) -> str | None: """Convert a Language enum to xAI language format. Args: language: The language to convert. Returns: The xAI-specific language code, or None if not supported. """ return language_to_xai_language(language)
[docs] async def start(self, frame): """Start the xAI TTS service.""" await super().start(frame) if self._session is None or self._session.closed: self._session = aiohttp.ClientSession() self._session_owner = True
[docs] async def stop(self, frame): """Stop the xAI TTS service.""" await super().stop(frame) await self._close_session()
[docs] async def cancel(self, frame): """Cancel the xAI TTS service.""" await super().cancel(frame) await self._close_session()
async def _close_session(self): if self._session_owner and self._session and not self._session.closed: await self._session.close() if self._session_owner: self._session = None
[docs] @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]: """Generate speech from text using xAI's TTS API.""" logger.debug(f"{self}: Generating TTS [{text}]") if self._session is None or self._session.closed: self._session = aiohttp.ClientSession() self._session_owner = True payload = { "text": text, "voice_id": self._settings.voice, "output_format": { "codec": self._encoding, "sample_rate": self.sample_rate, }, } if self._settings.language: payload["language"] = str(self._settings.language) headers = { "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", } measuring_ttfb = True try: async with self._session.post( self._base_url, json=payload, headers=headers ) as response: if response.status != 200: error = await response.text(errors="ignore") yield ErrorFrame( error=f"Error getting audio (status: {response.status}, error: {error})" ) return await self.start_tts_usage_metrics(text) async for chunk in response.content.iter_chunked(self.chunk_size): if not chunk: continue if measuring_ttfb: await self.stop_ttfb_metrics() measuring_ttfb = False yield TTSAudioRawFrame( chunk, self.sample_rate, 1, context_id=context_id, ) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")
[docs] @dataclass class XAIWebsocketTTSSettings(TTSSettings): """Settings for XAITTSService (WebSocket streaming).""" pass
[docs] class XAITTSService(InterruptibleTTSService): """xAI streaming text-to-speech service. Connects to xAI's WebSocket TTS endpoint and streams audio chunks back as they are synthesized. Text can be sent incrementally via ``text.delta`` messages and each utterance is terminated with ``text.done``. The server responds with ``audio.delta`` chunks followed by an ``audio.done`` message. Audio parameters (voice, language, codec, sample rate, bit rate) are passed as query string parameters on the WebSocket URL; changing any of them at runtime reconnects the WebSocket. """ Settings = XAIWebsocketTTSSettings _settings: Settings
[docs] def __init__( self, *, api_key: str, base_url: str = "wss://api.x.ai/v1/tts", sample_rate: int | None = None, codec: str = "pcm", settings: Settings | None = None, **kwargs, ): """Initialize the xAI WebSocket TTS service. Args: api_key: xAI API key for authentication. base_url: xAI TTS WebSocket endpoint. Defaults to ``wss://api.x.ai/v1/tts``. sample_rate: Output audio sample rate in Hz. If None, uses the pipeline default. codec: Output audio codec. One of ``pcm``, ``wav``, ``mulaw``, ``alaw``. Defaults to ``pcm`` so emitted ``TTSAudioRawFrame`` objects need no decoding downstream. settings: Runtime-updatable settings. **kwargs: Additional arguments passed to parent ``InterruptibleTTSService``. """ default_settings = self.Settings( model=None, voice="eve", language=Language.EN, ) if settings is not None: default_settings.apply_update(settings) super().__init__( push_start_frame=True, push_stop_frames=True, sample_rate=sample_rate, settings=default_settings, **kwargs, ) self._api_key = api_key self._base_url = base_url self._codec = codec self._receive_task = None
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics.""" return True
[docs] def language_to_service_language(self, language: Language) -> str | None: """Convert a Language enum to xAI language format.""" return language_to_xai_language(language)
[docs] async def start(self, frame: StartFrame): """Start the xAI WebSocket TTS service.""" await super().start(frame) await self._connect()
[docs] async def stop(self, frame: EndFrame): """Stop the xAI WebSocket TTS service.""" await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the xAI WebSocket TTS service.""" await super().cancel(frame) await self._disconnect()
async def _connect(self): await super()._connect() await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): await super()._disconnect() if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None await self._disconnect_websocket() async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]: """Apply a settings delta. Reconnects if any URL-baked field changes.""" changed = await super()._update_settings(delta) if changed: await self._disconnect() await self._connect() return changed def _build_url(self) -> str: language = self._settings.language if isinstance(language, Language): language_value = language_to_xai_language(language) or language.value else: language_value = str(language) if language is not None else "auto" params: dict[str, Any] = { "voice": self._settings.voice, "language": language_value, "codec": self._codec, "sample_rate": self.sample_rate, } return f"{self._base_url}?{urlencode(params)}" async def _connect_websocket(self): try: if self._websocket and self._websocket.state is State.OPEN: return logger.debug("Connecting to xAI TTS") url = self._build_url() headers = {"Authorization": f"Bearer {self._api_key}"} self._websocket = await websocket_connect(url, additional_headers=headers) await self._call_event_handler("on_connected") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) self._websocket = None await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: await self.stop_all_metrics() if self._websocket: logger.debug("Disconnecting from xAI TTS") await self._websocket.close() except Exception as e: await self.push_error(error_msg=f"Error disconnecting from xAI TTS: {e}", exception=e) finally: self._websocket = None await self._call_event_handler("on_disconnected") def _get_websocket(self): if self._websocket: return self._websocket raise Exception("Websocket not connected")
[docs] async def flush_audio(self, context_id: str | None = None): """Signal end-of-utterance so xAI begins synthesizing what it has buffered.""" if not self._websocket or self._websocket.state is State.CLOSED: return await self._get_websocket().send(json.dumps({"type": "text.done"}))
async def _receive_messages(self): async for message in self._get_websocket(): if isinstance(message, bytes): logger.warning(f"{self}: unexpected binary frame from xAI TTS") continue try: msg = json.loads(message) except json.JSONDecodeError: logger.error(f"{self}: invalid JSON message: {message}") continue msg_type = msg.get("type") context_id = self.get_active_audio_context_id() if msg_type == "audio.delta": audio_b64 = msg.get("delta") if not audio_b64: continue audio = base64.b64decode(audio_b64) await self.stop_ttfb_metrics() if context_id: frame = TTSAudioRawFrame( audio=audio, sample_rate=self.sample_rate, num_channels=1, context_id=context_id, ) await self.append_to_audio_context(context_id, frame) elif msg_type == "audio.done": await self.stop_all_metrics() if context_id: await self.append_to_audio_context( context_id, TTSStoppedFrame(context_id=context_id) ) await self.remove_audio_context(context_id) elif msg_type == "error": await self.stop_all_metrics() error_detail = msg.get("message") or msg.get("error") or str(msg) if context_id: await self.append_to_audio_context( context_id, TTSStoppedFrame(context_id=context_id) ) await self.remove_audio_context(context_id) await self.push_error(error_msg=f"xAI TTS error: {error_detail}") else: logger.debug(f"{self}: unhandled xAI message type: {msg_type}")
[docs] @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]: """Generate TTS audio from text using xAI's streaming WebSocket API.""" logger.debug(f"{self}: Generating TTS [{text}]") try: if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() try: await self._get_websocket().send(json.dumps({"type": "text.delta", "delta": text})) await self.start_tts_usage_metrics(text) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") yield TTSStoppedFrame(context_id=context_id) await self._disconnect() await self._connect() return yield None except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")