"""Parallel step execution: run independent plan steps concurrently. Multi-agent acceleration: steps with satisfied dependencies and no mutual dependencies are dispatched in parallel to maximize throughput. """ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from typing import Any, Callable, Protocol from fusionagi.schemas.plan import Plan from fusionagi.planning import ready_steps, get_step from fusionagi._logger import logger @dataclass class ParallelStepResult: """Result of a single step execution in parallel batch.""" step_id: str success: bool result: Any = None error: str | None = None envelope: Any = None # AgentMessageEnvelope from executor @dataclass class ExecuteStepsCallback(Protocol): """Protocol for executing a single step (e.g. via orchestrator).""" def __call__( self, task_id: str, step_id: str, plan: Plan, sender: str = "supervisor", ) -> Any: """Execute one step; return response envelope or result.""" ... def execute_steps_parallel( execute_fn: Callable[[str, str, Plan, str], Any], task_id: str, plan: Plan, completed_step_ids: set[str], sender: str = "supervisor", max_workers: int | None = None, ) -> list[ParallelStepResult]: """ Execute all ready steps in parallel. Args: execute_fn: Function (task_id, step_id, plan, sender) -> response. task_id: Task identifier. plan: The plan containing steps. completed_step_ids: Steps already completed. sender: Sender identity for execute messages. max_workers: Max parallel workers (default: unbounded for ready steps). Returns: List of ParallelStepResult, one per step attempted. """ ready = ready_steps(plan, completed_step_ids) if not ready: return [] results: list[ParallelStepResult] = [] workers = max_workers if max_workers and max_workers > 0 else len(ready) def run_one(step_id: str) -> ParallelStepResult: try: response = execute_fn(task_id, step_id, plan, sender) if response is None: return ParallelStepResult(step_id=step_id, success=False, error="No response") # Response may be AgentMessageEnvelope with intent step_done/step_failed if hasattr(response, "message"): msg = response.message if msg.intent == "step_done": payload = msg.payload or {} return ParallelStepResult( step_id=step_id, success=True, result=payload.get("result"), envelope=response, ) return ParallelStepResult( step_id=step_id, success=False, error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown", envelope=response, ) return ParallelStepResult(step_id=step_id, success=True, result=response) except Exception as e: logger.exception("Parallel step execution failed", extra={"step_id": step_id}) return ParallelStepResult(step_id=step_id, success=False, error=str(e)) with ThreadPoolExecutor(max_workers=workers) as executor: future_to_step = {executor.submit(run_one, sid): sid for sid in ready} for future in as_completed(future_to_step): results.append(future.result()) logger.info( "Parallel step batch completed", extra={"task_id": task_id, "steps": ready, "results": len(results)}, ) return results def execute_steps_parallel_wave( execute_fn: Callable[[str, str, Plan, str], Any], task_id: str, plan: Plan, sender: str = "supervisor", max_workers: int | None = None, ) -> list[ParallelStepResult]: """ Execute plan in waves: each wave runs all ready steps in parallel, then advances to the next wave when deps are satisfied. Returns combined results from all waves. """ completed: set[str] = set() all_results: list[ParallelStepResult] = [] while True: batch = execute_steps_parallel( execute_fn, task_id, plan, completed, sender, max_workers ) if not batch: break for r in batch: all_results.append(r) if r.success: completed.add(r.step_id) else: # On failure, stop the wave (caller can retry or handle) logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id}) return all_results return all_results