Source code for pipecat.services.aws.sagemaker.bidi_client

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""AWS SageMaker bidirectional streaming client.

This module provides a client for streaming bidirectional communication with
SageMaker endpoints using the HTTP/2 protocol. Supports sending audio, text,
and JSON data to SageMaker model endpoints and receiving streaming responses.
"""

import os

from loguru import logger

try:
    from aws_sdk_sagemaker_runtime_http2.client import SageMakerRuntimeHTTP2Client
    from aws_sdk_sagemaker_runtime_http2.config import Config, HTTPAuthSchemeResolver
    from aws_sdk_sagemaker_runtime_http2.models import (
        InvokeEndpointWithBidirectionalStreamInput,
        RequestPayloadPart,
        RequestStreamEventPayloadPart,
        ResponseStreamEvent,
    )
    from smithy_aws_core.auth.sigv4 import SigV4AuthScheme
    from smithy_aws_core.identity import EnvironmentCredentialsResolver
    from smithy_core.aio.eventstream import DuplexEventStream
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error(
        "In order to use SageMaker BiDi client, you need to `pip install pipecat-ai[sagemaker]`."
    )
    raise Exception(f"Missing module: {e}")


[docs] class SageMakerBidiClient: """Client for bidirectional streaming with AWS SageMaker endpoints. Handles low-level HTTP/2 bidirectional streaming protocol for communicating with SageMaker model endpoints. Provides methods for sending various data types (audio, text, JSON) and receiving streaming responses. This client uses AWS SigV4 authentication and supports credential resolution from environment variables, AWS CLI configuration, and instance metadata. Example:: client = SageMakerBidiClient( endpoint_name="my-deepgram-endpoint", region="us-east-2", model_invocation_path="v1/listen", model_query_string="model=nova-3&language=en" ) await client.start_session() await client.send_audio_chunk(audio_bytes) response = await client.receive_response() await client.close_session() """
[docs] def __init__( self, endpoint_name: str, region: str, model_invocation_path: str = "", model_query_string: str = "", ): """Initialize the SageMaker BiDi client. Args: endpoint_name: Name of the SageMaker endpoint to connect to. region: AWS region where the endpoint is deployed. model_invocation_path: API path for the model invocation (e.g., "v1/listen"). model_query_string: Query string parameters for the model (e.g., "model=nova-3"). """ self.endpoint_name = endpoint_name self.region = region self.model_invocation_path = model_invocation_path self.model_query_string = model_query_string self.bidi_endpoint = f"https://runtime.sagemaker.{region}.amazonaws.com:8443" self._client: SageMakerRuntimeHTTP2Client | None = None self._stream: ( DuplexEventStream[RequestStreamEventPayloadPart, ResponseStreamEvent, any] | None ) = None self._output_stream = None self._is_active = False
def _initialize_client(self): """Initialize the SageMaker Runtime HTTP2 client with AWS credentials. Creates and configures the SageMaker Runtime HTTP2 client with SigV4 authentication. Attempts to resolve AWS credentials from environment variables, AWS CLI configuration, or instance metadata. """ logger.debug(f"Initializing SageMaker BiDi client for region: {self.region}") logger.debug(f"Using endpoint URI: {self.bidi_endpoint}") # Check for AWS credentials has_env_creds = bool(os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY")) if not has_env_creds: logger.warning( "AWS credentials not found in environment variables. " "Attempting to use EnvironmentCredentialsResolver which will check " "AWS CLI configuration and instance metadata." ) config = Config( endpoint_uri=self.bidi_endpoint, region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), auth_scheme_resolver=HTTPAuthSchemeResolver(), auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="sagemaker")}, ) self._client = SageMakerRuntimeHTTP2Client(config=config)
[docs] async def start_session(self): """Start a bidirectional streaming session with the SageMaker endpoint. Initializes the client if needed, creates the bidirectional stream, and establishes the connection to the SageMaker endpoint. Must be called before sending or receiving data. Returns: The output stream for receiving responses. Raises: RuntimeError: If client initialization or connection fails. """ if not self._client: self._initialize_client() logger.debug(f"Starting BiDi session with endpoint: {self.endpoint_name}") logger.debug(f"Model invocation path: {self.model_invocation_path}") logger.debug(f"Model query string: {self.model_query_string}") # Create the bidirectional stream stream_input = InvokeEndpointWithBidirectionalStreamInput( endpoint_name=self.endpoint_name, model_invocation_path=self.model_invocation_path, model_query_string=self.model_query_string, ) try: self._stream = await self._client.invoke_endpoint_with_bidirectional_stream( stream_input ) self._is_active = True # Get output stream output = await self._stream.await_output() self._output_stream = output[1] logger.debug("BiDi session started successfully") return self._output_stream except Exception as e: logger.error(f"Failed to start BiDi session: {e}") self._is_active = False raise RuntimeError(f"Failed to start SageMaker BiDi session: {e}")
[docs] async def send_data(self, data_bytes: bytes, data_type: str | None = None): """Send a chunk of data to the stream. Generic method for sending any type of data to the SageMaker endpoint. Use the convenience methods (send_audio_chunk, send_text, send_json) for common data types. Args: data_bytes: Raw bytes to send. data_type: Optional data type header. Common values are "BINARY" for audio/binary data and "UTF8" for text/JSON data. Raises: RuntimeError: If session is not active or send fails. """ if not self._is_active or not self._stream: raise RuntimeError("BiDi session not active") try: payload = RequestPayloadPart(bytes_=data_bytes, data_type=data_type) event = RequestStreamEventPayloadPart(value=payload) await self._stream.input_stream.send(event) except Exception as e: logger.error(f"Failed to send data: {e}") raise
[docs] async def send_audio_chunk(self, audio_bytes: bytes): """Send a chunk of audio data to the stream. Convenience method for sending audio data. Automatically sets the data type to "BINARY". Args: audio_bytes: Raw audio bytes to send (e.g., PCM audio data). Raises: RuntimeError: If session is not active or send fails. """ await self.send_data(audio_bytes, data_type="BINARY")
[docs] async def send_text(self, text: str): """Send text data to the stream. Convenience method for sending text data. Automatically encodes the text as UTF-8 and sets the data type to "UTF8". Args: text: Text string to send. Raises: RuntimeError: If session is not active or send fails. """ await self.send_data(text.encode("utf-8"), data_type="UTF8")
[docs] async def send_json(self, data: dict): """Send JSON data to the stream. Convenience method for sending JSON-encoded messages. Useful for control messages like KeepAlive or CloseStream. Automatically serializes the dictionary to JSON, encodes as UTF-8, and sets the data type to "UTF8". Args: data: Dictionary to send as JSON (e.g., {"type": "KeepAlive"}). Raises: RuntimeError: If session is not active or send fails. """ import json await self.send_data(json.dumps(data).encode("utf-8"), data_type="UTF8")
[docs] async def receive_response(self) -> ResponseStreamEvent | None: """Receive a response from the stream. Blocks until a response is available from the SageMaker endpoint. Returns None when the stream is closed. Returns: The response event containing payload data, or None if stream is closed. Raises: RuntimeError: If session is not active. """ if not self._is_active or not self._output_stream: raise RuntimeError("BiDi session not active") try: result = await self._output_stream.receive() return result except Exception as e: logger.error(f"Failed to receive response: {e}") raise
[docs] async def close_session(self): """Close the bidirectional streaming session. Gracefully closes the input stream and marks the session as inactive. Safe to call multiple times. """ if not self._is_active: return logger.debug("Closing BiDi session...") self._is_active = False try: if self._stream: await self._stream.input_stream.close() logger.debug("BiDi session closed successfully") except Exception as e: logger.warning(f"Error closing BiDi session: {e}")
@property def is_active(self) -> bool: """Check if the session is currently active. Returns: True if session is active, False otherwise. """ return self._is_active