Source code for pipecat.processors.frameworks.strands_agents

"""Strands Agent integration for Pipecat.

This module provides integration with Strands Agents for handling conversational AI
interactions. It supports both single agent and multi-agent graphs.
"""

from loguru import logger

from pipecat.frames.frames import (
    Frame,
    LLMContextFrame,
    LLMFullResponseEndFrame,
    LLMFullResponseStartFrame,
    LLMTextFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor

try:
    from strands import Agent
    from strands.multiagent.graph import Graph
except ModuleNotFoundError as e:
    logger.error("In order to use Strands Agents, you need to `pip install strands-agents`.")
    raise Exception(f"Missing module: {e}")


[docs] class StrandsAgentsProcessor(FrameProcessor): """Processor that integrates Strands Agents with Pipecat's frame pipeline. This processor takes LLM message frames, extracts the latest user message, and processes it through either a single Strands Agent or a multi-agent Graph. The response is streamed back as text frames with appropriate response markers. Supports both single agent streaming and graph-based multi-agent workflows. """
[docs] def __init__( self, agent: Agent | None = None, graph: Graph | None = None, graph_exit_node: str | None = None, ): """Initialize the Strands Agents processor. Args: agent: The Strands Agent to use for single-agent processing. graph: The Strands multi-agent Graph to use for graph-based processing. graph_exit_node: The exit node name when using graph-based processing. Raises: AssertionError: If neither agent nor graph is provided, or if graph is provided without a graph_exit_node. """ super().__init__() self.agent = agent self.graph = graph self.graph_exit_node = graph_exit_node assert self.agent or self.graph, "Either agent or graph must be provided" if self.graph: assert self.graph_exit_node, "graph_exit_node must be provided if graph is provided"
[docs] async def process_frame(self, frame: Frame, direction: FrameDirection): """Process incoming frames and handle LLM message frames. Args: frame: The incoming frame to process. direction: The direction of frame flow in the pipeline. """ await super().process_frame(frame, direction) if isinstance(frame, LLMContextFrame): messages = frame.context.get_messages() if messages: last_message = messages[-1] await self._ainvoke(str(last_message["content"]).strip()) else: await self.push_frame(frame, direction)
async def _ainvoke(self, text: str): """Invoke the Strands agent with the provided text and stream results as Pipecat frames. Args: text: The user input text to process through the agent or graph. """ logger.debug(f"Invoking Strands agent with: {text}") ttfb_tracking = True try: await self.push_frame(LLMFullResponseStartFrame()) await self.start_processing_metrics() await self.start_ttfb_metrics() if self.graph: # Graph does not stream; await full result then emit assistant text graph_result = await self.graph.invoke_async(text) if ttfb_tracking: await self.stop_ttfb_metrics() ttfb_tracking = False try: node_result = graph_result.results[self.graph_exit_node] logger.debug(f"Node result: {node_result}") for agent_result in node_result.get_agent_results(): # Push to TTS service message = getattr(agent_result, "message", None) if isinstance(message, dict) and "content" in message: for block in message["content"]: if isinstance(block, dict) and "text" in block: await self.push_frame(LLMTextFrame(str(block["text"]))) # Update usage metrics await self._report_usage_metrics( agent_result.metrics.accumulated_usage.get("inputTokens", 0), agent_result.metrics.accumulated_usage.get("outputTokens", 0), agent_result.metrics.accumulated_usage.get("totalTokens", 0), ) except Exception as parse_err: logger.warning(f"Failed to extract messages from GraphResult: {parse_err}") else: # Agent supports streaming events via async iterator async for event in self.agent.stream_async(text): # Push to TTS service if isinstance(event, dict) and "data" in event: await self.push_frame(LLMTextFrame(str(event["data"]))) if ttfb_tracking: await self.stop_ttfb_metrics() ttfb_tracking = False # Update usage metrics if ( isinstance(event, dict) and "event" in event and "metadata" in event["event"] ): if "usage" in event["event"]["metadata"]: usage = event["event"]["metadata"]["usage"] await self._report_usage_metrics( usage.get("inputTokens", 0), usage.get("outputTokens", 0), usage.get("totalTokens", 0), ) except GeneratorExit: logger.warning(f"{self} generator was closed prematurely") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: if ttfb_tracking: await self.stop_ttfb_metrics() ttfb_tracking = False await self.stop_processing_metrics() await self.push_frame(LLMFullResponseEndFrame())
[docs] def can_generate_metrics(self) -> bool: """Check if this service can generate performance metrics. Returns: True as this service supports metrics generation. """ return True
async def _report_usage_metrics( self, prompt_tokens: int, completion_tokens: int, total_tokens: int ): tokens = LLMTokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, ) await self.start_llm_usage_metrics(tokens)