#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Large Language Model services with function calling support."""
from __future__ import annotations
import asyncio
import json
import uuid
import warnings
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Protocol,
)
from loguru import logger
from websockets.exceptions import ConnectionClosed
from websockets.protocol import State
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
FunctionCallCancelFrame,
FunctionCallFromLLM,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallResultProperties,
FunctionCallsStartedFrame,
InterruptionFrame,
LLMConfigureOutputFrame,
LLMContextSummaryRequestFrame,
LLMContextSummaryResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMTextFrame,
LLMUpdateSettingsFrame,
StartFrame,
)
from pipecat.processors.aggregators.llm_context import (
LLMContext,
LLMSpecificMessage,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.settings import LLMSettings, assert_given
from pipecat.services.websocket_service import WebsocketService
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
from pipecat.utils.async_tool_cancellation import (
ASYNC_TOOL_CANCELLATION_INSTRUCTIONS,
CANCEL_ASYNC_TOOL_NAME,
CANCEL_ASYNC_TOOL_SCHEMA,
)
from pipecat.utils.context.llm_context_summarization import (
DEFAULT_SUMMARIZATION_TIMEOUT,
LLMContextSummarizationUtil,
)
# Type alias for a callable that handles LLM function calls.
FunctionCallHandler = Callable[["FunctionCallParams"], Awaitable[None]]
# Type alias for a callback function that handles the result of an LLM function call.
[docs]
class FunctionCallResultCallback(Protocol):
"""Protocol for function call result callbacks.
Used for both final results and intermediate updates. Pass
``properties=FunctionCallResultProperties(is_final=False)`` to send an
intermediate update (only valid for async function calls registered with
``cancel_on_interruption=False``).
"""
async def __call__(
self, result: Any, *, properties: FunctionCallResultProperties | None = None
) -> None:
"""Call the result callback.
Args:
result: The result of the function call, or an intermediate update.
properties: Optional properties. Set ``is_final=False`` to send an
intermediate update instead of the final result.
"""
...
[docs]
@dataclass
class FunctionCallParams:
"""Parameters for a function call.
Parameters:
function_name: The name of the function being called.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
llm: The LLMService instance being used.
context: The LLM context.
result_callback: Callback to deliver the result of the function call.
For async function calls (``cancel_on_interruption=False``), call
it with ``properties=FunctionCallResultProperties(is_final=False)``
to push intermediate updates before the final result.
app_resources: The application-defined resources passed to
``PipelineTask(..., app_resources=...)``. Same object — passed by
reference, not a copy. Use it to share DB handles, clients, state,
feature flags, etc. across all of a session's tool handlers.
"""
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
llm: LLMService
context: LLMContext
result_callback: FunctionCallResultCallback
app_resources: Any = None
@property
def tool_resources(self) -> Any:
"""Deprecated alias for :attr:`app_resources`.
.. deprecated:: 1.2.0
Use :attr:`app_resources` instead. ``tool_resources`` will be
removed in a future version.
"""
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`FunctionCallParams.tool_resources` is deprecated since 1.2.0, "
"use `app_resources` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.app_resources
[docs]
@dataclass
class FunctionCallRegistryItem:
"""Represents an entry in the function call registry.
This is what the user registers when calling register_function.
Parameters:
function_name: The name of the function (None for catch-all handler).
handler: The handler for processing function call parameters.
cancel_on_interruption: Whether to cancel the call on interruption.
When ``False`` the call is treated as asynchronous: the LLM
continues the conversation immediately without waiting for the
result, and the result is injected later via a developer message.
timeout_secs: Optional per-tool timeout in seconds. Overrides the global
``function_call_timeout_secs`` for this specific function.
"""
function_name: str | None
handler: FunctionCallHandler | DirectFunctionWrapper
cancel_on_interruption: bool
timeout_secs: float | None = None
[docs]
@dataclass
class FunctionCallRunnerItem:
"""Internal function call entry for the function call runner.
The runner executes function calls in order.
Parameters:
registry_item: The registry item containing handler information.
function_name: The name of the function.
tool_call_id: A unique identifier for the function call.
arguments: The arguments for the function.
context: The LLM context.
run_llm: Optional flag to control LLM execution after function call.
group_id: Shared identifier for all function calls from the same LLM
response batch. Used to trigger the LLM exactly once when the last
call in the group completes.
"""
registry_item: FunctionCallRegistryItem
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
context: LLMContext
run_llm: bool | None = None
group_id: str | None = None
[docs]
class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
"""Base class for all LLM services.
Handles function calling registration and execution with support for both
parallel and sequential execution modes. Provides event handlers for
completion timeouts and function call lifecycle events.
The service supports the following event handlers:
- on_completion_timeout: Called when an LLM completion timeout occurs
- on_function_calls_started: Called when function calls are received and
execution is about to start. Built-in tools (e.g. ``cancel_async_tool_call``)
are excluded from this event.
- on_function_calls_cancelled: Called after one or more async tool calls are
cancelled.
Example::
@task.event_handler("on_completion_timeout")
async def on_completion_timeout(service):
logger.warning("LLM completion timed out")
@task.event_handler("on_function_calls_started")
async def on_function_calls_started(service, function_calls: List[FunctionCallFromLLM]):
logger.info(f"Starting {len(function_calls)} function calls")
@task.event_handler("on_function_calls_cancelled")
async def on_function_calls_cancelled(service, function_calls: List[FunctionCallFromLLM]):
logger.info(f"Cancelled {len(function_calls)} function calls")
"""
_settings: LLMSettings
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
# However, subclasses should override this with a more specific adapter when necessary.
adapter_class: type[BaseLLMAdapter] = OpenAILLMAdapter
[docs]
def __init__(
self,
run_in_parallel: bool = True,
group_parallel_tools: bool = True,
function_call_timeout_secs: float | None = None,
enable_async_tool_cancellation: bool = False,
settings: LLMSettings | None = None,
**kwargs,
):
"""Initialize the LLM service.
Args:
run_in_parallel: Whether to run function calls in parallel or sequentially.
Defaults to True.
group_parallel_tools: Whether to group parallel function calls so the LLM
is triggered exactly once after all calls in the batch complete. When
False, each function call result triggers the LLM independently as it
arrives. Defaults to True.
function_call_timeout_secs: Optional timeout in seconds for deferred function
calls.
enable_async_tool_cancellation: When True and at least one async function
(``cancel_on_interruption=False``) is registered, automatically injects
the ``cancel_async_tool_call`` built-in tool and its system instructions
so the LLM can cancel stale in-progress calls. Defaults to False.
settings: The runtime-updatable settings for the LLM service.
**kwargs: Additional arguments passed to the parent AIService.
"""
super().__init__(
settings=settings
# Here in case subclass doesn't implement more specific settings
# (which hopefully should be rare)
or LLMSettings(),
**kwargs,
)
self._run_in_parallel = run_in_parallel
self._group_parallel_tools = group_parallel_tools
self._function_call_timeout_secs = function_call_timeout_secs
self._enable_async_tool_cancellation: bool = enable_async_tool_cancellation
self._filter_incomplete_user_turns: bool = False
self._async_tool_cancellation_enabled: bool = False
self._base_system_instruction: str | None = None
self._adapter = self.adapter_class()
self._functions: dict[str | None, FunctionCallRegistryItem] = {}
self._function_call_tasks: dict[asyncio.Task | None, FunctionCallRunnerItem] = {}
self._sequential_runner_task: asyncio.Task | None = None
self._skip_tts: bool | None = None
self._summary_task: asyncio.Task | None = None
self._register_event_handler("on_function_calls_started")
self._register_event_handler("on_function_calls_cancelled")
self._register_event_handler("on_completion_timeout")
[docs]
def get_llm_adapter(self) -> BaseLLMAdapter:
"""Get the LLM adapter instance.
Returns:
The adapter instance used for LLM communication.
"""
return self._adapter
[docs]
def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
"""Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.
Args:
message: The message content.
Returns:
A LLMSpecificMessage instance.
"""
return self.get_llm_adapter().create_llm_specific_message(message)
[docs]
async def run_inference(
self,
context: LLMContext,
max_tokens: int | None = None,
system_instruction: str | None = None,
) -> str | None:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.
Must be implemented by subclasses.
Args:
context: The LLM context containing conversation history.
max_tokens: Optional maximum number of tokens to generate. If provided,
overrides the service's default max_tokens/max_completion_tokens setting.
system_instruction: Optional system instruction to use for this inference.
If provided, overrides any system instruction in the context.
Returns:
The LLM's response as a string, or None if no response is generated.
"""
raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}")
[docs]
async def start(self, frame: StartFrame):
"""Start the LLM service.
Args:
frame: The start frame.
"""
await super().start(frame)
if not self._run_in_parallel:
await self._create_sequential_runner_task()
if self._enable_async_tool_cancellation and self._has_async_tools():
self._setup_async_tool_cancellation()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the LLM service.
Args:
frame: The end frame.
"""
await super().stop(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
await self._cancel_summary_task()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the LLM service.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
if not self._run_in_parallel:
await self._cancel_sequential_runner_task()
await self._cancel_summary_task()
def _compose_system_instruction(self):
"""Compose system_instruction from the base and all active addon instructions.
Combines the base system instruction with turn completion instructions
(when enabled) and async tool cancellation instructions (when enabled),
writing the result to ``self._settings.system_instruction``.
"""
base = self._base_system_instruction
parts = [base] if base else []
if self._filter_incomplete_user_turns:
parts.append(self._user_turn_completion_config.completion_instructions)
if self._async_tool_cancellation_enabled:
parts.append(ASYNC_TOOL_CANCELLATION_INSTRUCTIONS)
composed = "\n\n".join(p for p in parts if p)
self._settings.system_instruction = composed or None
logger.debug(f"{self}: System instruction composed: {self._settings.system_instruction}")
async def _update_settings(self, delta: LLMSettings) -> dict[str, Any]:
"""Apply a settings delta, handling turn-completion fields.
Args:
delta: An LLM settings delta.
Returns:
Dict mapping changed field names to their previous values.
"""
changed = await super()._update_settings(delta)
if "filter_incomplete_user_turns" in changed:
self._filter_incomplete_user_turns = (
self._settings.filter_incomplete_user_turns or False
)
logger.info(
f"{self}: Incomplete turn filtering "
f"{'enabled' if self._filter_incomplete_user_turns else 'disabled'}"
)
if self._filter_incomplete_user_turns:
# Save the current system_instruction before composing
self._base_system_instruction = self._settings.system_instruction
self._compose_system_instruction()
else:
# Restore original system_instruction
self._settings.system_instruction = self._base_system_instruction
self._base_system_instruction = None
if "user_turn_completion_config" in changed and self._filter_incomplete_user_turns:
self.set_user_turn_completion_config(
assert_given(self._settings.user_turn_completion_config)
)
self._compose_system_instruction()
if (
"system_instruction" in changed
and (self._filter_incomplete_user_turns or self._async_tool_cancellation_enabled)
and "filter_incomplete_user_turns" not in changed
):
# system_instruction changed while composition is active.
# Treat the new value as the new base and recompose.
self._base_system_instruction = self._settings.system_instruction
self._compose_system_instruction()
return changed
[docs]
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process a frame.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, LLMConfigureOutputFrame):
self._skip_tts = frame.skip_tts
elif isinstance(frame, LLMUpdateSettingsFrame):
if frame.service is not None and frame.service is not self:
await self.push_frame(frame, direction)
elif frame.delta is not None:
await self._update_settings(frame.delta)
elif frame.settings:
# Backward-compatible path: convert legacy dict to settings object.
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"Passing a dict via LLMUpdateSettingsFrame(settings={...}) is deprecated "
"since 0.0.104, use LLMUpdateSettingsFrame(delta=LLMSettings(...)) instead.",
DeprecationWarning,
stacklevel=2,
)
delta = type(self._settings).from_mapping(frame.settings)
await self._update_settings(delta)
elif isinstance(frame, LLMContextSummaryRequestFrame):
await self._handle_summary_request(frame)
[docs]
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
"""Pushes a frame.
Args:
frame: The frame to push.
direction: The direction of frame pushing.
"""
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
if self._skip_tts is not None:
frame.skip_tts = self._skip_tts
await super().push_frame(frame, direction)
async def _push_llm_text(self, text: str):
"""Push LLM text, using turn completion detection if enabled.
This helper method simplifies text pushing in LLM implementations by
handling the conditional logic for turn completion internally.
Args:
text: The text content from the LLM to push.
"""
if self._filter_incomplete_user_turns:
await self._push_turn_text(text)
else:
await self.push_frame(LLMTextFrame(text))
async def _handle_interruptions(self, _: InterruptionFrame):
for function_name, entry in self._functions.items():
if entry.cancel_on_interruption:
await self._cancel_function_call(function_name)
async def _handle_summary_request(self, frame: LLMContextSummaryRequestFrame):
"""Handle context summarization request from aggregator.
Processes a summarization request by generating a compressed summary
of conversation history. Uses the adapter to format the summary
according to the provider's requirements. Broadcasts the result back
to the aggregator for context reconstruction.
Args:
frame: The summary request frame containing context and parameters.
"""
logger.debug(f"{self}: Processing summarization request {frame.request_id}")
# Create a background task to generate the summary without blocking
self._summary_task = self.create_task(self._generate_summary_task(frame))
async def _generate_summary_task(self, frame: LLMContextSummaryRequestFrame):
"""Background task to generate summary without blocking the pipeline.
Args:
frame: The summary request frame containing context and parameters.
"""
summary = ""
last_index = -1
error = None
timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT
try:
summary, last_index = await asyncio.wait_for(
self._generate_summary(frame),
timeout=timeout,
)
except TimeoutError:
await self.push_error(error_msg=f"Context summarization timed out after {timeout}s")
except Exception as e:
error = f"Error generating context summary: {e}"
await self.push_error(error, exception=e)
await self.broadcast_frame(
LLMContextSummaryResultFrame,
request_id=frame.request_id,
summary=summary,
last_summarized_index=last_index,
error=error,
)
self._summary_task = None
async def _generate_summary(self, frame: LLMContextSummaryRequestFrame) -> tuple[str, int]:
"""Generate a compressed summary of conversation context.
Uses the message selection logic to identify which messages
to summarize, formats them as a transcript, and invokes the LLM to
generate a concise summary. The summary is formatted according to the
LLM provider's requirements using the adapter.
Args:
frame: The summary request frame containing context and configuration.
Returns:
Tuple of (formatted summary message, last_summarized_index).
Raises:
RuntimeError: If there are no messages to summarize, the service doesn't
support run_inference(), or the LLM returns an empty summary.
Note:
Requires the service to implement run_inference() method for
synchronous LLM calls.
"""
# Get messages to summarize using utility method
result = LLMContextSummarizationUtil.get_messages_to_summarize(
frame.context, frame.min_messages_to_keep
)
if not result.messages:
logger.debug(f"{self}: No messages to summarize")
raise RuntimeError("No messages to summarize")
logger.debug(
f"{self}: Generating summary for {len(result.messages)} messages "
f"(index 0 to {result.last_summarized_index}), "
f"target_context_tokens={frame.target_context_tokens}"
)
# Create summary context
transcript = LLMContextSummarizationUtil.format_messages_for_summary(result.messages)
summary_context = LLMContext(
messages=[{"role": "user", "content": f"Conversation history:\n{transcript}"}]
)
# Generate summary using run_inference
# This will be overridden by each LLM service implementation
try:
summary_text = await self.run_inference(
summary_context,
max_tokens=frame.target_context_tokens,
system_instruction=frame.summarization_prompt,
)
except NotImplementedError:
raise RuntimeError(
f"LLM service {self.__class__.__name__} does not implement run_inference"
)
if not summary_text:
raise RuntimeError("LLM returned empty summary")
summary_text = summary_text.strip()
logger.info(
f"{self}: Generated summary of {len(summary_text)} characters "
f"for {len(result.messages)} messages"
)
return summary_text, result.last_summarized_index
[docs]
def register_function(
self,
function_name: str | None,
handler: Any,
*,
cancel_on_interruption: bool = True,
timeout_secs: float | None = None,
):
"""Register a function handler for LLM function calls.
Args:
function_name: The name of the function to handle. Use None to handle
all function calls with a catch-all handler.
handler: The function handler. Should accept a single FunctionCallParams
parameter.
cancel_on_interruption: Whether to cancel this function call when an
interruption occurs. When ``False`` the call is treated as
asynchronous: the LLM continues the conversation immediately
without waiting for the result, and the result is injected later
via a developer message. Defaults to True.
timeout_secs: Optional per-tool timeout in seconds. Overrides the global
``function_call_timeout_secs`` for this specific function. Defaults to
None, which uses the global timeout.
"""
if function_name == CANCEL_ASYNC_TOOL_NAME:
raise ValueError(
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
"registered by user code."
)
# Registering a function with the function_name set to None will run
# that handler for all functions
self._functions[function_name] = FunctionCallRegistryItem(
function_name=function_name,
handler=handler,
cancel_on_interruption=cancel_on_interruption,
timeout_secs=timeout_secs,
)
[docs]
def register_direct_function(
self,
handler: DirectFunction,
*,
cancel_on_interruption: bool = True,
timeout_secs: float | None = None,
):
"""Register a direct function handler for LLM function calls.
Direct functions have their metadata automatically extracted from their
signature and docstring, eliminating the need for accompanying
configurations (as FunctionSchemas or in provider-specific formats).
Args:
handler: The direct function to register. Must follow DirectFunction protocol.
cancel_on_interruption: Whether to cancel this function call when an
interruption occurs. When ``False`` the call is treated as
asynchronous: the LLM continues the conversation immediately
without waiting for the result, and the result is injected later
via a developer message. Defaults to True.
timeout_secs: Optional per-tool timeout in seconds. Overrides the global
``function_call_timeout_secs`` for this specific function. Defaults to
None, which uses the global timeout.
"""
wrapper = DirectFunctionWrapper(handler)
if wrapper.name == CANCEL_ASYNC_TOOL_NAME:
raise ValueError(
f"'{CANCEL_ASYNC_TOOL_NAME}' is a reserved built-in tool name and cannot be "
"registered by user code."
)
self._functions[wrapper.name] = FunctionCallRegistryItem(
function_name=wrapper.name,
handler=wrapper,
cancel_on_interruption=cancel_on_interruption,
timeout_secs=timeout_secs,
)
[docs]
def unregister_function(self, function_name: str | None):
"""Remove a registered function handler.
Args:
function_name: The name of the function handler to remove.
"""
del self._functions[function_name]
if self._async_tool_cancellation_enabled and not self._has_async_tools():
self._teardown_async_tool_cancellation()
[docs]
def unregister_direct_function(self, handler: Any):
"""Remove a registered direct function handler.
Args:
handler: The direct function handler to remove.
"""
wrapper = DirectFunctionWrapper(handler)
del self._functions[wrapper.name]
# Note: no need to remove start callback here, as direct functions don't support start callbacks.
if self._async_tool_cancellation_enabled and not self._has_async_tools():
self._teardown_async_tool_cancellation()
[docs]
def has_function(self, function_name: str):
"""Check if a function handler is registered.
Args:
function_name: The name of the function to check.
Returns:
True if the function is registered or if a catch-all handler (None)
is registered.
"""
if None in self._functions.keys():
return True
return function_name in self._functions.keys()
[docs]
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
"""Execute a sequence of function calls from the LLM.
Triggers the on_function_calls_started event and executes functions
either in parallel or sequentially based on the run_in_parallel setting.
Args:
function_calls: The function calls to execute.
"""
if len(function_calls) == 0:
return
# Exclude the built-in cancel tool — it's an internal mechanism and
# should not be surfaced to user-facing event handlers or frames.
user_visible_calls = [
fc for fc in function_calls if fc.function_name != CANCEL_ASYNC_TOOL_NAME
]
if user_visible_calls:
await self._call_event_handler("on_function_calls_started", user_visible_calls)
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=user_visible_calls)
# When group_parallel_tools is True all calls share a group_id so the
# aggregator triggers the LLM exactly once after the last one completes.
# When False, group_id is None and each result triggers inference independently.
group_id = str(uuid.uuid4()) if self._group_parallel_tools else None
runner_items = []
for function_call in function_calls:
if function_call.function_name in self._functions.keys():
item = self._functions[function_call.function_name]
elif None in self._functions.keys():
item = self._functions[None]
else:
logger.warning(
f"{self} is calling '{function_call.function_name}', but it's not registered."
)
item = self._build_missing_function_call_registry_item(function_call.function_name)
runner_items.append(
FunctionCallRunnerItem(
registry_item=item,
function_name=function_call.function_name,
tool_call_id=function_call.tool_call_id,
arguments=function_call.arguments,
context=function_call.context,
group_id=group_id,
)
)
if self._run_in_parallel:
await self._run_parallel_function_calls(runner_items)
else:
await self._run_sequential_function_calls(runner_items)
async def _create_sequential_runner_task(self):
if not self._sequential_runner_task:
self._sequential_runner_queue = asyncio.Queue()
self._sequential_runner_task = self.create_task(self._sequential_runner_handler())
async def _cancel_sequential_runner_task(self):
if self._sequential_runner_task:
await self.cancel_task(self._sequential_runner_task)
self._sequential_runner_task = None
async def _cancel_summary_task(self):
if self._summary_task:
await self.cancel_task(self._summary_task)
self._summary_task = None
async def _sequential_runner_handler(self):
while True:
runner_item = await self._sequential_runner_queue.get()
task = self.create_task(self._run_function_call(runner_item))
self._function_call_tasks[task] = runner_item
# Since we run tasks sequentially we don't need to call
# task.add_done_callback(self._function_call_task_finished).
await task
del self._function_call_tasks[task]
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
tasks = []
for runner_item in runner_items:
task = self.create_task(self._run_function_call(runner_item))
tasks.append(task)
self._function_call_tasks[task] = runner_item
task.add_done_callback(self._function_call_task_finished)
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
# Enqueue all function calls for background execution.
for runner_item in runner_items:
await self._sequential_runner_queue.put(runner_item)
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
# Re-resolve the registry item at execution time. The function may have
# been unregistered between queuing and execution, in which case we
# fall back to the missing-function handler so the call still terminates
# with a normal tool result.
if runner_item.function_name in self._functions.keys():
item = self._functions[runner_item.function_name]
elif None in self._functions.keys():
item = self._functions[None]
elif runner_item.registry_item.handler == self._missing_function_call_handler:
item = runner_item.registry_item
else:
logger.warning(
f"{self} is calling '{runner_item.function_name}', but it was just unregistered."
)
item = self._build_missing_function_call_registry_item(runner_item.function_name)
logger.debug(
f"{self} Calling function [{runner_item.function_name}:{runner_item.tool_call_id}] with arguments {runner_item.arguments}"
)
# Broadcast function call in-progress. This frame will let our assistant
# context aggregator know that we are in the middle of a function
# call. Some contexts/aggregators may not need this. But some definitely
# do (Anthropic, for example).
await self.broadcast_frame(
FunctionCallInProgressFrame,
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
cancel_on_interruption=item.cancel_on_interruption,
group_id=runner_item.group_id,
)
timeout_task: asyncio.Task | None = None
# Single callback for both intermediate updates and final results.
# Pass properties=FunctionCallResultProperties(is_final=False) for updates.
async def function_call_result_callback(
result: Any, *, properties: FunctionCallResultProperties | None = None
):
is_final = properties.is_final if properties else True
if not is_final and item.cancel_on_interruption:
logger.warning(
f"{self} result_callback called with is_final=False on sync function call"
f" [{runner_item.function_name}:{runner_item.tool_call_id}]."
" Intermediate updates are only valid for async function calls"
" (cancel_on_interruption=False)."
)
return
nonlocal timeout_task
# Cancel timeout task if it exists
if timeout_task and not timeout_task.done():
await self.cancel_task(timeout_task)
await self.broadcast_frame(
FunctionCallResultFrame,
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
result=result,
run_llm=runner_item.run_llm,
properties=properties,
)
# Start a timeout task for deferred function calls
async def timeout_handler():
try:
effective_timeout = item.timeout_secs or self._function_call_timeout_secs
await asyncio.sleep(effective_timeout)
logger.warning(
f"{self} Function call [{runner_item.function_name}:{runner_item.tool_call_id}] timed out after {effective_timeout} seconds."
f" You can increase this timeout by passing `timeout_secs` to `register_function()`,"
f" or set a global default via `function_call_timeout_secs` on the LLM constructor."
)
await function_call_result_callback(None)
except asyncio.CancelledError:
raise
if item.timeout_secs or self._function_call_timeout_secs:
timeout_task = self.create_task(timeout_handler())
# Yield to the event loop so the timeout task coroutine gets entered
# before it could be cancelled. Without this, cancelling the task before
# it starts would leave the coroutine in a "never awaited" state.
await asyncio.sleep(0)
# _pipeline_task may be unset when the service is driven without a PipelineTask.
app_resources = self._pipeline_task.app_resources if self._pipeline_task else None
try:
if isinstance(item.handler, DirectFunctionWrapper):
# Handler is a DirectFunctionWrapper
await item.handler.invoke(
args=runner_item.arguments,
params=FunctionCallParams(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
llm=self,
context=runner_item.context,
result_callback=function_call_result_callback,
app_resources=app_resources,
),
)
else:
# Handler is a FunctionCallHandler
params = FunctionCallParams(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
llm=self,
context=runner_item.context,
result_callback=function_call_result_callback,
app_resources=app_resources,
)
await item.handler(params)
except Exception as e:
error_message = f"Error executing function call [{runner_item.function_name}]: {e}"
logger.error(f"{self} {error_message}")
await self.push_error(error_msg=error_message, exception=e, fatal=False)
finally:
if timeout_task and not timeout_task.done():
await self.cancel_task(timeout_task)
def _build_missing_function_call_registry_item(
self, function_name: str
) -> FunctionCallRegistryItem:
"""Build a registry item that routes to the missing-function handler."""
return FunctionCallRegistryItem(
function_name=function_name,
handler=self._missing_function_call_handler,
cancel_on_interruption=True,
)
async def _missing_function_call_handler(self, params: FunctionCallParams):
"""Return a terminal tool result when the LLM calls an unknown function."""
await params.result_callback(f"Error: function '{params.function_name}' is not registered.")
def _has_async_tools(self) -> bool:
"""Return True if at least one non-builtin async tool is registered."""
return any(
not item.cancel_on_interruption
for name, item in self._functions.items()
if name != CANCEL_ASYNC_TOOL_NAME
)
def _setup_async_tool_cancellation(self):
"""Enable async tool cancellation.
Saves the base system instruction, recomposes to include cancellation
instructions, registers the built-in ``cancel_async_tool_call`` handler,
and injects its schema into the adapter's built-in tool dict.
"""
logger.debug(f"{self}: Enabling async tool cancellation")
self._async_tool_cancellation_enabled = True
if self._base_system_instruction is None:
self._base_system_instruction = self._settings.system_instruction
self._compose_system_instruction()
self._adapter.builtin_tools[CANCEL_ASYNC_TOOL_NAME] = CANCEL_ASYNC_TOOL_SCHEMA
if CANCEL_ASYNC_TOOL_NAME not in self._functions:
self._functions[CANCEL_ASYNC_TOOL_NAME] = FunctionCallRegistryItem(
function_name=CANCEL_ASYNC_TOOL_NAME,
handler=self._cancel_async_tool_call_handler,
cancel_on_interruption=True,
)
def _teardown_async_tool_cancellation(self):
"""Disable async tool cancellation.
Removes the built-in ``cancel_async_tool_call`` handler and its schema,
recomposes the system instruction without cancellation instructions.
"""
logger.debug(f"{self}: Disabling async tool cancellation")
self._async_tool_cancellation_enabled = False
self._adapter.builtin_tools.pop(CANCEL_ASYNC_TOOL_NAME, None)
self._functions.pop(CANCEL_ASYNC_TOOL_NAME, None)
self._compose_system_instruction()
async def _cancel_async_tool_call_handler(self, params: FunctionCallParams):
"""Handle a ``cancel_async_tool_call`` invocation from the LLM.
Args:
params: Function call parameters containing ``tool_call_id`` to cancel.
"""
logger.debug(f"{self}: cancel_async_tool_call invoked")
tool_call_id: str | None = params.arguments.get("tool_call_id")
if not tool_call_id:
logger.warning(f"{self} cancel_async_tool_call called with no tool_call_id")
await params.result_callback({"cancelled": None})
return
await self._cancel_function_calls_by_tool_call_id(tool_call_id)
await params.result_callback(
{"cancelled": tool_call_id},
properties=FunctionCallResultProperties(run_llm=True),
)
async def _cancel_function_calls_by_tool_call_id(self, tool_call_id: str):
"""Cancel in-progress function call tasks by their tool_call_id.
Args:
tool_call_id: tool_call_id to cancel.
"""
cancelled_tasks = set()
cancelled_items = []
for task, runner_item in self._function_call_tasks.items():
if runner_item.tool_call_id == tool_call_id:
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
logger.debug(
f"{self} Cancelling async function call [{name}:{tool_call_id}] "
"by LLM request..."
)
if task:
task.remove_done_callback(self._function_call_task_finished)
await self.cancel_task(task)
cancelled_tasks.add(task)
await self.broadcast_frame(
FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
)
cancelled_items.append(
FunctionCallFromLLM(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
context=runner_item.context,
)
)
logger.debug(f"{self} Async function call [{name}:{tool_call_id}] cancelled")
for task in cancelled_tasks:
self._function_call_task_finished(task)
if cancelled_items:
await self._call_event_handler("on_function_calls_cancelled", cancelled_items)
async def _cancel_function_call(self, function_name: str | None):
cancelled_tasks = set()
cancelled_items = []
for task, runner_item in self._function_call_tasks.items():
if runner_item.registry_item.function_name == function_name:
name = runner_item.function_name
tool_call_id = runner_item.tool_call_id
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
if task:
# We remove the callback because we are going to cancel the
# task next, otherwise we will be removing it from the set
# while we are iterating.
task.remove_done_callback(self._function_call_task_finished)
await self.cancel_task(task)
cancelled_tasks.add(task)
await self.broadcast_frame(
FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
)
cancelled_items.append(
FunctionCallFromLLM(
function_name=runner_item.function_name,
tool_call_id=runner_item.tool_call_id,
arguments=runner_item.arguments,
context=runner_item.context,
)
)
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
# Remove all cancelled tasks from our set.
for task in cancelled_tasks:
self._function_call_task_finished(task)
if cancelled_items:
await self._call_event_handler("on_function_calls_cancelled", cancelled_items)
def _function_call_task_finished(self, task: asyncio.Task):
if task in self._function_call_tasks:
del self._function_call_tasks[task]
# ---------------------------------------------------------------------------
# WebSocket LLM service base
# ---------------------------------------------------------------------------
[docs]
class WebsocketReconnectedError(Exception):
"""Raised by ``_ws_send``/``_ws_recv`` after a transparent reconnection.
Signals that the WebSocket connection was lost and automatically
re-established. The current inference should be restarted — any
connection-local state on the server (e.g. cached responses) is gone.
"""
pass
[docs]
class WebsocketLLMService(LLMService, WebsocketService):
"""Base class for websocket-based LLM services.
Each LLM inference is a discrete request/response exchange: send one
request, receive events inline until a terminal event, then wait for
the next frame to trigger an inference. This contrasts with
``WebsocketTTSService`` / ``WebsocketSTTService`` which stream data
continuously via a background receive loop
(``_receive_task_handler``). This class does **not** start a
background receive loop.
Provides connection lifecycle management (connect on start, disconnect
on stop/cancel), automatic reconnection with exponential backoff, and
three helpers for running each inference:
1. ``_ensure_connected()`` — verify the websocket is alive, reconnect
with exponential backoff if not.
2. ``_ws_send(message)`` — send the inference request as JSON.
3. ``_ws_recv()`` — receive and parse response events one at a time
until the caller sees a terminal event.
``_ws_send`` and ``_ws_recv`` catch ``ConnectionClosed`` transparently,
auto-reconnect via ``_try_reconnect``, and raise
``WebsocketReconnectedError`` so callers know the inference must be
restarted. If reconnection fails, the original ``ConnectionClosed``
propagates.
Subclasses must implement:
``_connect_websocket()``: Establish the websocket connection.
``_disconnect_websocket()``: Close the websocket and clean up.
Event handlers:
on_connection_error: Called when a websocket connection error occurs.
Example::
@llm.event_handler("on_connection_error")
async def on_connection_error(llm: LLMService, error: str):
logger.error(f"LLM connection error: {error}")
"""
[docs]
def __init__(self, *, reconnect_on_error: bool = True, **kwargs):
"""Initialize the Websocket LLM service.
Args:
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
"""
LLMService.__init__(self, **kwargs)
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
self._register_event_handler("on_connection_error")
# -- lifecycle ------------------------------------------------------------
async def _connect(self):
"""Connect: reset flags and establish the websocket."""
await super()._connect()
await self._connect_websocket()
async def _disconnect(self):
"""Disconnect: set flags and close the websocket."""
await super()._disconnect()
await self._disconnect_websocket()
[docs]
async def start(self, frame: StartFrame):
"""Start the service and establish WebSocket connection.
Args:
frame: The start frame triggering service initialization.
"""
await super().start(frame)
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the service and close WebSocket connection.
Args:
frame: The end frame triggering service shutdown.
"""
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the service and close WebSocket connection.
Args:
frame: The cancel frame triggering service cancellation.
"""
await super().cancel(frame)
await self._disconnect()
# -- per-inference helpers ------------------------------------------------
async def _ws_send(self, message: dict):
"""Send a JSON message over the websocket.
Guards against sends during intentional disconnect. If the send
fails with ``ConnectionClosed``, attempts to reconnect and raises
``WebsocketReconnectedError`` on success so the caller can restart
the inference. If reconnection fails, the original
``ConnectionClosed`` propagates.
Args:
message: The message dict to serialize and send.
"""
if self._disconnecting or not self._websocket:
return
try:
await self._websocket.send(json.dumps(message))
except ConnectionClosed:
if self._disconnecting:
return
success = await self._try_reconnect(report_error=self._report_error)
if success:
raise WebsocketReconnectedError()
raise
async def _ws_recv(self) -> dict:
"""Receive and parse a JSON message from the websocket.
If the receive fails with ``ConnectionClosed``, attempts to
reconnect and raises ``WebsocketReconnectedError`` on success.
If reconnection fails, the original ``ConnectionClosed``
propagates.
Returns:
The parsed JSON message as a dict.
"""
try:
raw = await self._websocket.recv()
return json.loads(raw)
except ConnectionClosed:
if self._disconnecting:
raise
success = await self._try_reconnect(report_error=self._report_error)
if success:
raise WebsocketReconnectedError()
raise
async def _ensure_connected(self):
"""Ensure the websocket is connected, reconnecting if needed.
Uses ``_try_reconnect`` with exponential backoff.
Raises:
ConnectionError: If the connection could not be established.
"""
if self._websocket and self._websocket.state is not State.CLOSED:
return
success = await self._try_reconnect(report_error=self._report_error)
if not success:
raise ConnectionError(f"{self} failed to establish WebSocket connection")
# -- WebsocketService interface -------------------------------------------
async def _receive_messages(self):
"""Not used — messages are received inline during each inference.
This satisfies the ``WebsocketService`` abstract method but is never
called because ``_receive_task_handler`` is never started.
"""
raise NotImplementedError(
"WebsocketLLMService receives messages inline during inference, "
"not via a continuous background loop"
)
async def _report_error(self, error: ErrorFrame):
await self._call_event_handler("on_connection_error", error.error)
await self.push_error_frame(error)