#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.
This module provides a smart turn analyzer that uses an ONNX model for
local end-of-turn detection without requiring network connectivity.
"""
from typing import Any
import numpy as np
import onnxruntime as ort
import soxr
from loguru import logger
from transformers import WhisperFeatureExtractor
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
from pipecat.utils.env import env_truthy
# The Whisper-based ONNX model expects 16 kHz audio input.
_MODEL_SAMPLE_RATE = 16000
[docs]
class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
"""Local turn analyzer using the smart-turn-v3 ONNX model.
Provides end-of-turn detection using locally-stored ONNX model,
enabling offline operation without network dependencies.
"""
[docs]
def __init__(self, *, smart_turn_model_path: str | None = None, cpu_count: int = 1, **kwargs):
"""Initialize the local ONNX smart-turn-v3 analyzer.
Args:
smart_turn_model_path: Path to the ONNX model file. If this is not
set, the bundled smart-turn-v3.2-cpu model will be used.
cpu_count: The number of CPUs to use for inference. Defaults to 1.
**kwargs: Additional arguments passed to BaseSmartTurn.
"""
super().__init__(**kwargs)
self._log_data = env_truthy("PIPECAT_SMART_TURN_LOG_DATA", default=False)
if not smart_turn_model_path:
# Load bundled model
model_name = "smart-turn-v3.2-cpu.onnx"
package_path = "pipecat.audio.turn.smart_turn.data"
try:
import importlib_resources as impresources
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
except BaseException:
from importlib import resources as impresources
try:
with impresources.path(package_path, model_name) as f:
smart_turn_model_path = f
except BaseException:
smart_turn_model_path = str(
impresources.files(package_path).joinpath(model_name)
)
logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...")
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
so.inter_op_num_threads = 1
so.intra_op_num_threads = cpu_count
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
logger.debug("Loaded Local Smart Turn v3.x")
def _write_audio_to_wav(
self, audio_array: np.ndarray, sample_rate: int = _MODEL_SAMPLE_RATE, suffix: str = ""
) -> None:
"""Write audio data to a WAV file in a background thread.
Args:
audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]).
sample_rate: The sample rate of the audio data.
suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded").
"""
import os
import threading
import wave
from datetime import datetime
# Generate filename with current timestamp (millisecond precision)
timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3]
log_dir = "./smart_turn_audio_log"
os.makedirs(log_dir, exist_ok=True)
filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav")
# Make a copy of the audio data to avoid issues with the array being modified
audio_copy = audio_array.copy()
def write_wav():
try:
# Convert float32 audio to int16 for WAV file
audio_int16 = (audio_copy * 32767).astype(np.int16)
with wave.open(filename, "wb") as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 2 bytes for int16
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
logger.debug(f"Wrote audio to {filename}")
except Exception as e:
logger.error(f"Failed to write audio to {filename}: {e}")
# Start background thread to write the WAV file
thread = threading.Thread(target=write_wav, daemon=True)
thread.start()
def _resample_to_model_rate(self, audio_array: np.ndarray) -> np.ndarray:
"""Resample audio to the model's expected sample rate (16 kHz).
Args:
audio_array: Audio data as a float32 numpy array.
Returns:
Resampled audio array at 16 kHz.
"""
actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE
if actual_rate == _MODEL_SAMPLE_RATE:
return audio_array
return soxr.resample(audio_array, actual_rate, _MODEL_SAMPLE_RATE, quality="VHQ")
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
"""Predict end-of-turn using local ONNX model."""
def truncate_audio_to_last_n_seconds(
audio_array, n_seconds=8, sample_rate=_MODEL_SAMPLE_RATE
):
"""Truncate audio to last n seconds or pad with zeros to meet n seconds."""
max_samples = n_seconds * sample_rate
if len(audio_array) > max_samples:
return audio_array[-max_samples:]
elif len(audio_array) < max_samples:
# Pad with zeros at the beginning
padding = max_samples - len(audio_array)
return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0)
return audio_array
audio_for_logging = audio_array
actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE
# Resample to 16 kHz if the pipeline uses a different sample rate
audio_array = self._resample_to_model_rate(audio_array)
# Truncate to 8 seconds (keeping the end) or pad to 8 seconds
audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8)
# Process audio using Whisper's feature extractor
inputs = self._feature_extractor(
audio_array,
sampling_rate=_MODEL_SAMPLE_RATE,
return_tensors="np",
padding="max_length",
max_length=8 * _MODEL_SAMPLE_RATE,
truncation=True,
do_normalize=True,
)
# Extract features and ensure correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0) # Add batch dimension
# Run ONNX inference
outputs = self._session.run(None, {"input_features": input_features})
# Extract probability (ONNX model returns sigmoid probabilities)
probability = outputs[0][0].item()
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if probability > 0.5 else 0
if self._log_data:
suffix = "_complete" if prediction == 1 else "_incomplete"
self._write_audio_to_wav(audio_for_logging, sample_rate=actual_rate, suffix=suffix)
return {
"prediction": prediction,
"probability": probability,
}