Source code for pipecat.transports.websocket.client

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""WebSocket client transport implementation for Pipecat.

This module provides a WebSocket client transport that enables bidirectional
communication over WebSocket connections, with support for audio streaming,
frame serialization, and connection management.
"""

import asyncio
import io
import time
import wave
from collections.abc import Awaitable, Callable

import websockets
from loguru import logger
from pydantic.main import BaseModel
from websockets.asyncio.client import connect as websocket_connect

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InputAudioRawFrame,
    InputTransportMessageFrame,
    OutputAudioRawFrame,
    OutputTransportMessageFrame,
    OutputTransportMessageUrgentFrame,
    StartFrame,
)
from pipecat.processors.frame_processor import FrameProcessorSetup
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.serializers.protobuf import ProtobufFrameSerializer
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.utils.asyncio.task_manager import BaseTaskManager


[docs] class WebsocketClientParams(TransportParams): """Configuration parameters for WebSocket client transport. Parameters: add_wav_header: Whether to add WAV headers to audio frames. serializer: Frame serializer for encoding/decoding messages. """ add_wav_header: bool = True additional_headers: dict[str, str] | None = None serializer: FrameSerializer | None = None
[docs] class WebsocketClientCallbacks(BaseModel): """Callback functions for WebSocket client events. Parameters: on_connected: Called when WebSocket connection is established. on_disconnected: Called when WebSocket connection is closed. on_message: Called when a message is received from the WebSocket. """ on_connected: Callable[[websockets.WebSocketClientProtocol], Awaitable[None]] on_disconnected: Callable[[websockets.WebSocketClientProtocol], Awaitable[None]] on_message: Callable[[websockets.WebSocketClientProtocol, websockets.Data], Awaitable[None]]
[docs] class WebsocketClientSession: """Manages a WebSocket client connection session. Handles connection lifecycle, message sending/receiving, and provides callback mechanisms for connection events. """
[docs] def __init__( self, uri: str, params: WebsocketClientParams, callbacks: WebsocketClientCallbacks, transport_name: str, ): """Initialize the WebSocket client session. Args: uri: The WebSocket URI to connect to. params: Configuration parameters for the session. callbacks: Callback functions for session events. transport_name: Name of the parent transport for logging. """ self._uri = uri self._params = params self._callbacks = callbacks self._transport_name = transport_name self._leave_counter = 0 self._task_manager: BaseTaskManager | None = None self._websocket: websockets.WebSocketClientProtocol | None = None
@property def task_manager(self) -> BaseTaskManager: """Get the task manager for this session. Returns: The task manager instance. Raises: Exception: If task manager is not initialized. """ if not self._task_manager: raise Exception( f"{self._transport_name}::WebsocketClientSession: TaskManager not initialized (pipeline not started?)" ) return self._task_manager
[docs] async def setup(self, task_manager: BaseTaskManager): """Set up the session with a task manager. Args: task_manager: The task manager to use for session tasks. """ self._leave_counter += 1 if not self._task_manager: self._task_manager = task_manager
[docs] async def connect(self): """Connect to the WebSocket server.""" if self._websocket: return try: self._websocket = await websocket_connect( uri=self._uri, open_timeout=10, additional_headers=self._params.additional_headers, ) self._client_task = self.task_manager.create_task( self._client_task_handler(), f"{self._transport_name}::WebsocketClientSession::_client_task_handler", ) await self._callbacks.on_connected(self._websocket) except TimeoutError: logger.error(f"Timeout connecting to {self._uri}")
[docs] async def disconnect(self): """Disconnect from the WebSocket server.""" self._leave_counter -= 1 if not self._websocket or self._leave_counter > 0: return await self.task_manager.cancel_task(self._client_task) await self._websocket.close() self._websocket = None
[docs] async def send(self, message: websockets.Data) -> bool: """Send a message through the WebSocket connection. Args: message: The message data to send. """ result = False try: if self._websocket: await self._websocket.send(message) result = True except Exception as e: logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})") finally: return result
@property def is_connected(self) -> bool: """Check if the WebSocket is currently connected. Returns: True if the WebSocket is in connected state. """ return self._websocket.state == websockets.State.OPEN if self._websocket else False @property def is_closing(self) -> bool: """Check if the WebSocket is currently closing. Returns: True if the WebSocket is in the process of closing. """ return self._websocket.state == websockets.State.CLOSING if self._websocket else False async def _client_task_handler(self): """Handle incoming messages from the WebSocket connection.""" try: # Handle incoming messages async for message in self._websocket: await self._callbacks.on_message(self._websocket, message) except Exception as e: logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})") await self._callbacks.on_disconnected(self._websocket) def __str__(self): """String representation of the WebSocket client session.""" return f"{self._transport_name}::WebsocketClientSession"
[docs] class WebsocketClientInputTransport(BaseInputTransport): """WebSocket client input transport for receiving frames. Handles incoming WebSocket messages, deserializes them to frames, and pushes them downstream in the processing pipeline. """
[docs] def __init__( self, transport: BaseTransport, session: WebsocketClientSession, params: WebsocketClientParams, ): """Initialize the WebSocket client input transport. Args: transport: The parent transport instance. session: The WebSocket session to use for communication. params: Configuration parameters for the transport. """ super().__init__(params) self._transport = transport self._session = session self._params = params # Whether we have seen a StartFrame already. self._initialized = False
[docs] async def setup(self, setup: FrameProcessorSetup): """Set up the input transport with the frame processor setup. Args: setup: The frame processor setup configuration. """ await super().setup(setup) await self._session.setup(setup.task_manager)
[docs] async def start(self, frame: StartFrame): """Start the input transport and initialize the WebSocket connection. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) if self._initialized: return self._initialized = True if self._params.serializer: await self._params.serializer.setup(frame) await self._session.connect() await self.set_transport_ready(frame)
[docs] async def stop(self, frame: EndFrame): """Stop the input transport and disconnect from WebSocket. Args: frame: The end frame signaling transport shutdown. """ await super().stop(frame) await self._session.disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the input transport and disconnect from WebSocket. Args: frame: The cancel frame signaling immediate cancellation. """ await super().cancel(frame) await self._session.disconnect()
[docs] async def cleanup(self): """Clean up the input transport resources.""" await super().cleanup() await self._transport.cleanup()
[docs] async def on_message(self, websocket, message): """Handle incoming WebSocket messages. Args: websocket: The WebSocket connection that received the message. message: The received message data. """ if not self._params.serializer: return frame = await self._params.serializer.deserialize(message) if not frame: return if isinstance(frame, InputAudioRawFrame) and self._params.audio_in_enabled: await self.push_audio_frame(frame) elif isinstance(frame, InputTransportMessageFrame): await self.broadcast_frame(InputTransportMessageFrame, message=frame.message) else: await self.push_frame(frame)
[docs] class WebsocketClientOutputTransport(BaseOutputTransport): """WebSocket client output transport for sending frames. Handles outgoing frames, serializes them for WebSocket transmission, and manages audio streaming with proper timing simulation. """
[docs] def __init__( self, transport: BaseTransport, session: WebsocketClientSession, params: WebsocketClientParams, ): """Initialize the WebSocket client output transport. Args: transport: The parent transport instance. session: The WebSocket session to use for communication. params: Configuration parameters for the transport. """ super().__init__(params) self._transport = transport self._session = session self._params = params # write_audio_frame() is called quickly, as soon as we get audio # (e.g. from the TTS), and since this is just a network connection we # would be sending it to quickly. Instead, we want to block to emulate # an audio device, this is what the send interval is. It will be # computed on StartFrame. self._send_interval = 0 self._next_send_time = 0 # Whether we have seen a StartFrame already. self._initialized = False
[docs] async def setup(self, setup: FrameProcessorSetup): """Set up the output transport with the frame processor setup. Args: setup: The frame processor setup configuration. """ await super().setup(setup) await self._session.setup(setup.task_manager)
[docs] async def start(self, frame: StartFrame): """Start the output transport and initialize the WebSocket connection. Args: frame: The start frame containing initialization parameters. """ await super().start(frame) if self._initialized: return self._initialized = True self._send_interval = (self.audio_chunk_size / self.sample_rate) / 2 if self._params.serializer: await self._params.serializer.setup(frame) await self._session.connect() await self.set_transport_ready(frame)
[docs] async def stop(self, frame: EndFrame): """Stop the output transport and disconnect from WebSocket. Args: frame: The end frame signaling transport shutdown. """ await super().stop(frame) await self._session.disconnect()
[docs] async def cancel(self, frame: CancelFrame): """Cancel the output transport and disconnect from WebSocket. Args: frame: The cancel frame signaling immediate cancellation. """ await super().cancel(frame) await self._session.disconnect()
[docs] async def cleanup(self): """Clean up the output transport resources.""" await super().cleanup() await self._transport.cleanup()
[docs] async def send_message( self, frame: OutputTransportMessageFrame | OutputTransportMessageUrgentFrame ): """Send a transport message through the WebSocket. Args: frame: The transport message frame to send. """ await self._write_frame(frame)
[docs] async def write_audio_frame(self, frame: OutputAudioRawFrame) -> bool: """Write an audio frame to the WebSocket with optional WAV header. Args: frame: The output audio frame to write. Returns: True if the audio frame was written successfully, False otherwise. """ if self._session.is_closing or not self._session.is_connected: return False frame = OutputAudioRawFrame( audio=frame.audio, sample_rate=self.sample_rate, num_channels=self._params.audio_out_channels, ) if self._params.add_wav_header: with io.BytesIO() as buffer: with wave.open(buffer, "wb") as wf: wf.setsampwidth(2) wf.setnchannels(frame.num_channels) wf.setframerate(frame.sample_rate) wf.writeframes(frame.audio) wav_frame = OutputAudioRawFrame( buffer.getvalue(), sample_rate=frame.sample_rate, num_channels=frame.num_channels, ) frame = wav_frame await self._write_frame(frame) # Simulate audio playback with a sleep. await self._write_audio_sleep() return True
async def _write_frame(self, frame: Frame): """Write a frame to the WebSocket after serialization.""" if self._session.is_closing or not self._session.is_connected: return if not self._params.serializer: return payload = await self._params.serializer.serialize(frame) if payload: await self._session.send(payload) async def _write_audio_sleep(self): """Simulate audio playback timing with sleep delays.""" # Simulate a clock. current_time = time.monotonic() sleep_duration = max(0, self._next_send_time - current_time) await asyncio.sleep(sleep_duration) if sleep_duration == 0: self._next_send_time = time.monotonic() + self._send_interval else: self._next_send_time += self._send_interval
[docs] class WebsocketClientTransport(BaseTransport): """WebSocket client transport for bidirectional communication. Provides a complete WebSocket client transport implementation with input and output capabilities, connection management, and event handling. Event handlers available: - on_connected(transport): Connected to WebSocket server - on_disconnected(transport): Disconnected from WebSocket server Example:: @transport.event_handler("on_connected") async def on_connected(transport): ... """
[docs] def __init__( self, uri: str, params: WebsocketClientParams | None = None, ): """Initialize the WebSocket client transport. Args: uri: The WebSocket URI to connect to. params: Optional configuration parameters for the transport. """ super().__init__() self._params = params or WebsocketClientParams() self._params.serializer = self._params.serializer or ProtobufFrameSerializer() callbacks = WebsocketClientCallbacks( on_connected=self._on_connected, on_disconnected=self._on_disconnected, on_message=self._on_message, ) self._session = WebsocketClientSession(uri, self._params, callbacks, self.name) self._input: WebsocketClientInputTransport | None = None self._output: WebsocketClientOutputTransport | None = None # Register supported handlers. The user will only be able to register # these handlers. self._register_event_handler("on_connected") self._register_event_handler("on_disconnected")
[docs] def input(self) -> WebsocketClientInputTransport: """Get the input transport for receiving frames. Returns: The WebSocket client input transport instance. """ if not self._input: self._input = WebsocketClientInputTransport(self, self._session, self._params) return self._input
[docs] def output(self) -> WebsocketClientOutputTransport: """Get the output transport for sending frames. Returns: The WebSocket client output transport instance. """ if not self._output: self._output = WebsocketClientOutputTransport(self, self._session, self._params) return self._output
async def _on_connected(self, websocket): """Handle WebSocket connection established event.""" await self._call_event_handler("on_connected", websocket) async def _on_disconnected(self, websocket): """Handle WebSocket connection closed event.""" await self._call_event_handler("on_disconnected", websocket) async def _on_message(self, websocket, message): """Handle incoming WebSocket message.""" if self._input: await self._input.on_message(websocket, message)