#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Kokoro TTS service implementation using kokoro-onnx."""
import os
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from loguru import logger
from pydantic import BaseModel
from pipecat.audio.utils import create_stream_resampler
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
)
from pipecat.services.settings import TTSSettings, assert_given
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts
try:
import requests
from kokoro_onnx import Kokoro
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Kokoro, you need to `pip install pipecat-ai[kokoro]`.")
raise Exception(f"Missing module: {e}")
KOKORO_CACHE_DIR = Path(os.path.expanduser("~/.cache/kokoro-onnx"))
KOKORO_MODEL_URL = "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/kokoro-v1.0.onnx"
KOKORO_VOICES_URL = (
"https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin"
)
def _download_file(url: str, dest: Path):
"""Download a file from a URL to a destination path."""
logger.debug(f"Downloading {url} to {dest}...")
dest.parent.mkdir(parents=True, exist_ok=True)
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
with open(dest, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
logger.debug(f"Downloaded {dest}")
def _ensure_model_files(model_path: Path, voices_path: Path):
"""Download model files if they don't exist."""
if not model_path.exists():
_download_file(KOKORO_MODEL_URL, model_path)
if not voices_path.exists():
_download_file(KOKORO_VOICES_URL, voices_path)
[docs]
def language_to_kokoro_language(language: Language) -> str:
"""Convert a Language enum to kokoro-onnx language code.
Args:
language: The Language enum value to convert.
Returns:
The corresponding kokoro-onnx locale code.
"""
LANGUAGE_MAP = {
Language.EN: "en-us",
Language.EN_US: "en-us",
Language.EN_GB: "en-gb",
Language.ES: "es",
Language.FR: "fr",
Language.HI: "hi",
Language.IT: "it",
Language.JA: "ja",
Language.PT: "pt",
Language.ZH: "zh",
}
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
[docs]
@dataclass
class KokoroTTSSettings(TTSSettings):
"""Settings for KokoroTTSService."""
pass
[docs]
class KokoroTTSService(TTSService):
"""Kokoro TTS service implementation.
Provides local text-to-speech synthesis using kokoro-onnx.
Automatically downloads model files on first use.
"""
Settings = KokoroTTSSettings
_settings: Settings
[docs]
def __init__(
self,
*,
voice_id: str | None = None,
model_path: str | None = None,
voices_path: str | None = None,
params: InputParams | None = None,
settings: Settings | None = None,
**kwargs,
):
"""Initialize the Kokoro TTS service.
Args:
voice_id: Voice identifier to use for synthesis.
.. deprecated:: 0.0.105
Use ``settings=KokoroTTSService.Settings(voice=...)`` instead.
model_path: Path to the kokoro ONNX model file. Defaults to auto-downloaded file.
voices_path: Path to the voices binary file. Defaults to auto-downloaded file.
params: Configuration parameters for synthesis.
.. deprecated:: 0.0.105
Use ``settings=KokoroTTSService.Settings(...)`` instead.
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
**kwargs: Additional arguments passed to parent `TTSService`.
"""
# 1. Initialize default_settings with hardcoded defaults
default_settings = self.Settings(
model=None,
voice=None,
language=Language.EN,
)
# 2. Apply direct init arg overrides (deprecated)
if voice_id is not None:
self._warn_init_param_moved_to_settings("voice_id", "voice")
default_settings.voice = voice_id
# 3. Apply params overrides — only if settings not provided
if params is not None:
self._warn_init_param_moved_to_settings("params")
if not settings:
default_settings.language = params.language
# 4. Apply settings delta (canonical API, always wins)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
model_file = Path(model_path) if model_path else KOKORO_CACHE_DIR / "kokoro-v1.0.onnx"
voices = Path(voices_path) if voices_path else KOKORO_CACHE_DIR / "voices-v1.0.bin"
_ensure_model_files(model_file, voices)
self._kokoro = Kokoro(str(model_file), str(voices))
self._resampler = create_stream_resampler()
[docs]
def can_generate_metrics(self) -> bool:
"""Indicate that this service supports TTFB and usage metrics."""
return True
[docs]
def language_to_service_language(self, language: Language) -> str:
"""Convert a Language enum to kokoro-onnx language format.
Args:
language: The language to convert.
Returns:
The kokoro-onnx language code.
"""
return language_to_kokoro_language(language)
[docs]
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Synthesize speech from text using kokoro-onnx.
Uses the async streaming API to generate audio frames.
Args:
text: The text to synthesize.
context_id: Unique identifier for this TTS context.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
try:
await self.start_tts_usage_metrics(text)
voice = assert_given(self._settings.voice)
lang = assert_given(self._settings.language)
if lang is None:
raise ValueError("Kokoro TTS language must be specified")
stream = self._kokoro.create_stream(text, voice=voice, lang=lang, speed=1.0)
async for samples, sample_rate in stream:
await self.stop_ttfb_metrics()
audio_int16 = (samples * 32767).astype(np.int16).tobytes()
audio_data = await self._resampler.resample(
audio_int16, sample_rate, self.sample_rate
)
yield TTSAudioRawFrame(
audio=audio_data,
sample_rate=self.sample_rate,
num_channels=1,
context_id=context_id,
)
except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
finally:
await self.stop_ttfb_metrics()