#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Krisp turn analyzer for end-of-turn detection using Krisp VIVA SDK.
This module provides a turn analyzer implementation using Krisp's turn detection
v3 (Tt) API to determine when a user has finished speaking in a conversation.
The Tt API accepts an external VAD flag alongside audio frames, allowing the
model to leverage voice activity information for more accurate turn detection.
Note: This analyzer uses a different model than KrispVivaFilter. The model path
can be specified via the KRISP_VIVA_TURN_MODEL_PATH environment variable or
passed directly to the constructor.
"""
import os
import time
import numpy as np
from loguru import logger
from pipecat.audio.krisp_instance import (
KrispVivaSDKManager,
int_to_krisp_frame_duration,
int_to_krisp_sample_rate,
)
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, BaseTurnParams, EndOfTurnState
from pipecat.metrics.metrics import MetricsData, TurnMetricsData
try:
import krisp_audio
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.")
raise ImportError(f"Missing module: {e}") from e
[docs]
class KrispTurnParams(BaseTurnParams):
"""Configuration parameters for Krisp turn analysis.
Parameters:
threshold: Probability threshold for turn completion (0.0 to 1.0).
Higher values require more confidence before marking turn as complete.
frame_duration_ms: Frame duration in milliseconds for turn detection.
Supported values: 10, 15, 20, 30, 32.
"""
threshold: float = 0.5
frame_duration_ms: int = 20
[docs]
class KrispVivaTurn(BaseTurnAnalyzer):
"""Turn analyzer using Krisp VIVA SDK for end-of-turn detection.
Uses Krisp's turn detection v3 (Tt) API to determine when a user has
finished speaking. The Tt API receives an external VAD flag with each
audio frame, which the ``is_speech`` parameter of ``append_audio``
provides. This analyzer requires a valid Krisp model file to operate.
"""
[docs]
def __init__(
self,
*,
model_path: str | None = None,
sample_rate: int | None = None,
params: KrispTurnParams | None = None,
api_key: str = "",
) -> None:
"""Initialize the Krisp turn analyzer.
Args:
model_path: Path to the Krisp turn detection model file (.kef extension).
If None, uses KRISP_VIVA_TURN_MODEL_PATH environment variable.
sample_rate: Optional initial sample rate for audio processing.
If provided, this will be used as the fixed sample rate.
params: Configuration parameters for turn analysis behavior.
api_key: Krisp SDK API key. If empty, falls back to
the KRISP_VIVA_API_KEY environment variable.
Raises:
ValueError: If model_path is not provided and KRISP_VIVA_TURN_MODEL_PATH is not set.
Exception: If model file doesn't have .kef extension.
FileNotFoundError: If model file doesn't exist.
RuntimeError: If Krisp SDK initialization fails.
"""
super().__init__(sample_rate=sample_rate)
# Acquire SDK reference (will initialize on first call)
try:
KrispVivaSDKManager.acquire(api_key=api_key)
self._sdk_acquired = True
except Exception as e:
self._sdk_acquired = False
raise RuntimeError(f"Failed to initialize Krisp SDK: {e}")
try:
# Set model path, checking environment if not specified
self._model_path = model_path or os.getenv("KRISP_VIVA_TURN_MODEL_PATH")
if not self._model_path:
logger.error(
"Model path is not provided and KRISP_VIVA_TURN_MODEL_PATH is not set."
)
raise ValueError("Model path for KrispVivaTurn must be provided.")
if not self._model_path.endswith(".kef"):
raise Exception("Model is expected with .kef extension")
if not os.path.isfile(self._model_path):
raise FileNotFoundError(f"Model file not found: {self._model_path}")
self._params = params or KrispTurnParams()
self._tt_session = None
self._preload_tt_session = None
self._samples_per_frame = None
self._audio_buffer = bytearray()
# State tracking
self._speech_triggered = False
self._last_probability = None
self._frame_probabilities = []
self._last_state = EndOfTurnState.INCOMPLETE
self._speech_stopped_time: float | None = None
self._e2e_processing_time_ms: float | None = None
self._last_metrics: TurnMetricsData | None = None
# Create session with provided sample rate or default to 16000 Hz
# This preloads the model to improve latency when set_sample_rate is called later
preload_sample_rate = sample_rate if sample_rate else 16000
try:
self._preload_tt_session = self._create_tt_session(preload_sample_rate)
except Exception as e:
logger.error(f"Failed to create turn detection session: {e}", exc_info=True)
self._preload_tt_session = None
raise RuntimeError(f"Failed to create turn detection session: {e}") from e
except Exception:
# If initialization fails, release the SDK reference
if self._sdk_acquired:
KrispVivaSDKManager.release()
self._sdk_acquired = False
raise
[docs]
async def cleanup(self):
"""Release SDK reference when analyzer is destroyed."""
if self._sdk_acquired:
try:
# Clean up session first
if hasattr(self, "_tt_session") and self._tt_session is not None:
self._tt_session = None
if hasattr(self, "_preload_tt_session") and self._preload_tt_session is not None:
self._preload_tt_session = None
KrispVivaSDKManager.release()
self._sdk_acquired = False
except Exception as e:
logger.error(f"Error in __del__: {e}", exc_info=True)
def _create_tt_session(self, sample_rate: int):
"""Create a turn detection session with the specified sample rate.
Args:
sample_rate: Sample rate for the session.
Returns:
krisp_audio.TtFloat instance.
Raises:
ValueError: If sample rate or frame duration is not supported.
RuntimeError: If session creation fails.
"""
try:
model_info = krisp_audio.ModelInfo()
model_info.path = self._model_path
tt_cfg = krisp_audio.TtSessionConfig()
tt_cfg.inputSampleRate = int_to_krisp_sample_rate(sample_rate)
tt_cfg.inputFrameDuration = int_to_krisp_frame_duration(self._params.frame_duration_ms)
tt_cfg.modelInfo = model_info
# Calculate samples per frame for this sample rate
self._samples_per_frame = int((sample_rate * self._params.frame_duration_ms) / 1000)
tt_instance = krisp_audio.TtFloat.create(tt_cfg)
return tt_instance
except Exception as e:
logger.error(f"Failed to create Krisp turn detection session: {e}", exc_info=True)
raise RuntimeError(f"Failed to create Krisp turn detection session: {e}") from e
[docs]
def set_sample_rate(self, sample_rate: int):
"""Set the sample rate and create/update the turn detection session.
Args:
sample_rate: The sample rate to set.
"""
if self._sample_rate == sample_rate:
return
super().set_sample_rate(sample_rate)
# Create session when sample rate is set
try:
self._tt_session = self._create_tt_session(self._sample_rate)
self.clear()
except Exception as e:
logger.error(f"Failed to create turn detection session: {e}", exc_info=True)
self._tt_session = None
@property
def frame_probabilities(self) -> list:
"""Get all probabilities from the last append_audio call.
Returns:
List of probability values for each frame processed in the last append_audio call.
"""
return self._frame_probabilities
@property
def last_probability(self) -> float | None:
"""Get the last turn probability value computed.
Returns:
Last probability value, or None if no frames have been processed yet.
"""
return self._last_probability
@property
def speech_triggered(self) -> bool:
"""Check if speech has been detected and triggered analysis.
Returns:
True if speech has been detected and turn analysis is active.
"""
return self._speech_triggered
@property
def params(self) -> KrispTurnParams:
"""Get the current turn analyzer parameters.
Returns:
Current turn analyzer configuration parameters.
"""
return self._params
[docs]
def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState:
"""Append audio data for turn analysis.
Args:
buffer: Raw audio data bytes to append for analysis.
is_speech: Whether the audio buffer contains detected speech.
Returns:
Current end-of-turn state after processing the audio.
"""
if self._tt_session is None:
logger.warning("Turn detection session not initialized, returning INCOMPLETE")
self._last_state = EndOfTurnState.INCOMPLETE
return EndOfTurnState.INCOMPLETE
if self._samples_per_frame is None:
logger.warning("Samples per frame not initialized, returning INCOMPLETE")
self._last_state = EndOfTurnState.INCOMPLETE
return EndOfTurnState.INCOMPLETE
try:
# Add incoming audio to our buffer
self._audio_buffer.extend(buffer)
# Clear frame probabilities from previous call
self._frame_probabilities = []
total_samples = len(self._audio_buffer) // 2 # 2 bytes per int16 sample
num_complete_frames = total_samples // self._samples_per_frame
if num_complete_frames == 0:
# Not enough samples for a complete frame yet, return current state
self._last_state = EndOfTurnState.INCOMPLETE
return EndOfTurnState.INCOMPLETE
complete_samples_count = num_complete_frames * self._samples_per_frame
bytes_to_process = complete_samples_count * 2 # 2 bytes per sample
audio_to_process = bytes(self._audio_buffer[:bytes_to_process])
self._audio_buffer = self._audio_buffer[bytes_to_process:]
audio_int16 = np.frombuffer(audio_to_process, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
frames = audio_float32.reshape(-1, self._samples_per_frame)
state = EndOfTurnState.INCOMPLETE
# Process each complete frame
for frame in frames:
if is_speech:
# Track speech start time
if not self._speech_triggered:
logger.trace("Speech detected, turn analysis started")
self._e2e_processing_time_ms = None
self._speech_triggered = True
# Reset speech stopped time when speech resumes
self._speech_stopped_time = None
else:
# Record the moment speech transitions to non-speech
if self._speech_triggered and self._speech_stopped_time is None:
self._speech_stopped_time = time.perf_counter()
# Note: We don't immediately mark as complete on silence detection.
# Instead, we wait for the model's probability check below to confirm
# end-of-turn based on the threshold.
prob = self._tt_session.process(frame.tolist(), is_speech, False)
# Store the probability for external access
self._last_probability = prob
self._frame_probabilities.append(prob)
# Check if turn is complete based on probability threshold
# Only mark as complete if we've detected speech and the model
# confirms with sufficient confidence
if self._speech_triggered and prob >= self._params.threshold:
# Calculate e2e processing time: time from speech stop to threshold crossing
if self._speech_stopped_time is not None:
self._e2e_processing_time_ms = (
time.perf_counter() - self._speech_stopped_time
) * 1000
self._last_metrics = TurnMetricsData(
processor="KrispVivaTurn",
is_complete=True,
probability=prob,
e2e_processing_time_ms=self._e2e_processing_time_ms,
)
logger.debug(f"Krisp turn complete")
state = EndOfTurnState.COMPLETE
self.clear()
break
# Store the last state for analyze_end_of_turn()
self._last_state = state
return state
except Exception as e:
logger.error(f"Error during Krisp turn detection: {e}", exc_info=True)
error_state = EndOfTurnState.INCOMPLETE
self._last_state = error_state
return error_state
[docs]
async def analyze_end_of_turn(self) -> tuple[EndOfTurnState, MetricsData | None]:
"""Analyze the current audio state to determine if turn has ended.
Returns:
Tuple containing the end-of-turn state and optional metrics data.
Returns the last state determined by append_audio().
"""
# For real-time processing, the state is determined in append_audio.
# Consume metrics so they aren't pushed twice.
metrics = self._last_metrics
self._last_metrics = None
return self._last_state, metrics
[docs]
def clear(self):
"""Reset the turn analyzer to its initial state."""
self._speech_triggered = False
self._audio_buffer.clear()
self._last_state = EndOfTurnState.INCOMPLETE
self._speech_stopped_time = None