98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
"""Sub-task delegation: fan-out to sub-agents, fan-in of results.
|
|
|
|
Enables hierarchical multi-agent: a supervisor decomposes a task into
|
|
sub-tasks, delegates to specialized sub-agents, and aggregates results.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable
|
|
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@dataclass
|
|
class SubTask:
|
|
"""A sub-task to delegate."""
|
|
|
|
sub_task_id: str
|
|
goal: str
|
|
constraints: list[str] = field(default_factory=list)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class SubTaskResult:
|
|
"""Result from a delegated sub-task."""
|
|
|
|
sub_task_id: str
|
|
success: bool
|
|
result: Any = None
|
|
error: str | None = None
|
|
agent_id: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class DelegationConfig:
|
|
"""Configuration for delegation behavior."""
|
|
|
|
max_parallel: int = 4
|
|
timeout_seconds: float | None = None
|
|
fail_fast: bool = False # Stop on first failure
|
|
|
|
|
|
def delegate_sub_tasks(
|
|
sub_tasks: list[SubTask],
|
|
delegate_fn: Callable[[SubTask], SubTaskResult],
|
|
config: DelegationConfig | None = None,
|
|
) -> list[SubTaskResult]:
|
|
"""
|
|
Fan-out: delegate sub-tasks to sub-agents in parallel.
|
|
|
|
Args:
|
|
sub_tasks: List of sub-tasks to delegate.
|
|
delegate_fn: (SubTask) -> SubTaskResult. Typically calls orchestrator
|
|
to submit task and route to sub-agent, then wait for completion.
|
|
config: Delegation behavior.
|
|
|
|
Returns:
|
|
List of SubTaskResult in same order as sub_tasks.
|
|
"""
|
|
cfg = config or DelegationConfig()
|
|
results: list[SubTaskResult] = [None] * len(sub_tasks) # type: ignore
|
|
index_map = {st.sub_task_id: i for i, st in enumerate(sub_tasks)}
|
|
|
|
def run_one(st: SubTask) -> tuple[int, SubTaskResult]:
|
|
r = delegate_fn(st)
|
|
return index_map[st.sub_task_id], r
|
|
|
|
with ThreadPoolExecutor(max_workers=cfg.max_parallel) as executor:
|
|
future_to_task = {executor.submit(run_one, st): st for st in sub_tasks}
|
|
for future in as_completed(future_to_task):
|
|
idx, result = future.result()
|
|
results[idx] = result
|
|
if cfg.fail_fast and not result.success:
|
|
logger.warning("Delegation fail_fast on failure", extra={"sub_task_id": result.sub_task_id})
|
|
break
|
|
|
|
return [r for r in results if r is not None]
|
|
|
|
|
|
def aggregate_sub_task_results(
|
|
results: list[SubTaskResult],
|
|
aggregator: Callable[[list[SubTaskResult]], Any],
|
|
) -> Any:
|
|
"""
|
|
Fan-in: aggregate sub-task results into a single outcome.
|
|
|
|
Args:
|
|
results: Results from delegate_sub_tasks.
|
|
aggregator: (results) -> aggregated value.
|
|
|
|
Returns:
|
|
Aggregated result.
|
|
"""
|
|
return aggregator(results)
|