311 lines
12 KiB
Python
311 lines
12 KiB
Python
"""Orchestrator: task lifecycle, agent registry, wiring to event bus and state."""
|
|
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any, Callable, Protocol, runtime_checkable
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from fusionagi.schemas.task import Task, TaskState, TaskPriority, VALID_TASK_TRANSITIONS
|
|
from fusionagi.schemas.messages import AgentMessageEnvelope
|
|
|
|
from fusionagi.core.event_bus import EventBus
|
|
from fusionagi.core.state_manager import StateManager
|
|
from fusionagi._logger import logger
|
|
|
|
# Single source of truth: re-export from schemas for backward compatibility
|
|
VALID_STATE_TRANSITIONS = VALID_TASK_TRANSITIONS
|
|
|
|
|
|
class InvalidStateTransitionError(Exception):
|
|
"""Raised when an invalid state transition is attempted."""
|
|
|
|
def __init__(self, task_id: str, from_state: TaskState, to_state: TaskState) -> None:
|
|
self.task_id = task_id
|
|
self.from_state = from_state
|
|
self.to_state = to_state
|
|
super().__init__(
|
|
f"Invalid state transition for task {task_id}: {from_state.value} -> {to_state.value}"
|
|
)
|
|
|
|
|
|
@runtime_checkable
|
|
class AgentProtocol(Protocol):
|
|
"""Protocol for agents that can handle messages."""
|
|
|
|
identity: str
|
|
|
|
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
|
"""Handle an incoming message and optionally return a response."""
|
|
...
|
|
|
|
|
|
class TaskGraphEntry(BaseModel):
|
|
"""Per-task plan/metadata storage (plan cache)."""
|
|
|
|
plan: dict[str, Any] | None = Field(default=None, description="Stored plan for the task")
|
|
|
|
|
|
class Orchestrator:
|
|
"""
|
|
Global task lifecycle and agent coordination; holds task plans, event bus, state, agent registry.
|
|
|
|
Task state lifecycle: submit_task creates PENDING. Callers/supervisors must call set_task_state
|
|
to transition to ACTIVE, COMPLETED, FAILED, or CANCELLED. The orchestrator validates state
|
|
transitions according to VALID_STATE_TRANSITIONS.
|
|
|
|
Valid transitions:
|
|
PENDING -> ACTIVE, CANCELLED
|
|
ACTIVE -> COMPLETED, FAILED, CANCELLED
|
|
FAILED -> PENDING (retry), CANCELLED
|
|
COMPLETED -> (terminal)
|
|
CANCELLED -> (terminal)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
event_bus: EventBus,
|
|
state_manager: StateManager,
|
|
validate_transitions: bool = True,
|
|
) -> None:
|
|
"""
|
|
Initialize the orchestrator.
|
|
|
|
Args:
|
|
event_bus: Event bus for publishing events.
|
|
state_manager: State manager for task state.
|
|
validate_transitions: If True, validate state transitions (default True).
|
|
"""
|
|
self._event_bus = event_bus
|
|
self._state = state_manager
|
|
self._validate_transitions = validate_transitions
|
|
self._agents: dict[str, AgentProtocol | Any] = {} # agent_id -> agent instance
|
|
self._sub_agents: dict[str, list[str]] = {} # parent_id -> [child_id]
|
|
self._task_plans: dict[str, TaskGraphEntry] = {} # task_id -> plan/metadata per task
|
|
self._async_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="orch_async")
|
|
|
|
def register_agent(self, agent_id: str, agent: Any) -> None:
|
|
"""Register an agent by id for routing and assignment."""
|
|
self._agents[agent_id] = agent
|
|
logger.info("Agent registered", extra={"agent_id": agent_id})
|
|
|
|
def unregister_agent(self, agent_id: str) -> None:
|
|
"""Remove an agent from the registry and from any parent's sub-agent list."""
|
|
self._agents.pop(agent_id, None)
|
|
self._sub_agents.pop(agent_id, None)
|
|
for parent_id, children in list(self._sub_agents.items()):
|
|
if agent_id in children:
|
|
self._sub_agents[parent_id] = [c for c in children if c != agent_id]
|
|
logger.info("Agent unregistered", extra={"agent_id": agent_id})
|
|
|
|
def register_sub_agent(self, parent_id: str, child_id: str, agent: Any) -> None:
|
|
"""Register a sub-agent under a parent; child can be delegated sub-tasks."""
|
|
self._agents[child_id] = agent
|
|
self._sub_agents.setdefault(parent_id, []).append(child_id)
|
|
logger.info("Sub-agent registered", extra={"parent_id": parent_id, "child_id": child_id})
|
|
|
|
def get_sub_agents(self, parent_id: str) -> list[str]:
|
|
"""Return list of child agent ids for a parent."""
|
|
return list(self._sub_agents.get(parent_id, []))
|
|
|
|
def get_agent(self, agent_id: str) -> Any | None:
|
|
"""Return registered agent by id or None."""
|
|
return self._agents.get(agent_id)
|
|
|
|
def shutdown(self, wait: bool = True) -> None:
|
|
"""Shut down the async executor used for route_message_async. Call when the orchestrator is no longer needed."""
|
|
self._async_executor.shutdown(wait=wait)
|
|
logger.debug("Orchestrator async executor shut down", extra={"wait": wait})
|
|
|
|
def submit_task(
|
|
self,
|
|
goal: str,
|
|
constraints: list[str] | None = None,
|
|
priority: TaskPriority = TaskPriority.NORMAL,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> str:
|
|
"""Create a task and publish task_created; returns task_id."""
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(
|
|
task_id=task_id,
|
|
goal=goal,
|
|
constraints=constraints or [],
|
|
priority=priority,
|
|
state=TaskState.PENDING,
|
|
metadata=metadata or {},
|
|
)
|
|
self._state.set_task(task)
|
|
self._task_plans[task_id] = TaskGraphEntry()
|
|
logger.info(
|
|
"Task created",
|
|
extra={"task_id": task_id, "goal": goal[:200] if goal else ""},
|
|
)
|
|
self._event_bus.publish(
|
|
"task_created",
|
|
{"task_id": task_id, "goal": goal, "constraints": task.constraints},
|
|
)
|
|
return task_id
|
|
|
|
def get_task_state(self, task_id: str) -> TaskState | None:
|
|
"""Return current state of a task or None if unknown."""
|
|
return self._state.get_task_state(task_id)
|
|
|
|
def get_task(self, task_id: str) -> Task | None:
|
|
"""Return full task or None."""
|
|
return self._state.get_task(task_id)
|
|
|
|
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None:
|
|
"""Store plan in task plans for a task."""
|
|
if task_id in self._task_plans:
|
|
self._task_plans[task_id].plan = plan
|
|
|
|
def get_task_plan(self, task_id: str) -> dict[str, Any] | None:
|
|
"""Return stored plan for a task or None."""
|
|
entry = self._task_plans.get(task_id)
|
|
return entry.plan if entry else None
|
|
|
|
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None:
|
|
"""
|
|
Update task state with transition validation.
|
|
|
|
Args:
|
|
task_id: The task identifier.
|
|
state: The new state to transition to.
|
|
force: If True, skip transition validation (use with caution).
|
|
|
|
Raises:
|
|
InvalidStateTransitionError: If the transition is not allowed and force=False.
|
|
ValueError: If task_id is unknown.
|
|
"""
|
|
current_state = self._state.get_task_state(task_id)
|
|
if current_state is None:
|
|
raise ValueError(f"Unknown task: {task_id}")
|
|
|
|
if not force and self._validate_transitions:
|
|
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
|
if state not in allowed and state != current_state:
|
|
raise InvalidStateTransitionError(task_id, current_state, state)
|
|
|
|
self._state.set_task_state(task_id, state)
|
|
logger.debug(
|
|
"Task state set",
|
|
extra={
|
|
"task_id": task_id,
|
|
"from_state": current_state.value,
|
|
"to_state": state.value,
|
|
},
|
|
)
|
|
self._event_bus.publish(
|
|
"task_state_changed",
|
|
{"task_id": task_id, "from_state": current_state.value, "to_state": state.value},
|
|
)
|
|
|
|
def can_transition(self, task_id: str, state: TaskState) -> bool:
|
|
"""Check if a state transition is valid without performing it."""
|
|
current_state = self._state.get_task_state(task_id)
|
|
if current_state is None:
|
|
return False
|
|
if state == current_state:
|
|
return True
|
|
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
|
return state in allowed
|
|
|
|
def route_message(self, envelope: AgentMessageEnvelope) -> None:
|
|
"""
|
|
Deliver an envelope to the recipient agent and publish message_received.
|
|
Does not route the agent's response; use route_message_return to get and optionally
|
|
re-route the response envelope.
|
|
"""
|
|
recipient = envelope.message.recipient
|
|
intent = envelope.message.intent
|
|
task_id = envelope.task_id or ""
|
|
logger.info(
|
|
"Message routed",
|
|
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
|
)
|
|
agent = self._agents.get(recipient)
|
|
self._event_bus.publish(
|
|
"message_received",
|
|
{
|
|
"task_id": envelope.task_id,
|
|
"recipient": recipient,
|
|
"intent": intent,
|
|
},
|
|
)
|
|
if agent is not None and hasattr(agent, "handle_message"):
|
|
agent.handle_message(envelope)
|
|
|
|
def route_message_return(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
|
"""
|
|
Deliver an envelope to the recipient agent and return the response envelope, if any.
|
|
Use this when the caller needs to handle or re-route the agent's response.
|
|
"""
|
|
recipient = envelope.message.recipient
|
|
intent = envelope.message.intent
|
|
task_id = envelope.task_id or ""
|
|
logger.info(
|
|
"Message routed",
|
|
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
|
)
|
|
agent = self._agents.get(recipient)
|
|
self._event_bus.publish(
|
|
"message_received",
|
|
{
|
|
"task_id": envelope.task_id,
|
|
"recipient": recipient,
|
|
"intent": intent,
|
|
},
|
|
)
|
|
if agent is not None and hasattr(agent, "handle_message"):
|
|
return agent.handle_message(envelope)
|
|
return None
|
|
|
|
def route_messages_batch(
|
|
self,
|
|
envelopes: list[AgentMessageEnvelope],
|
|
) -> list[AgentMessageEnvelope | None]:
|
|
"""
|
|
Route multiple messages; return responses in same order.
|
|
Uses concurrent execution for parallel dispatch.
|
|
"""
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
results: list[AgentMessageEnvelope | None] = [None] * len(envelopes)
|
|
|
|
def route_one(i: int, env: AgentMessageEnvelope) -> tuple[int, AgentMessageEnvelope | None]:
|
|
return i, self.route_message_return(env)
|
|
|
|
with ThreadPoolExecutor(max_workers=min(len(envelopes), 32)) as ex:
|
|
futures = [ex.submit(route_one, i, env) for i, env in enumerate(envelopes)]
|
|
for fut in as_completed(futures):
|
|
idx, resp = fut.result()
|
|
results[idx] = resp
|
|
|
|
return results
|
|
|
|
def route_message_async(
|
|
self,
|
|
envelope: AgentMessageEnvelope,
|
|
callback: Callable[[AgentMessageEnvelope | None], None] | None = None,
|
|
) -> Any:
|
|
"""
|
|
Route message in background; optionally invoke callback with response.
|
|
Returns Future for non-blocking await.
|
|
"""
|
|
from concurrent import futures
|
|
|
|
def run() -> AgentMessageEnvelope | None:
|
|
return self.route_message_return(envelope)
|
|
|
|
future = self._async_executor.submit(run)
|
|
if callback:
|
|
|
|
def done(f: futures.Future) -> None:
|
|
try:
|
|
callback(f.result())
|
|
except Exception:
|
|
logger.exception("Async route callback failed")
|
|
|
|
future.add_done_callback(done)
|
|
return future
|