Source code for pipecat.services.xai.llm

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Grok LLM service implementation using OpenAI-compatible interface.

This module provides a service for interacting with Grok's API through an
OpenAI-compatible interface, including specialized token usage tracking
and context aggregation functionality.
"""

from dataclasses import dataclass

from loguru import logger

from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.openai.base_llm import BaseOpenAILLMService
from pipecat.services.openai.llm import (
    OpenAILLMService,
)


[docs] @dataclass class GrokLLMSettings(BaseOpenAILLMService.Settings): """Settings for GrokLLMService.""" pass
[docs] class GrokLLMService(OpenAILLMService): """A service for interacting with Grok's API using the OpenAI-compatible interface. This service extends OpenAILLMService to connect to Grok's API endpoint while maintaining full compatibility with OpenAI's interface and functionality. Includes specialized token usage tracking that accumulates metrics during processing and reports final totals. """ Settings = GrokLLMSettings _settings: Settings
[docs] def __init__( self, *, api_key: str, base_url: str = "https://api.x.ai/v1", model: str | None = None, settings: Settings | None = None, **kwargs, ): """Initialize the GrokLLMService with API key and model. Args: api_key: The API key for accessing Grok's API. base_url: The base URL for Grok API. Defaults to "https://api.x.ai/v1". model: The model identifier to use. Defaults to "grok-3". .. deprecated:: 0.0.105 Use ``settings=GrokLLMService.Settings(model=...)`` instead. settings: Runtime-updatable settings. When provided alongside deprecated parameters, ``settings`` values take precedence. **kwargs: Additional keyword arguments passed to OpenAILLMService. """ # 1. Initialize default_settings with hardcoded defaults default_settings = self.Settings( model="grok-3", ) # 2. Apply direct init arg overrides (deprecated) if model is not None: self._warn_init_param_moved_to_settings("model", "model") default_settings.model = model # 3. (No step 3, as there's no params object to apply) # 4. Apply settings delta (canonical API, always wins) if settings is not None: default_settings.apply_update(settings) super().__init__(api_key=api_key, base_url=base_url, settings=default_settings, **kwargs) # Initialize counters for token usage metrics self._prompt_tokens = 0 self._completion_tokens = 0 self._total_tokens = 0 self._has_reported_prompt_tokens = False self._is_processing = False
[docs] def create_client(self, api_key=None, base_url=None, **kwargs): """Create OpenAI-compatible client for Grok API endpoint. Args: api_key: The API key to use. If None, uses instance default. base_url: The base URL to use. If None, uses instance default. **kwargs: Additional arguments passed to client creation. Returns: The configured client instance for Grok API. """ logger.debug(f"Creating Grok client with api {base_url}") return super().create_client(api_key, base_url, **kwargs)
async def _process_context(self, context: LLMContext): """Process a context through the LLM and accumulate token usage metrics. This method overrides the parent class implementation to handle Grok's incremental token reporting style, accumulating the counts and reporting them once at the end of processing. Args: context: The context to process, containing messages and other information needed for the LLM interaction. """ # Reset all counters and flags at the start of processing self._prompt_tokens = 0 self._completion_tokens = 0 self._total_tokens = 0 self._cache_read_input_tokens = None self._reasoning_tokens = None self._has_reported_prompt_tokens = False self._is_processing = True try: await super()._process_context(context) finally: self._is_processing = False # Report final accumulated token usage at the end of processing if self._prompt_tokens > 0 or self._completion_tokens > 0: self._total_tokens = self._prompt_tokens + self._completion_tokens tokens = LLMTokenUsage( prompt_tokens=self._prompt_tokens, completion_tokens=self._completion_tokens, total_tokens=self._total_tokens, cache_read_input_tokens=self._cache_read_input_tokens, reasoning_tokens=self._reasoning_tokens, ) await super().start_llm_usage_metrics(tokens)
[docs] async def start_llm_usage_metrics(self, tokens: LLMTokenUsage): """Accumulate token usage metrics during processing. This method intercepts the incremental token updates from Grok's API and accumulates them instead of passing each update to the metrics system. The final accumulated totals are reported at the end of processing. Args: tokens: The token usage metrics for the current chunk of processing, containing prompt_tokens, completion_tokens, and optional cached/reasoning tokens. """ # Only accumulate metrics during active processing if not self._is_processing: return # Record prompt tokens the first time we see them if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0: self._prompt_tokens = tokens.prompt_tokens self._has_reported_prompt_tokens = True # Update completion tokens count if it has increased if tokens.completion_tokens > self._completion_tokens: self._completion_tokens = tokens.completion_tokens # Capture cached & reasoning tokens (these typically only appear once per request) if tokens.cache_read_input_tokens is not None: self._cache_read_input_tokens = tokens.cache_read_input_tokens if tokens.reasoning_tokens is not None: self._reasoning_tokens = tokens.reasoning_tokens