192 lines
6.9 KiB
Python
192 lines
6.9 KiB
Python
"""Supervisor agent: drives the orchestration loop with parallel dispatch.
|
|
|
|
Coordinates Planner -> Reasoner -> Executor flow. Supports:
|
|
- Parallel step execution (independent steps run concurrently)
|
|
- Pooled executor routing (load-balanced across N executors)
|
|
- Batch task processing (multiple tasks in flight)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
from fusionagi.agents.base_agent import BaseAgent
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
|
from fusionagi.schemas.plan import Plan
|
|
from fusionagi.planning import ready_steps, get_step
|
|
from fusionagi.multi_agent.parallel import execute_steps_parallel, execute_steps_parallel_wave
|
|
from fusionagi._logger import logger
|
|
|
|
if TYPE_CHECKING:
|
|
from fusionagi.core.orchestrator import Orchestrator
|
|
|
|
|
|
class SupervisorAgent(BaseAgent):
|
|
"""
|
|
Supervisor: drives plan-execute loop with multi-agent accelerations.
|
|
|
|
Features:
|
|
- Parallel step execution (ready_steps dispatched concurrently)
|
|
- Configurable execution mode: sequential, parallel, or wave
|
|
- Integration with Orchestrator for message routing
|
|
- Optional pooled executor for horizontal scaling
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
identity: str = "supervisor",
|
|
orchestrator: Orchestrator | None = None,
|
|
planner_id: str = "planner",
|
|
reasoner_id: str = "reasoner",
|
|
executor_id: str = "executor",
|
|
parallel_mode: bool = True,
|
|
max_parallel_workers: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
identity: Supervisor agent id.
|
|
orchestrator: Orchestrator for routing (required for full loop).
|
|
planner_id: Registered planner agent id.
|
|
reasoner_id: Registered reasoner agent id.
|
|
executor_id: Registered executor (or pool) agent id.
|
|
parallel_mode: If True, use parallel step execution.
|
|
max_parallel_workers: Cap on concurrent step executions.
|
|
"""
|
|
super().__init__(
|
|
identity=identity,
|
|
role="Supervisor",
|
|
objective="Coordinate plan-execute loop with parallel dispatch",
|
|
memory_access=True,
|
|
tool_permissions=[],
|
|
)
|
|
self._orch = orchestrator
|
|
self._planner_id = planner_id
|
|
self._reasoner_id = reasoner_id
|
|
self._executor_id = executor_id
|
|
self._parallel_mode = parallel_mode
|
|
self._max_workers = max_parallel_workers
|
|
|
|
def _route(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
|
"""Route message via orchestrator and return response."""
|
|
if not self._orch:
|
|
return None
|
|
return self._orch.route_message_return(envelope)
|
|
|
|
def _execute_step(self, task_id: str, step_id: str, plan: Plan, sender: str) -> Any:
|
|
"""Execute a single step by routing to executor."""
|
|
envelope = AgentMessageEnvelope(
|
|
message=AgentMessage(
|
|
sender=sender,
|
|
recipient=self._executor_id,
|
|
intent="execute_step",
|
|
payload={"step_id": step_id, "plan": plan.to_dict()},
|
|
),
|
|
task_id=task_id,
|
|
)
|
|
return self._route(envelope)
|
|
|
|
def _request_plan(self, task_id: str, goal: str, constraints: list[str]) -> Plan | None:
|
|
"""Request plan from planner."""
|
|
envelope = AgentMessageEnvelope(
|
|
message=AgentMessage(
|
|
sender=self.identity,
|
|
recipient=self._planner_id,
|
|
intent="plan_request",
|
|
payload={"goal": goal, "constraints": constraints},
|
|
),
|
|
task_id=task_id,
|
|
)
|
|
resp = self._route(envelope)
|
|
if not resp or not resp.message.payload:
|
|
return None
|
|
plan_dict = resp.message.payload.get("plan")
|
|
if not plan_dict:
|
|
return None
|
|
return Plan.from_dict(plan_dict)
|
|
|
|
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
|
"""
|
|
On run_task or similar: get plan, execute steps (parallel or sequential).
|
|
"""
|
|
if envelope.message.intent not in ("run_task", "execute_plan"):
|
|
return None
|
|
|
|
payload = envelope.message.payload or {}
|
|
task_id = envelope.task_id or ""
|
|
goal = payload.get("goal", "")
|
|
constraints = payload.get("constraints", [])
|
|
plan_dict = payload.get("plan")
|
|
|
|
logger.info(
|
|
"Supervisor handling run_task",
|
|
extra={"task_id": task_id, "parallel": self._parallel_mode},
|
|
)
|
|
|
|
if not self._orch:
|
|
return envelope.create_response(
|
|
"run_failed",
|
|
payload={"error": "No orchestrator configured"},
|
|
)
|
|
|
|
# Get plan
|
|
if plan_dict:
|
|
plan = Plan.from_dict(plan_dict)
|
|
else:
|
|
plan = self._request_plan(task_id, goal, constraints)
|
|
if not plan:
|
|
return envelope.create_response(
|
|
"run_failed",
|
|
payload={"error": "Failed to get plan"},
|
|
)
|
|
|
|
# Execute steps
|
|
if self._parallel_mode:
|
|
results = execute_steps_parallel_wave(
|
|
self._execute_step,
|
|
task_id,
|
|
plan,
|
|
sender=self.identity,
|
|
max_workers=self._max_workers,
|
|
)
|
|
successes = sum(1 for r in results if r.success)
|
|
failures = [r for r in results if not r.success]
|
|
if failures:
|
|
return envelope.create_response(
|
|
"run_failed",
|
|
payload={
|
|
"error": f"Step(s) failed: {[f.step_id for f in failures]}",
|
|
"results": [
|
|
{"step_id": r.step_id, "success": r.success, "error": r.error}
|
|
for r in results
|
|
],
|
|
},
|
|
)
|
|
return envelope.create_response(
|
|
"run_completed",
|
|
payload={
|
|
"steps_completed": successes,
|
|
"results": [{"step_id": r.step_id, "result": r.result} for r in results],
|
|
},
|
|
)
|
|
|
|
# Sequential fallback
|
|
completed: set[str] = set()
|
|
while True:
|
|
ready = ready_steps(plan, completed)
|
|
if not ready:
|
|
break
|
|
step_id = ready[0]
|
|
resp = self._execute_step(task_id, step_id, plan, self.identity)
|
|
if resp and hasattr(resp, "message") and resp.message.intent == "step_done":
|
|
completed.add(step_id)
|
|
else:
|
|
return envelope.create_response(
|
|
"run_failed",
|
|
payload={"error": f"Step {step_id} failed", "step_id": step_id},
|
|
)
|
|
|
|
return envelope.create_response(
|
|
"run_completed",
|
|
payload={"steps_completed": len(completed)},
|
|
)
|