#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""HTTP-based smart turn analyzer for remote ML inference.
This module provides a smart turn analyzer that sends audio data to remote
HTTP endpoints for ML-based end-of-turn detection.
"""
import asyncio
import io
from typing import Any
import aiohttp
import numpy as np
from loguru import logger
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn, SmartTurnTimeoutException
[docs]
class HttpSmartTurnAnalyzer(BaseSmartTurn):
"""Smart turn analyzer using HTTP-based ML inference.
Sends audio data to remote HTTP endpoints for ML-based end-of-turn
prediction. Handles serialization, HTTP communication, and error recovery.
"""
[docs]
def __init__(
self,
*,
url: str,
aiohttp_session: aiohttp.ClientSession,
headers: dict[str, str] | None = None,
**kwargs,
):
"""Initialize the HTTP smart turn analyzer.
Args:
url: HTTP endpoint URL for the smart turn ML service.
aiohttp_session: HTTP client session for making requests.
headers: Optional HTTP headers to include in requests.
**kwargs: Additional arguments passed to BaseSmartTurn.
"""
super().__init__(**kwargs)
self._url = url
self._headers = headers or {}
self._aiohttp_session = aiohttp_session
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
"""Serialize NumPy audio array to bytes for HTTP transmission."""
logger.trace("Serializing NumPy array to bytes...")
buffer = io.BytesIO()
np.save(buffer, audio_array)
serialized_bytes = buffer.getvalue()
logger.trace(f"Serialized size: {len(serialized_bytes)} bytes")
return serialized_bytes
async def _send_raw_request(self, data_bytes: bytes) -> dict[str, Any]:
"""Send raw audio data to the HTTP endpoint for prediction."""
headers = {"Content-Type": "application/octet-stream"}
headers.update(self._headers)
try:
timeout = aiohttp.ClientTimeout(total=self._params.stop_secs)
async with self._aiohttp_session.post(
self._url, data=data_bytes, headers=headers, timeout=timeout
) as response:
logger.trace("\n--- Response ---")
logger.trace(f"Status Code: {response.status}")
# Check if successful
if response.status != 200:
error_text = await response.text()
logger.trace("Response Content (Error):")
logger.trace(error_text)
if response.status == 500:
logger.warning(f"Smart turn service returned 500 error: {error_text}")
raise Exception(f"Server returned HTTP 500: {error_text}")
else:
response.raise_for_status()
# Process successful response
try:
json_data = await response.json()
logger.trace("Response JSON:")
logger.trace(json_data)
return json_data
except aiohttp.ContentTypeError:
# Non-JSON response
text = await response.text()
logger.trace("Response Content (non-JSON):")
logger.trace(text)
raise Exception(f"Non-JSON response: {text}")
except TimeoutError:
logger.error(f"Request timed out after {self._params.stop_secs} seconds")
raise SmartTurnTimeoutException(f"Request exceeded {self._params.stop_secs} seconds.")
except aiohttp.ClientError as e:
logger.error(f"Failed to send raw request to Daily Smart Turn: {e}")
raise Exception("Failed to send raw request to Daily Smart Turn.")
def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]:
"""Predict end-of-turn using remote HTTP ML service."""
try:
serialized_array = self._serialize_array(audio_array)
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(
self._send_raw_request(serialized_array), loop
)
return future.result()
except Exception as e:
logger.error(f"Smart turn prediction failed: {str(e)}")
# Return an incomplete prediction when a failure occurs
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}