Source code for pipecat.services.websocket_service

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Base websocket service with automatic reconnection and error handling."""

import asyncio
import time
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable

import websockets
from loguru import logger
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from websockets.protocol import State

from pipecat.frames.frames import ErrorFrame
from pipecat.utils.network import exponential_backoff_time


[docs] class WebsocketService(ABC): """Base class for websocket-based services with automatic reconnection. Provides websocket connection management, automatic reconnection with exponential backoff, connection verification, and error handling. Subclasses implement service-specific connection and message handling logic. """ # Rapid failure detection: when a server accepts the WebSocket handshake but # immediately closes the connection (e.g. invalid API key, policy rejection), # exponential backoff won't help because the handshake keeps succeeding. We # detect this by tracking how long the connection survives after being established. _MIN_STABLE_CONNECTION_DURATION = 5.0 # seconds _MAX_CONSECUTIVE_QUICK_FAILURES = 3
[docs] def __init__(self, *, reconnect_on_error: bool = True, **kwargs): """Initialize the websocket service. Args: reconnect_on_error: Whether to automatically reconnect on connection errors. **kwargs: Additional arguments (unused, for compatibility). """ self._websocket: websockets.WebSocketClientProtocol | None = None self._reconnect_on_error = reconnect_on_error self._reconnect_in_progress: bool = False self._disconnecting: bool = False self._quick_failure_count: int = 0 self._last_connect_time: float = 0.0
async def _verify_connection(self) -> bool: """Verify the websocket connection is active and responsive. Returns: True if connection is verified working, False otherwise. """ try: if not self._websocket or self._websocket.state is State.CLOSED: return False await self._websocket.ping() return True except Exception as e: logger.error(f"{self} connection verification failed: {e}") return False async def _reconnect_websocket(self, attempt_number: int) -> bool: """Reconnect the websocket with the current attempt number. Args: attempt_number: Current retry attempt number for logging. Returns: True if reconnection and verification successful, False otherwise. """ logger.warning(f"{self} reconnecting (attempt: {attempt_number})") await self._disconnect_websocket() await self._connect_websocket() return await self._verify_connection() async def _try_reconnect( self, max_retries: int = 3, report_error: Callable[[ErrorFrame], Awaitable[None]] | None = None, ) -> bool: # Prevent concurrent reconnection attempts if self._reconnect_in_progress: logger.warning(f"{self} reconnect attempt aborted: already in progress") return False self._reconnect_in_progress = True last_exception: Exception | None = None try: for attempt in range(1, max_retries + 1): try: logger.warning(f"{self} reconnecting, attempt {attempt}") if await self._reconnect_websocket(attempt): logger.info(f"{self} reconnected successfully on attempt {attempt}") self._last_connect_time = time.monotonic() return True except Exception as e: last_exception = e logger.error(f"{self} reconnection attempt {attempt} failed: {e}") if report_error: await report_error( ErrorFrame(f"{self} reconnection attempt {attempt} failed: {e}") ) wait_time = exponential_backoff_time(attempt) await asyncio.sleep(wait_time) msg = f"{self} failed to reconnect after {max_retries} attempts" if last_exception: msg += f": {last_exception}" logger.error(msg) if report_error: await report_error(ErrorFrame(msg)) return False finally: self._reconnect_in_progress = False
[docs] async def send_with_retry(self, message, report_error: Callable[[ErrorFrame], Awaitable[None]]): """Attempt to send a message, retrying after reconnect if necessary.""" try: await self._websocket.send(message) except Exception as e: logger.error(f"{self} send failed: {e}, will try to reconnect") # Try to reconnect before retrying success = await self._try_reconnect(report_error=report_error) if success: logger.info(f"{self} reconnected successfully, will retry send the message") # trying to send the message one more time await self._websocket.send(message) else: logger.error(f"{self} send failed; unable to reconnect")
async def _maybe_try_reconnect( self, error_message: str, report_error: Callable[[ErrorFrame], Awaitable[None]], error: Exception | None = None, ) -> bool: """Check if reconnection should be attempted and try if appropriate. Args: error_message: Human-readable error message for logging. report_error: Callback function to report connection errors. error: The exception that occurred (optional, may be None for graceful closes). Returns: True if should continue the receive loop, False if should break. """ # Don't reconnect if we're intentionally disconnecting if self._disconnecting: if error: logger.warning(f"{self} error during disconnect: {error}") else: logger.debug(f"{self} receive loop ended during disconnect") return False # Check if the connection died too quickly after being established. This # catches cases where the handshake succeeds but the server immediately # closes (e.g. invalid API key). Exponential backoff won't help here # because the handshake keeps succeeding — we need to stop the loop. if self._last_connect_time > 0: connection_duration = time.monotonic() - self._last_connect_time if connection_duration < self._MIN_STABLE_CONNECTION_DURATION: self._quick_failure_count += 1 logger.warning( f"{self} connection lasted only {connection_duration:.1f}s " f"({self._quick_failure_count}/{self._MAX_CONSECUTIVE_QUICK_FAILURES} " f"consecutive quick failures)" ) if self._quick_failure_count >= self._MAX_CONSECUTIVE_QUICK_FAILURES: msg = ( f"{self} connection failed {self._MAX_CONSECUTIVE_QUICK_FAILURES} " f"times immediately after connecting" ) logger.error(msg) await report_error(ErrorFrame(msg)) return False else: # Connection was stable — reset the counter. self._quick_failure_count = 0 # Log the message logger.warning(error_message) # Try to reconnect if enabled if self._reconnect_on_error: success = await self._try_reconnect(report_error=report_error) return success else: # Reconnection disabled await report_error(ErrorFrame(error_message)) return False async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Awaitable[None]]): """Handle websocket message receiving with automatic retry logic. Continuously receives messages with automatic reconnection on errors. Uses exponential backoff between retry attempts and reports fatal errors after maximum retries are exhausted. Args: report_error: Callback function to report connection errors. """ while True: self._last_connect_time = time.monotonic() try: await self._receive_messages() # _receive_messages() returned normally. This happens when the websocket # closes gracefully (server sent close frame). The async for loop over # the websocket exits without raising an exception in this case. # We must handle this to avoid an infinite loop. message = f"{self} connection closed by server" should_continue = await self._maybe_try_reconnect(message, report_error) if not should_continue: break except ConnectionClosedOK as e: # Normal closure, don't retry logger.debug(f"{self} connection closed normally: {e}") break except ConnectionClosedError as e: # Connection closed with error (e.g., no close frame received/sent) # This often indicates network issues, server problems, or abrupt disconnection message = f"{self} connection closed, but with an error: {e}" should_continue = await self._maybe_try_reconnect(message, report_error, e) if not should_continue: break except Exception as e: # General error during message receiving message = f"{self} error receiving messages: {e}" should_continue = await self._maybe_try_reconnect(message, report_error, e) if not should_continue: break async def _connect(self): """Connect to the service and reset disconnecting flag. Manages the disconnecting flag to enable reconnection. Subclasses should call super()._connect() first, then implement their specific connection logic including websocket connection via _connect_websocket() and any additional setup required. """ self._disconnecting = False self._quick_failure_count = 0 async def _disconnect(self): """Disconnect from the service and set disconnecting flag. Manages the disconnecting flag to prevent reconnection during intentional disconnect. Subclasses should call super()._disconnect() first, then implement their specific disconnection logic including websocket disconnection via _disconnect_websocket() and any cleanup required. """ self._disconnecting = True @abstractmethod async def _connect_websocket(self): """Establish the websocket connection. Implement the low-level websocket connection logic specific to the service. Should only handle websocket connection, not additional service setup. """ pass @abstractmethod async def _disconnect_websocket(self): """Close the websocket connection. Implement the low-level websocket disconnection logic specific to the service. Should only handle websocket disconnection, not additional service cleanup. """ pass @abstractmethod async def _receive_messages(self): """Receive and process websocket messages. Implement service-specific logic for receiving and handling messages from the websocket connection. Called continuously by the receive task handler. """ pass