#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Fish Audio text-to-speech service implementation.
This module provides integration with Fish Audio's real-time TTS WebSocket API
for streaming text-to-speech synthesis with customizable voice parameters.
"""
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass, field
from typing import Any, Literal, Self
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
TTSAudioRawFrame,
TTSStoppedFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, assert_given
from pipecat.services.tts_service import InterruptibleTTSService
from pipecat.transcriptions.language import Language
from pipecat.utils.tracing.service_decorators import traced_tts
try:
import ormsgpack
from websockets.asyncio.client import connect as websocket_connect
from websockets.protocol import State
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
raise Exception(f"Missing module: {e}")
# FishAudio supports various output formats
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]
[docs]
@dataclass
class FishAudioTTSSettings(TTSSettings):
"""Settings for FishAudioTTSService.
Parameters:
latency: Latency mode ("normal" or "balanced"). Defaults to "balanced".
normalize: Whether to normalize audio output. Defaults to True.
temperature: Controls randomness in speech generation (0.0-1.0).
top_p: Controls diversity via nucleus sampling (0.0-1.0).
prosody_speed: Speech speed multiplier (0.5-2.0). Defaults to 1.0.
prosody_volume: Volume adjustment in dB (-20 to 20). Defaults to 0.
"""
latency: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
normalize: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
temperature: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
top_p: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
prosody_speed: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
prosody_volume: int | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
[docs]
@classmethod
def from_mapping(cls, settings: Mapping[str, Any]) -> Self:
"""Construct settings from a plain dict, destructuring legacy nested ``prosody``."""
flat = dict(settings)
nested = flat.pop("prosody", None)
if isinstance(nested, dict):
flat.setdefault("prosody_speed", nested.get("speed"))
flat.setdefault("prosody_volume", nested.get("volume"))
return super().from_mapping(flat)
[docs]
class FishAudioTTSService(InterruptibleTTSService):
"""Fish Audio text-to-speech service with WebSocket streaming.
Provides real-time text-to-speech synthesis using Fish Audio's WebSocket API.
Supports various audio formats, customizable prosody controls, and streaming
audio generation with interruption handling.
"""
Settings = FishAudioTTSSettings
_settings: Settings
[docs]
def __init__(
self,
*,
api_key: str,
reference_id: str | None = None, # This is the voice ID
model_id: str | None = None,
output_format: FishAudioOutputFormat = "pcm",
sample_rate: int | None = None,
params: InputParams | None = None,
settings: Settings | None = None,
**kwargs,
):
"""Initialize the Fish Audio TTS service.
Args:
api_key: Fish Audio API key for authentication.
reference_id: Reference ID of the voice model to use for synthesis.
.. deprecated:: 0.0.105
Use ``settings=FishAudioTTSService.Settings(voice=...)`` instead.
model_id: Specify which Fish Audio TTS model to use (e.g. "s1").
.. deprecated:: 0.0.105
Use ``settings=FishAudioTTSService.Settings(model=...)`` instead.
output_format: Audio output format. Defaults to "pcm".
sample_rate: Audio sample rate. If None, uses default.
params: Additional input parameters for voice customization.
.. deprecated:: 0.0.105
Use ``settings=FishAudioTTSService.Settings(...)`` instead.
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
**kwargs: Additional arguments passed to the parent service.
"""
# 1. Initialize default_settings with hardcoded defaults
default_settings = self.Settings(
model="s2-pro",
voice=None,
language=None,
latency="balanced",
normalize=True,
temperature=None,
top_p=None,
prosody_speed=1.0,
prosody_volume=0,
)
# 2. Apply direct init arg overrides (deprecated)
if reference_id is not None:
self._warn_init_param_moved_to_settings("reference_id", "voice")
default_settings.voice = reference_id
if model_id is not None:
self._warn_init_param_moved_to_settings("model_id", "model")
default_settings.model = model_id
# 3. Apply params overrides — only if settings not provided
if params is not None:
self._warn_init_param_moved_to_settings("params")
if not settings:
if params.latency is not None:
default_settings.latency = params.latency
if params.normalize is not None:
default_settings.normalize = params.normalize
if params.prosody_speed is not None:
default_settings.prosody_speed = params.prosody_speed
if params.prosody_volume is not None:
default_settings.prosody_volume = params.prosody_volume
# 4. Apply settings delta (canonical API, always wins)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
push_stop_frames=True,
push_start_frame=True,
pause_frame_processing=True,
sample_rate=sample_rate,
settings=default_settings,
**kwargs,
)
self._api_key = api_key
self._base_url = "wss://api.fish.audio/v1/tts/live"
self._websocket = None
self._receive_task = None
# Init-only audio format config (not runtime-updatable).
self._fish_sample_rate = 0 # Set in start()
self._output_format = output_format
[docs]
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as Fish Audio service supports metrics generation.
"""
return True
async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
"""Apply a settings delta and reconnect if needed.
Any change to voice or model triggers a WebSocket reconnect.
Args:
delta: A :class:`TTSSettings` (or ``FishAudioTTSService.Settings``) delta.
Returns:
Dict mapping changed field names to their previous values.
"""
changed = await super()._update_settings(delta)
if changed:
await self._disconnect()
await self._connect()
return changed
[docs]
async def start(self, frame: StartFrame):
"""Start the Fish Audio TTS service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._fish_sample_rate = self.sample_rate
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the Fish Audio TTS service.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the Fish Audio TTS service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._websocket and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
try:
if self._websocket and self._websocket.state is State.OPEN:
return
logger.debug("Connecting to Fish Audio")
headers = {"Authorization": f"Bearer {self._api_key}"}
model = assert_given(self._settings.model)
if model is not None:
headers["model"] = model
self._websocket = await websocket_connect(self._base_url, additional_headers=headers)
# Send initial start message with ormsgpack
request_settings = {
"sample_rate": self._fish_sample_rate,
"latency": self._settings.latency,
"format": self._output_format,
"normalize": self._settings.normalize,
"prosody": {
"speed": self._settings.prosody_speed,
"volume": self._settings.prosody_volume,
},
"reference_id": self._settings.voice,
}
if self._settings.temperature is not None:
request_settings["temperature"] = self._settings.temperature
if self._settings.top_p is not None:
request_settings["top_p"] = self._settings.top_p
start_message = {"event": "start", "request": {"text": "", **request_settings}}
await self._websocket.send(ormsgpack.packb(start_message))
logger.debug("Sent start message to Fish Audio")
await self._call_event_handler("on_connected")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
self._websocket = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:
await self.stop_all_metrics()
if self._websocket:
logger.debug("Disconnecting from Fish Audio")
# Send stop event with ormsgpack
stop_message = {"event": "stop"}
await self._websocket.send(ormsgpack.packb(stop_message))
await self._websocket.close()
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
await self._call_event_handler("on_disconnected")
[docs]
async def flush_audio(self, context_id: str | None = None):
"""Flush any buffered audio by sending a flush event to Fish Audio."""
logger.trace(f"{self}: Flushing audio buffers")
if not self._websocket or self._websocket.state is State.CLOSED:
return
flush_message = {"event": "flush"}
await self._get_websocket().send(ormsgpack.packb(flush_message))
def _get_websocket(self):
if self._websocket:
return self._websocket
raise Exception("Websocket not connected")
[docs]
async def on_audio_context_interrupted(self, context_id: str):
"""Stop all metrics when audio context is interrupted."""
await self.stop_all_metrics()
await super().on_audio_context_interrupted(context_id)
async def _receive_messages(self):
async for message in self._get_websocket():
try:
if isinstance(message, bytes):
msg = ormsgpack.unpackb(message)
if isinstance(msg, dict):
event = msg.get("event")
if event == "audio":
audio_data = msg.get("audio")
# Only process larger chunks to remove msgpack overhead
if audio_data and len(audio_data) > 1024:
context_id = self.get_active_audio_context_id()
frame = TTSAudioRawFrame(
audio_data,
self.sample_rate,
1,
context_id=context_id,
)
await self.append_to_audio_context(context_id, frame)
await self.stop_ttfb_metrics()
elif event == "finish":
reason = msg.get("reason", "unknown")
if reason == "error":
await self.push_error(
error_msg="Fish Audio server error during synthesis"
)
else:
logger.debug(f"Fish Audio session finished: {reason}")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
[docs]
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Generate speech from text using Fish Audio's streaming API.
Args:
text: The text to synthesize into speech.
context_id: The context ID for tracking audio frames.
Yields:
Frame: Audio frames and control frames for the synthesized speech.
"""
logger.debug(f"{self}: Generating Fish TTS: [{text}]")
try:
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()
# Send the text
text_message = {
"event": "text",
"text": text,
}
try:
await self._get_websocket().send(ormsgpack.packb(text_message))
await self.start_tts_usage_metrics(text)
# Send flush event to force audio generation
flush_message = {"event": "flush"}
await self._get_websocket().send(ormsgpack.packb(flush_message))
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame(context_id=context_id)
await self._disconnect()
await self._connect()
yield None
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")