Source code for pipecat.audio.turn.smart_turn.local_smart_turn_v2

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Local PyTorch turn analyzer for on-device ML inference using the smart-turn-v2 model.

This module provides a smart turn analyzer that uses PyTorch models for
local end-of-turn detection without requiring network connectivity.
"""

import warnings
from typing import Any

import numpy as np
from loguru import logger

from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn

try:
    import torch
    import torch.nn.functional as F
    from torch import nn
    from transformers import (
        Wav2Vec2Config,
        Wav2Vec2Model,
        Wav2Vec2PreTrainedModel,
        Wav2Vec2Processor,
    )
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error(
        "In order to use LocalSmartTurnAnalyzerV2, you need to `pip install pipecat-ai[local-smart-turn]`."
    )
    raise Exception(f"Missing module: {e}")


[docs] class LocalSmartTurnAnalyzerV2(BaseSmartTurn): """Local turn analyzer using the smart-turn-v2 PyTorch model. Provides end-of-turn detection using locally-stored PyTorch models, enabling offline operation without network dependencies. Uses Wav2Vec2 architecture for audio sequence classification. .. deprecated:: 0.0.106 LocalSmartTurnAnalyzerV2 is deprecated and will be removed in a future version. Use LocalSmartTurnAnalyzerV3 instead. """
[docs] def __init__(self, *, smart_turn_model_path: str, **kwargs): """Initialize the local PyTorch smart-turn-v2 analyzer. Args: smart_turn_model_path: Path to directory containing the PyTorch model and feature extractor files. If empty, uses default HuggingFace model. **kwargs: Additional arguments passed to BaseSmartTurn. """ super().__init__(**kwargs) with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( "LocalSmartTurnAnalyzerV2 is deprecated and will be removed in a future version. " "Use LocalSmartTurnAnalyzerV3 instead.", DeprecationWarning, stacklevel=2, ) if not smart_turn_model_path: # Define the path to the pretrained model on Hugging Face smart_turn_model_path = "pipecat-ai/smart-turn-v2" logger.debug("Loading Local Smart Turn v2 model...") # Load the pretrained model for sequence classification self._turn_model = _Wav2Vec2ForEndpointing.from_pretrained(smart_turn_model_path) # Load the corresponding feature extractor for preprocessing audio self._turn_processor = Wav2Vec2Processor.from_pretrained(smart_turn_model_path) # Use platform-optimized backend if available (MPS for Apple silicon, CUDA for NVIDIA) self._device = "cpu" if torch.backends.mps.is_available(): self._device = "mps" elif torch.cuda.is_available(): self._device = "cuda" # Move model to selected device and set it to evaluation mode self._turn_model = self._turn_model.to(self._device) self._turn_model.eval() logger.debug("Loaded Local Smart Turn v2")
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]: """Predict end-of-turn using local PyTorch model.""" inputs = self._turn_processor( audio_array, sampling_rate=16000, padding="max_length", truncation=True, max_length=16000 * 16, # 16 seconds at 16kHz return_attention_mask=True, return_tensors="pt", ) # Move inputs to device inputs = {k: v.to(self._device) for k, v in inputs.items()} # Run inference with torch.no_grad(): outputs = self._turn_model(**inputs) # The model returns sigmoid probabilities directly in the logits field probability = outputs["logits"][0].item() # Make prediction (1 for Complete, 0 for Incomplete) prediction = 1 if probability > 0.5 else 0 return { "prediction": prediction, "probability": probability, }
class _Wav2Vec2ForEndpointing(Wav2Vec2PreTrainedModel): def __init__(self, config: Wav2Vec2Config): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) self.pool_attention = nn.Sequential( nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1) ) self.classifier = nn.Sequential( nn.Linear(config.hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 64), nn.GELU(), nn.Linear(64, 1), ) for module in self.classifier: if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.1) if module.bias is not None: module.bias.data.zero_() for module in self.pool_attention: if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.1) if module.bias is not None: module.bias.data.zero_() def attention_pool(self, hidden_states, attention_mask): # Calculate attention weights attention_weights = self.pool_attention(hidden_states) if attention_mask is None: raise ValueError("attention_mask must be provided for attention pooling") attention_weights = attention_weights + ( (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 ) attention_weights = F.softmax(attention_weights, dim=1) # Apply attention to hidden states weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) return weighted_sum def forward(self, input_values, attention_mask=None, labels=None): outputs = self.wav2vec2(input_values, attention_mask=attention_mask) hidden_states = outputs[0] # Create transformer padding mask if attention_mask is not None: input_length = attention_mask.size(1) hidden_length = hidden_states.size(1) ratio = input_length / hidden_length indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long() attention_mask = attention_mask[:, indices] attention_mask = attention_mask.bool() else: attention_mask = None pooled = self.attention_pool(hidden_states, attention_mask) logits = self.classifier(pooled) if torch.isnan(logits).any(): raise ValueError("NaN values detected in logits") if labels is not None: # Calculate positive sample weight based on batch statistics pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0) loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight) labels = labels.float() loss = loss_fct(logits.view(-1), labels.view(-1)) # Add L2 regularization for classifier layers l2_lambda = 0.01 l2_reg = torch.tensor(0.0, device=logits.device) for param in self.classifier.parameters(): l2_reg += torch.norm(param) loss += l2_lambda * l2_reg probs = torch.sigmoid(logits.detach()) return {"loss": loss, "logits": probs} probs = torch.sigmoid(logits) return {"logits": probs}