145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
|
|
"""Parallel step execution: run independent plan steps concurrently.
|
||
|
|
|
||
|
|
Multi-agent acceleration: steps with satisfied dependencies and no mutual
|
||
|
|
dependencies are dispatched in parallel to maximize throughput.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any, Callable, Protocol
|
||
|
|
|
||
|
|
from fusionagi.schemas.plan import Plan
|
||
|
|
from fusionagi.planning import ready_steps, get_step
|
||
|
|
from fusionagi._logger import logger
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ParallelStepResult:
|
||
|
|
"""Result of a single step execution in parallel batch."""
|
||
|
|
|
||
|
|
step_id: str
|
||
|
|
success: bool
|
||
|
|
result: Any = None
|
||
|
|
error: str | None = None
|
||
|
|
envelope: Any = None # AgentMessageEnvelope from executor
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ExecuteStepsCallback(Protocol):
|
||
|
|
"""Protocol for executing a single step (e.g. via orchestrator)."""
|
||
|
|
|
||
|
|
def __call__(
|
||
|
|
self,
|
||
|
|
task_id: str,
|
||
|
|
step_id: str,
|
||
|
|
plan: Plan,
|
||
|
|
sender: str = "supervisor",
|
||
|
|
) -> Any:
|
||
|
|
"""Execute one step; return response envelope or result."""
|
||
|
|
...
|
||
|
|
|
||
|
|
|
||
|
|
def execute_steps_parallel(
|
||
|
|
execute_fn: Callable[[str, str, Plan, str], Any],
|
||
|
|
task_id: str,
|
||
|
|
plan: Plan,
|
||
|
|
completed_step_ids: set[str],
|
||
|
|
sender: str = "supervisor",
|
||
|
|
max_workers: int | None = None,
|
||
|
|
) -> list[ParallelStepResult]:
|
||
|
|
"""
|
||
|
|
Execute all ready steps in parallel.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
execute_fn: Function (task_id, step_id, plan, sender) -> response.
|
||
|
|
task_id: Task identifier.
|
||
|
|
plan: The plan containing steps.
|
||
|
|
completed_step_ids: Steps already completed.
|
||
|
|
sender: Sender identity for execute messages.
|
||
|
|
max_workers: Max parallel workers (default: unbounded for ready steps).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of ParallelStepResult, one per step attempted.
|
||
|
|
"""
|
||
|
|
ready = ready_steps(plan, completed_step_ids)
|
||
|
|
if not ready:
|
||
|
|
return []
|
||
|
|
|
||
|
|
results: list[ParallelStepResult] = []
|
||
|
|
workers = max_workers if max_workers and max_workers > 0 else len(ready)
|
||
|
|
|
||
|
|
def run_one(step_id: str) -> ParallelStepResult:
|
||
|
|
try:
|
||
|
|
response = execute_fn(task_id, step_id, plan, sender)
|
||
|
|
if response is None:
|
||
|
|
return ParallelStepResult(step_id=step_id, success=False, error="No response")
|
||
|
|
# Response may be AgentMessageEnvelope with intent step_done/step_failed
|
||
|
|
if hasattr(response, "message"):
|
||
|
|
msg = response.message
|
||
|
|
if msg.intent == "step_done":
|
||
|
|
payload = msg.payload or {}
|
||
|
|
return ParallelStepResult(
|
||
|
|
step_id=step_id,
|
||
|
|
success=True,
|
||
|
|
result=payload.get("result"),
|
||
|
|
envelope=response,
|
||
|
|
)
|
||
|
|
return ParallelStepResult(
|
||
|
|
step_id=step_id,
|
||
|
|
success=False,
|
||
|
|
error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown",
|
||
|
|
envelope=response,
|
||
|
|
)
|
||
|
|
return ParallelStepResult(step_id=step_id, success=True, result=response)
|
||
|
|
except Exception as e:
|
||
|
|
logger.exception("Parallel step execution failed", extra={"step_id": step_id})
|
||
|
|
return ParallelStepResult(step_id=step_id, success=False, error=str(e))
|
||
|
|
|
||
|
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||
|
|
future_to_step = {executor.submit(run_one, sid): sid for sid in ready}
|
||
|
|
for future in as_completed(future_to_step):
|
||
|
|
results.append(future.result())
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Parallel step batch completed",
|
||
|
|
extra={"task_id": task_id, "steps": ready, "results": len(results)},
|
||
|
|
)
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def execute_steps_parallel_wave(
|
||
|
|
execute_fn: Callable[[str, str, Plan, str], Any],
|
||
|
|
task_id: str,
|
||
|
|
plan: Plan,
|
||
|
|
sender: str = "supervisor",
|
||
|
|
max_workers: int | None = None,
|
||
|
|
) -> list[ParallelStepResult]:
|
||
|
|
"""
|
||
|
|
Execute plan in waves: each wave runs all ready steps in parallel,
|
||
|
|
then advances to the next wave when deps are satisfied.
|
||
|
|
|
||
|
|
Returns combined results from all waves.
|
||
|
|
"""
|
||
|
|
completed: set[str] = set()
|
||
|
|
all_results: list[ParallelStepResult] = []
|
||
|
|
|
||
|
|
while True:
|
||
|
|
batch = execute_steps_parallel(
|
||
|
|
execute_fn, task_id, plan, completed, sender, max_workers
|
||
|
|
)
|
||
|
|
if not batch:
|
||
|
|
break
|
||
|
|
|
||
|
|
for r in batch:
|
||
|
|
all_results.append(r)
|
||
|
|
if r.success:
|
||
|
|
completed.add(r.step_id)
|
||
|
|
else:
|
||
|
|
# On failure, stop the wave (caller can retry or handle)
|
||
|
|
logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id})
|
||
|
|
return all_results
|
||
|
|
|
||
|
|
return all_results
|