Files
FusionAGI/fusionagi/core/orchestrator.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

311 lines
12 KiB
Python

"""Orchestrator: task lifecycle, agent registry, wiring to event bus and state."""
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from fusionagi.schemas.task import Task, TaskState, TaskPriority, VALID_TASK_TRANSITIONS
from fusionagi.schemas.messages import AgentMessageEnvelope
from fusionagi.core.event_bus import EventBus
from fusionagi.core.state_manager import StateManager
from fusionagi._logger import logger
# Single source of truth: re-export from schemas for backward compatibility
VALID_STATE_TRANSITIONS = VALID_TASK_TRANSITIONS
class InvalidStateTransitionError(Exception):
"""Raised when an invalid state transition is attempted."""
def __init__(self, task_id: str, from_state: TaskState, to_state: TaskState) -> None:
self.task_id = task_id
self.from_state = from_state
self.to_state = to_state
super().__init__(
f"Invalid state transition for task {task_id}: {from_state.value} -> {to_state.value}"
)
@runtime_checkable
class AgentProtocol(Protocol):
"""Protocol for agents that can handle messages."""
identity: str
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
"""Handle an incoming message and optionally return a response."""
...
class TaskGraphEntry(BaseModel):
"""Per-task plan/metadata storage (plan cache)."""
plan: dict[str, Any] | None = Field(default=None, description="Stored plan for the task")
class Orchestrator:
"""
Global task lifecycle and agent coordination; holds task plans, event bus, state, agent registry.
Task state lifecycle: submit_task creates PENDING. Callers/supervisors must call set_task_state
to transition to ACTIVE, COMPLETED, FAILED, or CANCELLED. The orchestrator validates state
transitions according to VALID_STATE_TRANSITIONS.
Valid transitions:
PENDING -> ACTIVE, CANCELLED
ACTIVE -> COMPLETED, FAILED, CANCELLED
FAILED -> PENDING (retry), CANCELLED
COMPLETED -> (terminal)
CANCELLED -> (terminal)
"""
def __init__(
self,
event_bus: EventBus,
state_manager: StateManager,
validate_transitions: bool = True,
) -> None:
"""
Initialize the orchestrator.
Args:
event_bus: Event bus for publishing events.
state_manager: State manager for task state.
validate_transitions: If True, validate state transitions (default True).
"""
self._event_bus = event_bus
self._state = state_manager
self._validate_transitions = validate_transitions
self._agents: dict[str, AgentProtocol | Any] = {} # agent_id -> agent instance
self._sub_agents: dict[str, list[str]] = {} # parent_id -> [child_id]
self._task_plans: dict[str, TaskGraphEntry] = {} # task_id -> plan/metadata per task
self._async_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="orch_async")
def register_agent(self, agent_id: str, agent: Any) -> None:
"""Register an agent by id for routing and assignment."""
self._agents[agent_id] = agent
logger.info("Agent registered", extra={"agent_id": agent_id})
def unregister_agent(self, agent_id: str) -> None:
"""Remove an agent from the registry and from any parent's sub-agent list."""
self._agents.pop(agent_id, None)
self._sub_agents.pop(agent_id, None)
for parent_id, children in list(self._sub_agents.items()):
if agent_id in children:
self._sub_agents[parent_id] = [c for c in children if c != agent_id]
logger.info("Agent unregistered", extra={"agent_id": agent_id})
def register_sub_agent(self, parent_id: str, child_id: str, agent: Any) -> None:
"""Register a sub-agent under a parent; child can be delegated sub-tasks."""
self._agents[child_id] = agent
self._sub_agents.setdefault(parent_id, []).append(child_id)
logger.info("Sub-agent registered", extra={"parent_id": parent_id, "child_id": child_id})
def get_sub_agents(self, parent_id: str) -> list[str]:
"""Return list of child agent ids for a parent."""
return list(self._sub_agents.get(parent_id, []))
def get_agent(self, agent_id: str) -> Any | None:
"""Return registered agent by id or None."""
return self._agents.get(agent_id)
def shutdown(self, wait: bool = True) -> None:
"""Shut down the async executor used for route_message_async. Call when the orchestrator is no longer needed."""
self._async_executor.shutdown(wait=wait)
logger.debug("Orchestrator async executor shut down", extra={"wait": wait})
def submit_task(
self,
goal: str,
constraints: list[str] | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
metadata: dict[str, Any] | None = None,
) -> str:
"""Create a task and publish task_created; returns task_id."""
task_id = str(uuid.uuid4())
task = Task(
task_id=task_id,
goal=goal,
constraints=constraints or [],
priority=priority,
state=TaskState.PENDING,
metadata=metadata or {},
)
self._state.set_task(task)
self._task_plans[task_id] = TaskGraphEntry()
logger.info(
"Task created",
extra={"task_id": task_id, "goal": goal[:200] if goal else ""},
)
self._event_bus.publish(
"task_created",
{"task_id": task_id, "goal": goal, "constraints": task.constraints},
)
return task_id
def get_task_state(self, task_id: str) -> TaskState | None:
"""Return current state of a task or None if unknown."""
return self._state.get_task_state(task_id)
def get_task(self, task_id: str) -> Task | None:
"""Return full task or None."""
return self._state.get_task(task_id)
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None:
"""Store plan in task plans for a task."""
if task_id in self._task_plans:
self._task_plans[task_id].plan = plan
def get_task_plan(self, task_id: str) -> dict[str, Any] | None:
"""Return stored plan for a task or None."""
entry = self._task_plans.get(task_id)
return entry.plan if entry else None
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None:
"""
Update task state with transition validation.
Args:
task_id: The task identifier.
state: The new state to transition to.
force: If True, skip transition validation (use with caution).
Raises:
InvalidStateTransitionError: If the transition is not allowed and force=False.
ValueError: If task_id is unknown.
"""
current_state = self._state.get_task_state(task_id)
if current_state is None:
raise ValueError(f"Unknown task: {task_id}")
if not force and self._validate_transitions:
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
if state not in allowed and state != current_state:
raise InvalidStateTransitionError(task_id, current_state, state)
self._state.set_task_state(task_id, state)
logger.debug(
"Task state set",
extra={
"task_id": task_id,
"from_state": current_state.value,
"to_state": state.value,
},
)
self._event_bus.publish(
"task_state_changed",
{"task_id": task_id, "from_state": current_state.value, "to_state": state.value},
)
def can_transition(self, task_id: str, state: TaskState) -> bool:
"""Check if a state transition is valid without performing it."""
current_state = self._state.get_task_state(task_id)
if current_state is None:
return False
if state == current_state:
return True
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
return state in allowed
def route_message(self, envelope: AgentMessageEnvelope) -> None:
"""
Deliver an envelope to the recipient agent and publish message_received.
Does not route the agent's response; use route_message_return to get and optionally
re-route the response envelope.
"""
recipient = envelope.message.recipient
intent = envelope.message.intent
task_id = envelope.task_id or ""
logger.info(
"Message routed",
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
)
agent = self._agents.get(recipient)
self._event_bus.publish(
"message_received",
{
"task_id": envelope.task_id,
"recipient": recipient,
"intent": intent,
},
)
if agent is not None and hasattr(agent, "handle_message"):
agent.handle_message(envelope)
def route_message_return(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
"""
Deliver an envelope to the recipient agent and return the response envelope, if any.
Use this when the caller needs to handle or re-route the agent's response.
"""
recipient = envelope.message.recipient
intent = envelope.message.intent
task_id = envelope.task_id or ""
logger.info(
"Message routed",
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
)
agent = self._agents.get(recipient)
self._event_bus.publish(
"message_received",
{
"task_id": envelope.task_id,
"recipient": recipient,
"intent": intent,
},
)
if agent is not None and hasattr(agent, "handle_message"):
return agent.handle_message(envelope)
return None
def route_messages_batch(
self,
envelopes: list[AgentMessageEnvelope],
) -> list[AgentMessageEnvelope | None]:
"""
Route multiple messages; return responses in same order.
Uses concurrent execution for parallel dispatch.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
results: list[AgentMessageEnvelope | None] = [None] * len(envelopes)
def route_one(i: int, env: AgentMessageEnvelope) -> tuple[int, AgentMessageEnvelope | None]:
return i, self.route_message_return(env)
with ThreadPoolExecutor(max_workers=min(len(envelopes), 32)) as ex:
futures = [ex.submit(route_one, i, env) for i, env in enumerate(envelopes)]
for fut in as_completed(futures):
idx, resp = fut.result()
results[idx] = resp
return results
def route_message_async(
self,
envelope: AgentMessageEnvelope,
callback: Callable[[AgentMessageEnvelope | None], None] | None = None,
) -> Any:
"""
Route message in background; optionally invoke callback with response.
Returns Future for non-blocking await.
"""
from concurrent import futures
def run() -> AgentMessageEnvelope | None:
return self.route_message_return(envelope)
future = self._async_executor.submit(run)
if callback:
def done(f: futures.Future) -> None:
try:
callback(f.result())
except Exception:
logger.exception("Async route callback failed")
future.add_done_callback(done)
return future