#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Asyncio task management.
This module provides task management functionality. Includes both abstract base
classes and concrete implementations for managing asyncio tasks with
comprehensive monitoring and cleanup capabilities.
"""
import asyncio
import traceback
from abc import ABC, abstractmethod
from collections.abc import Coroutine, Sequence
from dataclasses import dataclass
from loguru import logger
[docs]
@dataclass
class TaskManagerParams:
"""Configuration parameters for task manager initialization.
Parameters:
loop: The asyncio event loop to use for task management.
"""
loop: asyncio.AbstractEventLoop
[docs]
class BaseTaskManager(ABC):
"""Abstract base class for asyncio task management.
Provides the interface for creating, monitoring, and managing asyncio tasks.
"""
[docs]
@abstractmethod
def setup(self, params: TaskManagerParams):
"""Initialize the task manager with configuration parameters.
Args:
params: Configuration parameters for task management.
"""
pass
[docs]
@abstractmethod
def get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop used by this task manager.
Returns:
The asyncio event loop instance.
"""
pass
[docs]
@abstractmethod
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
"""Creates and schedules a new asyncio Task that runs the given coroutine.
The task is added to a global set of created tasks.
Args:
coroutine: The coroutine to be executed within the task.
name: The name to assign to the task for identification.
Returns:
The created task object.
"""
pass
[docs]
@abstractmethod
async def cancel_task(self, task: asyncio.Task, timeout: float | None = None):
"""Cancels the given asyncio Task and awaits its completion with an optional timeout.
This function removes the task from the set of registered tasks upon
completion or failure.
Args:
task: The task to be cancelled.
timeout: The optional timeout in seconds to wait for the task to cancel.
"""
pass
[docs]
@abstractmethod
def current_tasks(self) -> Sequence[asyncio.Task]:
"""Returns the list of currently created/registered tasks.
Returns:
Sequence of currently managed asyncio tasks.
"""
pass
[docs]
@dataclass
class TaskData:
"""Internal data structure for tracking task metadata.
Parameters:
task: The asyncio Task being managed.
"""
task: asyncio.Task
[docs]
class TaskManager(BaseTaskManager):
"""Concrete implementation of BaseTaskManager.
Manages asyncio tasks. Provides comprehensive task lifecycle management
including creation, monitoring, cancellation, and cleanup.
"""
[docs]
def __init__(self) -> None:
"""Initialize the task manager with empty task registry."""
self._tasks: dict[str, TaskData] = {}
self._params: TaskManagerParams | None = None
[docs]
def setup(self, params: TaskManagerParams):
"""Initialize the task manager with configuration parameters.
Args:
params: Configuration parameters for task management.
"""
if not self._params:
self._params = params
[docs]
def get_event_loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop used by this task manager.
Returns:
The asyncio event loop instance.
Raises:
Exception: If the task manager is not properly set up.
"""
if not self._params:
raise Exception("TaskManager is not setup: unable to get event loop")
return self._params.loop
[docs]
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
"""Creates and schedules a new asyncio Task that runs the given coroutine.
The task is added to a global set of created tasks.
Args:
coroutine: The coroutine to be executed within the task.
name: The name to assign to the task for identification.
Returns:
The created task object.
Raises:
Exception: If the task manager is not properly set up.
"""
async def run_coroutine():
try:
return await coroutine
except asyncio.CancelledError:
logger.trace(f"{name}: task cancelled")
# Re-raise the exception to ensure the task is cancelled.
raise
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.error(f"{name} unexpected exception ({last.filename}:{last.lineno}): {e}")
if not self._params:
raise Exception("TaskManager is not setup: unable to get event loop")
task = self._params.loop.create_task(run_coroutine())
task.set_name(name)
task.add_done_callback(self._task_done_handler)
self._add_task(TaskData(task=task))
logger.trace(f"{name}: task created")
return task
[docs]
async def cancel_task(self, task: asyncio.Task, timeout: float | None = None):
"""Cancels the given asyncio Task and awaits its completion with an optional timeout.
This function removes the task from the set of registered tasks upon
completion or failure.
Args:
task: The task to be cancelled.
timeout: The optional timeout in seconds to wait for the task to cancel.
"""
name = task.get_name()
task.cancel()
try:
if timeout:
await asyncio.wait_for(task, timeout=timeout)
else:
await task
except TimeoutError:
logger.warning(f"{name}: timed out waiting for task to cancel")
except asyncio.CancelledError:
# Here are sure the task is cancelled properly.
pass
except Exception as e:
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.error(
f"{name} unexpected exception while cancelling task ({last.filename}:{last.lineno}): {e}"
)
except BaseException as e:
tb = traceback.extract_tb(e.__traceback__)
last = tb[-1]
logger.critical(
f"{name} fatal base exception while cancelling task ({last.filename}:{last.lineno}): {e}"
)
raise
[docs]
def current_tasks(self) -> Sequence[asyncio.Task]:
"""Returns the list of currently created/registered tasks.
Returns:
Sequence of currently managed asyncio tasks.
"""
return [data.task for data in self._tasks.values()]
def _add_task(self, task_data: TaskData):
"""Add a task to the internal registry.
Args:
task_data: The task metadata.
"""
name = task_data.task.get_name()
self._tasks[name] = task_data
def _task_done_handler(self, task: asyncio.Task):
"""Handle task completion by removing the task from the registry.
Args:
task: The completed asyncio task.
"""
name = task.get_name()
try:
del self._tasks[name]
except KeyError as e:
logger.trace(f"{name}: unable to remove task data (already removed?): {e}")