Source code for pipecat.utils.asyncio.task_manager

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Asyncio task management.

This module provides task management functionality. Includes both abstract base
classes and concrete implementations for managing asyncio tasks with
comprehensive monitoring and cleanup capabilities.
"""

import asyncio
import traceback
from abc import ABC, abstractmethod
from collections.abc import Coroutine, Sequence
from dataclasses import dataclass

from loguru import logger


[docs] @dataclass class TaskManagerParams: """Configuration parameters for task manager initialization. Parameters: loop: The asyncio event loop to use for task management. """ loop: asyncio.AbstractEventLoop
[docs] class BaseTaskManager(ABC): """Abstract base class for asyncio task management. Provides the interface for creating, monitoring, and managing asyncio tasks. """
[docs] @abstractmethod def setup(self, params: TaskManagerParams): """Initialize the task manager with configuration parameters. Args: params: Configuration parameters for task management. """ pass
[docs] @abstractmethod def get_event_loop(self) -> asyncio.AbstractEventLoop: """Get the event loop used by this task manager. Returns: The asyncio event loop instance. """ pass
[docs] @abstractmethod def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task: """Creates and schedules a new asyncio Task that runs the given coroutine. The task is added to a global set of created tasks. Args: coroutine: The coroutine to be executed within the task. name: The name to assign to the task for identification. Returns: The created task object. """ pass
[docs] @abstractmethod async def cancel_task(self, task: asyncio.Task, timeout: float | None = None): """Cancels the given asyncio Task and awaits its completion with an optional timeout. This function removes the task from the set of registered tasks upon completion or failure. Args: task: The task to be cancelled. timeout: The optional timeout in seconds to wait for the task to cancel. """ pass
[docs] @abstractmethod def current_tasks(self) -> Sequence[asyncio.Task]: """Returns the list of currently created/registered tasks. Returns: Sequence of currently managed asyncio tasks. """ pass
[docs] @dataclass class TaskData: """Internal data structure for tracking task metadata. Parameters: task: The asyncio Task being managed. """ task: asyncio.Task
[docs] class TaskManager(BaseTaskManager): """Concrete implementation of BaseTaskManager. Manages asyncio tasks. Provides comprehensive task lifecycle management including creation, monitoring, cancellation, and cleanup. """
[docs] def __init__(self) -> None: """Initialize the task manager with empty task registry.""" self._tasks: dict[str, TaskData] = {} self._params: TaskManagerParams | None = None
[docs] def setup(self, params: TaskManagerParams): """Initialize the task manager with configuration parameters. Args: params: Configuration parameters for task management. """ if not self._params: self._params = params
[docs] def get_event_loop(self) -> asyncio.AbstractEventLoop: """Get the event loop used by this task manager. Returns: The asyncio event loop instance. Raises: Exception: If the task manager is not properly set up. """ if not self._params: raise Exception("TaskManager is not setup: unable to get event loop") return self._params.loop
[docs] def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task: """Creates and schedules a new asyncio Task that runs the given coroutine. The task is added to a global set of created tasks. Args: coroutine: The coroutine to be executed within the task. name: The name to assign to the task for identification. Returns: The created task object. Raises: Exception: If the task manager is not properly set up. """ async def run_coroutine(): try: return await coroutine except asyncio.CancelledError: logger.trace(f"{name}: task cancelled") # Re-raise the exception to ensure the task is cancelled. raise except Exception as e: tb = traceback.extract_tb(e.__traceback__) last = tb[-1] logger.error(f"{name} unexpected exception ({last.filename}:{last.lineno}): {e}") if not self._params: raise Exception("TaskManager is not setup: unable to get event loop") task = self._params.loop.create_task(run_coroutine()) task.set_name(name) task.add_done_callback(self._task_done_handler) self._add_task(TaskData(task=task)) logger.trace(f"{name}: task created") return task
[docs] async def cancel_task(self, task: asyncio.Task, timeout: float | None = None): """Cancels the given asyncio Task and awaits its completion with an optional timeout. This function removes the task from the set of registered tasks upon completion or failure. Args: task: The task to be cancelled. timeout: The optional timeout in seconds to wait for the task to cancel. """ name = task.get_name() task.cancel() try: if timeout: await asyncio.wait_for(task, timeout=timeout) else: await task except TimeoutError: logger.warning(f"{name}: timed out waiting for task to cancel") except asyncio.CancelledError: # Here are sure the task is cancelled properly. pass except Exception as e: tb = traceback.extract_tb(e.__traceback__) last = tb[-1] logger.error( f"{name} unexpected exception while cancelling task ({last.filename}:{last.lineno}): {e}" ) except BaseException as e: tb = traceback.extract_tb(e.__traceback__) last = tb[-1] logger.critical( f"{name} fatal base exception while cancelling task ({last.filename}:{last.lineno}): {e}" ) raise
[docs] def current_tasks(self) -> Sequence[asyncio.Task]: """Returns the list of currently created/registered tasks. Returns: Sequence of currently managed asyncio tasks. """ return [data.task for data in self._tasks.values()]
def _add_task(self, task_data: TaskData): """Add a task to the internal registry. Args: task_data: The task metadata. """ name = task_data.task.get_name() self._tasks[name] = task_data def _task_done_handler(self, task: asyncio.Task): """Handle task completion by removing the task from the registry. Args: task: The completed asyncio task. """ name = task.get_name() try: del self._tasks[name] except KeyError as e: logger.trace(f"{name}: unable to remove task data (already removed?): {e}")