Files

145 lines
4.7 KiB
Python
Raw Permalink Normal View History

"""Parallel step execution: run independent plan steps concurrently.
Multi-agent acceleration: steps with satisfied dependencies and no mutual
dependencies are dispatched in parallel to maximize throughput.
"""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol
from fusionagi.schemas.plan import Plan
from fusionagi.planning import ready_steps, get_step
from fusionagi._logger import logger
@dataclass
class ParallelStepResult:
"""Result of a single step execution in parallel batch."""
step_id: str
success: bool
result: Any = None
error: str | None = None
envelope: Any = None # AgentMessageEnvelope from executor
@dataclass
class ExecuteStepsCallback(Protocol):
"""Protocol for executing a single step (e.g. via orchestrator)."""
def __call__(
self,
task_id: str,
step_id: str,
plan: Plan,
sender: str = "supervisor",
) -> Any:
"""Execute one step; return response envelope or result."""
...
def execute_steps_parallel(
execute_fn: Callable[[str, str, Plan, str], Any],
task_id: str,
plan: Plan,
completed_step_ids: set[str],
sender: str = "supervisor",
max_workers: int | None = None,
) -> list[ParallelStepResult]:
"""
Execute all ready steps in parallel.
Args:
execute_fn: Function (task_id, step_id, plan, sender) -> response.
task_id: Task identifier.
plan: The plan containing steps.
completed_step_ids: Steps already completed.
sender: Sender identity for execute messages.
max_workers: Max parallel workers (default: unbounded for ready steps).
Returns:
List of ParallelStepResult, one per step attempted.
"""
ready = ready_steps(plan, completed_step_ids)
if not ready:
return []
results: list[ParallelStepResult] = []
workers = max_workers if max_workers and max_workers > 0 else len(ready)
def run_one(step_id: str) -> ParallelStepResult:
try:
response = execute_fn(task_id, step_id, plan, sender)
if response is None:
return ParallelStepResult(step_id=step_id, success=False, error="No response")
# Response may be AgentMessageEnvelope with intent step_done/step_failed
if hasattr(response, "message"):
msg = response.message
if msg.intent == "step_done":
payload = msg.payload or {}
return ParallelStepResult(
step_id=step_id,
success=True,
result=payload.get("result"),
envelope=response,
)
return ParallelStepResult(
step_id=step_id,
success=False,
error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown",
envelope=response,
)
return ParallelStepResult(step_id=step_id, success=True, result=response)
except Exception as e:
logger.exception("Parallel step execution failed", extra={"step_id": step_id})
return ParallelStepResult(step_id=step_id, success=False, error=str(e))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_step = {executor.submit(run_one, sid): sid for sid in ready}
for future in as_completed(future_to_step):
results.append(future.result())
logger.info(
"Parallel step batch completed",
extra={"task_id": task_id, "steps": ready, "results": len(results)},
)
return results
def execute_steps_parallel_wave(
execute_fn: Callable[[str, str, Plan, str], Any],
task_id: str,
plan: Plan,
sender: str = "supervisor",
max_workers: int | None = None,
) -> list[ParallelStepResult]:
"""
Execute plan in waves: each wave runs all ready steps in parallel,
then advances to the next wave when deps are satisfied.
Returns combined results from all waves.
"""
completed: set[str] = set()
all_results: list[ParallelStepResult] = []
while True:
batch = execute_steps_parallel(
execute_fn, task_id, plan, completed, sender, max_workers
)
if not batch:
break
for r in batch:
all_results.append(r)
if r.success:
completed.add(r.step_id)
else:
# On failure, stop the wave (caller can retry or handle)
logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id})
return all_results
return all_results