"""Supervisor agent: drives the orchestration loop with parallel dispatch. Coordinates Planner -> Reasoner -> Executor flow. Supports: - Parallel step execution (independent steps run concurrently) - Pooled executor routing (load-balanced across N executors) - Batch task processing (multiple tasks in flight) """ from __future__ import annotations from typing import Any, Callable, TYPE_CHECKING from fusionagi.agents.base_agent import BaseAgent from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope from fusionagi.schemas.plan import Plan from fusionagi.planning import ready_steps, get_step from fusionagi.multi_agent.parallel import execute_steps_parallel, execute_steps_parallel_wave from fusionagi._logger import logger if TYPE_CHECKING: from fusionagi.core.orchestrator import Orchestrator class SupervisorAgent(BaseAgent): """ Supervisor: drives plan-execute loop with multi-agent accelerations. Features: - Parallel step execution (ready_steps dispatched concurrently) - Configurable execution mode: sequential, parallel, or wave - Integration with Orchestrator for message routing - Optional pooled executor for horizontal scaling """ def __init__( self, identity: str = "supervisor", orchestrator: Orchestrator | None = None, planner_id: str = "planner", reasoner_id: str = "reasoner", executor_id: str = "executor", parallel_mode: bool = True, max_parallel_workers: int | None = None, ) -> None: """ Args: identity: Supervisor agent id. orchestrator: Orchestrator for routing (required for full loop). planner_id: Registered planner agent id. reasoner_id: Registered reasoner agent id. executor_id: Registered executor (or pool) agent id. parallel_mode: If True, use parallel step execution. max_parallel_workers: Cap on concurrent step executions. """ super().__init__( identity=identity, role="Supervisor", objective="Coordinate plan-execute loop with parallel dispatch", memory_access=True, tool_permissions=[], ) self._orch = orchestrator self._planner_id = planner_id self._reasoner_id = reasoner_id self._executor_id = executor_id self._parallel_mode = parallel_mode self._max_workers = max_parallel_workers def _route(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None: """Route message via orchestrator and return response.""" if not self._orch: return None return self._orch.route_message_return(envelope) def _execute_step(self, task_id: str, step_id: str, plan: Plan, sender: str) -> Any: """Execute a single step by routing to executor.""" envelope = AgentMessageEnvelope( message=AgentMessage( sender=sender, recipient=self._executor_id, intent="execute_step", payload={"step_id": step_id, "plan": plan.to_dict()}, ), task_id=task_id, ) return self._route(envelope) def _request_plan(self, task_id: str, goal: str, constraints: list[str]) -> Plan | None: """Request plan from planner.""" envelope = AgentMessageEnvelope( message=AgentMessage( sender=self.identity, recipient=self._planner_id, intent="plan_request", payload={"goal": goal, "constraints": constraints}, ), task_id=task_id, ) resp = self._route(envelope) if not resp or not resp.message.payload: return None plan_dict = resp.message.payload.get("plan") if not plan_dict: return None return Plan.from_dict(plan_dict) def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None: """ On run_task or similar: get plan, execute steps (parallel or sequential). """ if envelope.message.intent not in ("run_task", "execute_plan"): return None payload = envelope.message.payload or {} task_id = envelope.task_id or "" goal = payload.get("goal", "") constraints = payload.get("constraints", []) plan_dict = payload.get("plan") logger.info( "Supervisor handling run_task", extra={"task_id": task_id, "parallel": self._parallel_mode}, ) if not self._orch: return envelope.create_response( "run_failed", payload={"error": "No orchestrator configured"}, ) # Get plan if plan_dict: plan = Plan.from_dict(plan_dict) else: plan = self._request_plan(task_id, goal, constraints) if not plan: return envelope.create_response( "run_failed", payload={"error": "Failed to get plan"}, ) # Execute steps if self._parallel_mode: results = execute_steps_parallel_wave( self._execute_step, task_id, plan, sender=self.identity, max_workers=self._max_workers, ) successes = sum(1 for r in results if r.success) failures = [r for r in results if not r.success] if failures: return envelope.create_response( "run_failed", payload={ "error": f"Step(s) failed: {[f.step_id for f in failures]}", "results": [ {"step_id": r.step_id, "success": r.success, "error": r.error} for r in results ], }, ) return envelope.create_response( "run_completed", payload={ "steps_completed": successes, "results": [{"step_id": r.step_id, "result": r.result} for r in results], }, ) # Sequential fallback completed: set[str] = set() while True: ready = ready_steps(plan, completed) if not ready: break step_id = ready[0] resp = self._execute_step(task_id, step_id, plan, self.identity) if resp and hasattr(resp, "message") and resp.message.intent == "step_done": completed.add(step_id) else: return envelope.create_response( "run_failed", payload={"error": f"Step {step_id} failed", "step_id": step_id}, ) return envelope.create_response( "run_completed", payload={"steps_completed": len(completed)}, )