#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""SmallWebRTC request handler for managing peer connections.
This module provides a client for handling web requests and managing WebRTC connections.
"""
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any
from aiortc.sdp import candidate_from_sdp
from fastapi import HTTPException
from loguru import logger
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
[docs]
@dataclass
class SmallWebRTCRequest:
"""Small WebRTC transport session arguments for the runner.
Parameters:
sdp: The SDP string (Session Description Protocol).
type: The type of the SDP, either "offer" or "answer".
pc_id: Optional identifier for the peer connection.
restart_pc: Optional whether to restart the peer connection.
request_data: Optional custom data sent by the customer.
"""
sdp: str
type: str
pc_id: str | None = None
restart_pc: bool | None = None
request_data: Any | None = None
[docs]
@classmethod
def from_dict(cls, data: dict):
"""Accept both snake_case and camelCase for the request_data field."""
if "requestData" in data and "request_data" not in data:
data["request_data"] = data.pop("requestData")
return cls(**data)
[docs]
@dataclass
class IceCandidate:
"""The remote ice candidate object received from the peer connection.
Parameters:
candidate: The ice candidate patch SDP string (Session Description Protocol).
sdp_mid: The SDP mid for the candidate patch.
sdp_mline_index: The SDP mline index for the candidate patch.
"""
candidate: str
sdp_mid: str
sdp_mline_index: int
[docs]
@dataclass
class SmallWebRTCPatchRequest:
"""Small WebRTC transport session arguments for the runner.
Parameters:
pc_id: Identifier for the peer connection.
candidates: A list of ICE candidate patches.
"""
pc_id: str
candidates: list[IceCandidate]
[docs]
class ConnectionMode(Enum):
"""Enum defining the connection handling modes."""
SINGLE = "single" # Only one active connection allowed
MULTIPLE = "multiple" # Multiple simultaneous connections allowed
[docs]
class SmallWebRTCRequestHandler:
"""SmallWebRTC request handler for managing peer connections.
This class is responsible for:
- Handling incoming SmallWebRTC requests.
- Creating and managing WebRTC peer connections.
- Supporting ESP32-specific SDP munging if enabled.
- Invoking callbacks for newly initialized connections.
- Supporting both single and multiple connection modes.
"""
[docs]
def __init__(
self,
ice_servers: list[IceServer] | None = None,
esp32_mode: bool = False,
host: str | None = None,
connection_mode: ConnectionMode = ConnectionMode.MULTIPLE,
) -> None:
"""Initialize a SmallWebRTC request handler.
Args:
ice_servers (Optional[List[IceServer]]): List of ICE servers to use for WebRTC
connections.
esp32_mode (bool): If True, enables ESP32-specific SDP munging.
host (Optional[str]): Host address used for SDP munging in ESP32 mode.
Ignored if `esp32_mode` is False.
connection_mode (ConnectionMode): Mode of operation for handling connections.
SINGLE allows only one active connection, MULTIPLE allows several.
"""
self._ice_servers = ice_servers
self._esp32_mode = esp32_mode
self._host = host
self._connection_mode = connection_mode
# Store connections by pc_id
self._pcs_map: dict[str, SmallWebRTCConnection] = {}
def _check_single_connection_constraints(self, pc_id: str | None) -> None:
"""Check if the connection request satisfies single connection mode constraints.
Args:
pc_id: The peer connection ID from the request
Raises:
HTTPException: If constraints are violated in single connection mode
"""
if self._connection_mode != ConnectionMode.SINGLE:
return
if not self._pcs_map: # No existing connections
return
# Get the existing connection (should be only one in single mode)
existing_connection = next(iter(self._pcs_map.values()))
if existing_connection.pc_id != pc_id and pc_id:
logger.warning(
f"Connection pc_id mismatch: existing={existing_connection.pc_id}, received={pc_id}"
)
raise HTTPException(status_code=400, detail="PC ID mismatch with existing connection")
if not pc_id:
logger.warning(
"Cannot create new connection: existing connection found but no pc_id received"
)
raise HTTPException(
status_code=400,
detail="Cannot create new connection with existing connection active",
)
[docs]
def update_ice_servers(self, ice_servers: list[IceServer] | None = None):
"""Update the list of ICE servers used for WebRTC connections."""
self._ice_servers = ice_servers
[docs]
async def handle_web_request(
self,
request: SmallWebRTCRequest,
webrtc_connection_callback: Callable[[Any], Awaitable[None]],
) -> dict[str, str] | None:
"""Handle a SmallWebRTC request and resolve the pending answer.
This method will:
- Reuse an existing WebRTC connection if `pc_id` exists.
- Otherwise, create a new `SmallWebRTCConnection`.
- Invoke the provided callback with the connection.
- Manage ESP32-specific munging if enabled.
- Enforce single/multiple connection mode constraints.
Args:
request (SmallWebRTCRequest): The incoming WebRTC request, containing
SDP, type, and optionally a `pc_id`.
webrtc_connection_callback (Callable[[Any], Awaitable[None]]): An
asynchronous callback function that is invoked with the WebRTC connection.
Returns:
Dictionary containing SDP answer, type, and peer connection ID,
or None if no answer is available.
Raises:
HTTPException: If connection mode constraints are violated
Exception: Any exception raised during request handling or callback execution
will be logged and propagated.
"""
try:
pc_id = request.pc_id
# Check connection mode constraints first
self._check_single_connection_constraints(pc_id)
# After constraints are satisfied, get the existing connection if any
existing_connection = self._pcs_map.get(pc_id) if pc_id else None
if existing_connection:
pipecat_connection = existing_connection
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request.sdp,
type=request.type,
restart_pc=request.restart_pc or False,
)
else:
pipecat_connection = SmallWebRTCConnection(ice_servers=self._ice_servers)
await pipecat_connection.initialize(sdp=request.sdp, type=request.type)
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
self._pcs_map.pop(webrtc_connection.pc_id, None)
# Invoke callback provided in runner arguments
try:
await webrtc_connection_callback(pipecat_connection)
logger.debug(
f"webrtc_connection_callback executed successfully for peer: {pipecat_connection.pc_id}"
)
except Exception as callback_error:
logger.error(
f"webrtc_connection_callback failed for peer {pipecat_connection.pc_id}: {callback_error}"
)
answer = pipecat_connection.get_answer()
if self._esp32_mode:
from pipecat.runner.utils import smallwebrtc_sdp_munging
answer["sdp"] = smallwebrtc_sdp_munging(answer["sdp"], self._host)
self._pcs_map[answer["pc_id"]] = pipecat_connection
return answer
except Exception as e:
logger.error(f"Error processing SmallWebRTC request: {e}")
logger.debug(f"SmallWebRTC request details: {request}")
raise
[docs]
async def handle_patch_request(self, request: SmallWebRTCPatchRequest):
"""Handle a SmallWebRTC patch candidate request."""
peer_connection = self._pcs_map.get(request.pc_id)
if not peer_connection:
raise HTTPException(status_code=404, detail="Peer connection not found")
for c in request.candidates:
candidate = candidate_from_sdp(c.candidate)
candidate.sdpMid = c.sdp_mid
candidate.sdpMLineIndex = c.sdp_mline_index
await peer_connection.add_ice_candidate(candidate)
[docs]
async def close(self):
"""Clear the connection map."""
coros = [pc.disconnect() for pc in self._pcs_map.values()]
await asyncio.gather(*coros)
self._pcs_map.clear()