Source code for pipecat.services.fish.tts

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Fish Audio text-to-speech service implementation.

This module provides integration with Fish Audio's real-time TTS WebSocket API
for streaming text-to-speech synthesis with customizable voice parameters.
"""

from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from typing import Any, Literal, Self

from loguru import logger
from pydantic import BaseModel

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    StartFrame,
    TTSAudioRawFrame,
    TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, assert_given
from pipecat.services.tts_service import InterruptibleTTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    import ormsgpack
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
    raise Exception(f"Missing module: {e}")

# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]


[docs] @dataclass class FishAudioTTSSettings(TTSSettings): """Settings for FishAudioTTSService. Parameters: latency: Latency mode ("normal" or "balanced"). Defaults to "balanced". normalize: Whether to normalize audio output. Defaults to True. temperature: Controls randomness in speech generation (0.0-1.0). top_p: Controls diversity via nucleus sampling (0.0-1.0). prosody_speed: Speech speed multiplier (0.5-2.0). Defaults to 1.0. prosody_volume: Volume adjustment in dB (-20 to 20). Defaults to 0. """ latency: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) normalize: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) temperature: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) top_p: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) prosody_speed: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) prosody_volume: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
[docs] @classmethod def from_mapping(cls, settings: Mapping[str, Any]) -> Self: """Construct settings from a plain dict, destructuring legacy nested ``prosody``.""" flat = dict(settings) nested = flat.pop("prosody", None) if isinstance(nested, dict): flat.setdefault("prosody_speed", nested.get("speed")) flat.setdefault("prosody_volume", nested.get("volume")) return super().from_mapping(flat)
[docs] class FishAudioTTSService(InterruptibleTTSService): """Fish Audio text-to-speech service with WebSocket streaming. Provides real-time text-to-speech synthesis using Fish Audio's WebSocket API. Supports various audio formats, customizable prosody controls, and streaming audio generation with interruption handling. """ Settings = FishAudioTTSSettings _settings: Settings
[docs] class InputParams(BaseModel): """Input parameters for Fish Audio TTS configuration. .. deprecated:: 0.0.105 Use ``settings=FishAudioTTSService.Settings(...)`` instead. Parameters: language: Language for synthesis. Defaults to English. latency: Latency mode ("normal" or "balanced"). Defaults to "normal". normalize: Whether to normalize audio output. Defaults to True. prosody_speed: Speech speed multiplier (0.5-2.0). Defaults to 1.0. prosody_volume: Volume adjustment in dB. Defaults to 0. """ language: Language | None = Language.EN latency: str | None = "normal" # "normal" or "balanced" normalize: bool | None = True prosody_speed: float | None = 1.0 # Speech speed (0.5-2.0) prosody_volume: int | None = 0 # Volume adjustment in dB
[docs] def __init__( self, *, api_key: str, reference_id: str | None = None, # This is the voice ID model_id: str | None = None, output_format: FishAudioOutputFormat = "pcm", sample_rate: int | None = None, params: InputParams | None = None, settings: Settings | None = None, **kwargs, ): """Initialize the Fish Audio TTS service. Args: api_key: Fish Audio API key for authentication. reference_id: Reference ID of the voice model to use for synthesis. .. deprecated:: 0.0.105 Use ``settings=FishAudioTTSService.Settings(voice=...)`` instead. model_id: Specify which Fish Audio TTS model to use (e.g. "s1"). .. deprecated:: 0.0.105 Use ``settings=FishAudioTTSService.Settings(model=...)`` instead. output_format: Audio output format. Defaults to "pcm". sample_rate: Audio sample rate. If None, uses default. params: Additional input parameters for voice customization. .. deprecated:: 0.0.105 Use ``settings=FishAudioTTSService.Settings(...)`` instead. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. **kwargs: Additional arguments passed to the parent service. """ # 1. Initialize default_settings with hardcoded defaults default_settings = self.Settings( model="s2-pro", voice=None, language=None, latency="balanced", normalize=True, temperature=None, top_p=None, prosody_speed=1.0, prosody_volume=0, ) # 2. Apply direct init arg overrides (deprecated) if reference_id is not None: self._warn_init_param_moved_to_settings("reference_id", "voice") default_settings.voice = reference_id if model_id is not None: self._warn_init_param_moved_to_settings("model_id", "model") default_settings.model = model_id # 3. Apply params overrides — only if settings not provided if params is not None: self._warn_init_param_moved_to_settings("params") if not settings: if params.latency is not None: default_settings.latency = params.latency if params.normalize is not None: default_settings.normalize = params.normalize if params.prosody_speed is not None: default_settings.prosody_speed = params.prosody_speed if params.prosody_volume is not None: default_settings.prosody_volume = params.prosody_volume # 4. Apply settings delta (canonical API, always wins) if settings is not None: default_settings.apply_update(settings) super().__init__( push_stop_frames=True, push_start_frame=True, pause_frame_processing=True, sample_rate=sample_rate, settings=default_settings, **kwargs, ) self._api_key = api_key self._base_url = "wss://api.fish.audio/v1/tts/live" self._websocket = None self._receive_task = None # Init-only audio format config (not runtime-updatable). self._fish_sample_rate = 0 # Set in start() self._output_format = output_format
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. Returns: True, as Fish Audio service supports metrics generation. """ return True
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]: """Apply a settings delta and reconnect if needed. Any change to voice or model triggers a WebSocket reconnect. Args: delta: A :class:`TTSSettings` (or ``FishAudioTTSService.Settings``) delta. Returns: Dict mapping changed field names to their previous values. """ changed = await super()._update_settings(delta) if changed: await self._disconnect() await self._connect() return changed
[docs] async def start(self, frame: StartFrame): """Start the Fish Audio TTS service. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) self._fish_sample_rate = self.sample_rate await self._connect()
[docs] async def stop(self, frame: EndFrame): """Stop the Fish Audio TTS service. Args: frame: The end frame. """ await super().stop(frame) await self._disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the Fish Audio TTS service. Args: frame: The cancel frame. """ await super().cancel(frame) await self._disconnect()
async def _connect(self): await super()._connect() await self._connect_websocket() if self._websocket and not self._receive_task: self._receive_task = self.create_task(self._receive_task_handler(self._report_error)) async def _disconnect(self): await super()._disconnect() if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None await self._disconnect_websocket() async def _connect_websocket(self): try: if self._websocket and self._websocket.state is State.OPEN: return logger.debug("Connecting to Fish Audio") headers = {"Authorization": f"Bearer {self._api_key}"} model = assert_given(self._settings.model) if model is not None: headers["model"] = model self._websocket = await websocket_connect(self._base_url, additional_headers=headers) # Send initial start message with ormsgpack request_settings = { "sample_rate": self._fish_sample_rate, "latency": self._settings.latency, "format": self._output_format, "normalize": self._settings.normalize, "prosody": { "speed": self._settings.prosody_speed, "volume": self._settings.prosody_volume, }, "reference_id": self._settings.voice, } if self._settings.temperature is not None: request_settings["temperature"] = self._settings.temperature if self._settings.top_p is not None: request_settings["top_p"] = self._settings.top_p start_message = {"event": "start", "request": {"text": "", **request_settings}} await self._websocket.send(ormsgpack.packb(start_message)) logger.debug("Sent start message to Fish Audio") await self._call_event_handler("on_connected") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) self._websocket = None await self._call_event_handler("on_connection_error", f"{e}") async def _disconnect_websocket(self): try: await self.stop_all_metrics() if self._websocket: logger.debug("Disconnecting from Fish Audio") # Send stop event with ormsgpack stop_message = {"event": "stop"} await self._websocket.send(ormsgpack.packb(stop_message)) await self._websocket.close() except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: self._websocket = None await self._call_event_handler("on_disconnected")
[docs] async def flush_audio(self, context_id: str | None = None): """Flush any buffered audio by sending a flush event to Fish Audio.""" logger.trace(f"{self}: Flushing audio buffers") if not self._websocket or self._websocket.state is State.CLOSED: return flush_message = {"event": "flush"} await self._get_websocket().send(ormsgpack.packb(flush_message))
def _get_websocket(self): if self._websocket: return self._websocket raise Exception("Websocket not connected")
[docs] async def on_audio_context_interrupted(self, context_id: str): """Stop all metrics when audio context is interrupted.""" await self.stop_all_metrics() await super().on_audio_context_interrupted(context_id)
async def _receive_messages(self): async for message in self._get_websocket(): try: if isinstance(message, bytes): msg = ormsgpack.unpackb(message) if isinstance(msg, dict): event = msg.get("event") if event == "audio": audio_data = msg.get("audio") # Only process larger chunks to remove msgpack overhead if audio_data and len(audio_data) > 1024: context_id = self.get_active_audio_context_id() frame = TTSAudioRawFrame( audio_data, self.sample_rate, 1, context_id=context_id, ) await self.append_to_audio_context(context_id, frame) await self.stop_ttfb_metrics() elif event == "finish": reason = msg.get("reason", "unknown") if reason == "error": await self.push_error( error_msg="Fish Audio server error during synthesis" ) else: logger.debug(f"Fish Audio session finished: {reason}") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
[docs] @traced_tts async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]: """Generate speech from text using Fish Audio's streaming API. Args: text: The text to synthesize into speech. context_id: The context ID for tracking audio frames. Yields: Frame: Audio frames and control frames for the synthesized speech. """ logger.debug(f"{self}: Generating Fish TTS: [{text}]") try: if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() # Send the text text_message = { "event": "text", "text": text, } try: await self._get_websocket().send(ormsgpack.packb(text_message)) await self.start_tts_usage_metrics(text) # Send flush event to force audio generation flush_message = {"event": "flush"} await self._get_websocket().send(ormsgpack.packb(flush_message)) except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}") yield TTSStoppedFrame(context_id=context_id) await self._disconnect() await self._connect() yield None except Exception as e: yield ErrorFrame(error=f"Unknown error occurred: {e}")