"""Orchestrator: task lifecycle, agent registry, wiring to event bus and state.""" import uuid from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Protocol, runtime_checkable from pydantic import BaseModel, Field from fusionagi.schemas.task import Task, TaskState, TaskPriority, VALID_TASK_TRANSITIONS from fusionagi.schemas.messages import AgentMessageEnvelope from fusionagi.core.event_bus import EventBus from fusionagi.core.state_manager import StateManager from fusionagi._logger import logger # Single source of truth: re-export from schemas for backward compatibility VALID_STATE_TRANSITIONS = VALID_TASK_TRANSITIONS class InvalidStateTransitionError(Exception): """Raised when an invalid state transition is attempted.""" def __init__(self, task_id: str, from_state: TaskState, to_state: TaskState) -> None: self.task_id = task_id self.from_state = from_state self.to_state = to_state super().__init__( f"Invalid state transition for task {task_id}: {from_state.value} -> {to_state.value}" ) @runtime_checkable class AgentProtocol(Protocol): """Protocol for agents that can handle messages.""" identity: str def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None: """Handle an incoming message and optionally return a response.""" ... class TaskGraphEntry(BaseModel): """Per-task plan/metadata storage (plan cache).""" plan: dict[str, Any] | None = Field(default=None, description="Stored plan for the task") class Orchestrator: """ Global task lifecycle and agent coordination; holds task plans, event bus, state, agent registry. Task state lifecycle: submit_task creates PENDING. Callers/supervisors must call set_task_state to transition to ACTIVE, COMPLETED, FAILED, or CANCELLED. The orchestrator validates state transitions according to VALID_STATE_TRANSITIONS. Valid transitions: PENDING -> ACTIVE, CANCELLED ACTIVE -> COMPLETED, FAILED, CANCELLED FAILED -> PENDING (retry), CANCELLED COMPLETED -> (terminal) CANCELLED -> (terminal) """ def __init__( self, event_bus: EventBus, state_manager: StateManager, validate_transitions: bool = True, ) -> None: """ Initialize the orchestrator. Args: event_bus: Event bus for publishing events. state_manager: State manager for task state. validate_transitions: If True, validate state transitions (default True). """ self._event_bus = event_bus self._state = state_manager self._validate_transitions = validate_transitions self._agents: dict[str, AgentProtocol | Any] = {} # agent_id -> agent instance self._sub_agents: dict[str, list[str]] = {} # parent_id -> [child_id] self._task_plans: dict[str, TaskGraphEntry] = {} # task_id -> plan/metadata per task self._async_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="orch_async") def register_agent(self, agent_id: str, agent: Any) -> None: """Register an agent by id for routing and assignment.""" self._agents[agent_id] = agent logger.info("Agent registered", extra={"agent_id": agent_id}) def unregister_agent(self, agent_id: str) -> None: """Remove an agent from the registry and from any parent's sub-agent list.""" self._agents.pop(agent_id, None) self._sub_agents.pop(agent_id, None) for parent_id, children in list(self._sub_agents.items()): if agent_id in children: self._sub_agents[parent_id] = [c for c in children if c != agent_id] logger.info("Agent unregistered", extra={"agent_id": agent_id}) def register_sub_agent(self, parent_id: str, child_id: str, agent: Any) -> None: """Register a sub-agent under a parent; child can be delegated sub-tasks.""" self._agents[child_id] = agent self._sub_agents.setdefault(parent_id, []).append(child_id) logger.info("Sub-agent registered", extra={"parent_id": parent_id, "child_id": child_id}) def get_sub_agents(self, parent_id: str) -> list[str]: """Return list of child agent ids for a parent.""" return list(self._sub_agents.get(parent_id, [])) def get_agent(self, agent_id: str) -> Any | None: """Return registered agent by id or None.""" return self._agents.get(agent_id) def shutdown(self, wait: bool = True) -> None: """Shut down the async executor used for route_message_async. Call when the orchestrator is no longer needed.""" self._async_executor.shutdown(wait=wait) logger.debug("Orchestrator async executor shut down", extra={"wait": wait}) def submit_task( self, goal: str, constraints: list[str] | None = None, priority: TaskPriority = TaskPriority.NORMAL, metadata: dict[str, Any] | None = None, ) -> str: """Create a task and publish task_created; returns task_id.""" task_id = str(uuid.uuid4()) task = Task( task_id=task_id, goal=goal, constraints=constraints or [], priority=priority, state=TaskState.PENDING, metadata=metadata or {}, ) self._state.set_task(task) self._task_plans[task_id] = TaskGraphEntry() logger.info( "Task created", extra={"task_id": task_id, "goal": goal[:200] if goal else ""}, ) self._event_bus.publish( "task_created", {"task_id": task_id, "goal": goal, "constraints": task.constraints}, ) return task_id def get_task_state(self, task_id: str) -> TaskState | None: """Return current state of a task or None if unknown.""" return self._state.get_task_state(task_id) def get_task(self, task_id: str) -> Task | None: """Return full task or None.""" return self._state.get_task(task_id) def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None: """Store plan in task plans for a task.""" if task_id in self._task_plans: self._task_plans[task_id].plan = plan def get_task_plan(self, task_id: str) -> dict[str, Any] | None: """Return stored plan for a task or None.""" entry = self._task_plans.get(task_id) return entry.plan if entry else None def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None: """ Update task state with transition validation. Args: task_id: The task identifier. state: The new state to transition to. force: If True, skip transition validation (use with caution). Raises: InvalidStateTransitionError: If the transition is not allowed and force=False. ValueError: If task_id is unknown. """ current_state = self._state.get_task_state(task_id) if current_state is None: raise ValueError(f"Unknown task: {task_id}") if not force and self._validate_transitions: allowed = VALID_TASK_TRANSITIONS.get(current_state, set()) if state not in allowed and state != current_state: raise InvalidStateTransitionError(task_id, current_state, state) self._state.set_task_state(task_id, state) logger.debug( "Task state set", extra={ "task_id": task_id, "from_state": current_state.value, "to_state": state.value, }, ) self._event_bus.publish( "task_state_changed", {"task_id": task_id, "from_state": current_state.value, "to_state": state.value}, ) def can_transition(self, task_id: str, state: TaskState) -> bool: """Check if a state transition is valid without performing it.""" current_state = self._state.get_task_state(task_id) if current_state is None: return False if state == current_state: return True allowed = VALID_TASK_TRANSITIONS.get(current_state, set()) return state in allowed def route_message(self, envelope: AgentMessageEnvelope) -> None: """ Deliver an envelope to the recipient agent and publish message_received. Does not route the agent's response; use route_message_return to get and optionally re-route the response envelope. """ recipient = envelope.message.recipient intent = envelope.message.intent task_id = envelope.task_id or "" logger.info( "Message routed", extra={"task_id": task_id, "recipient": recipient, "intent": intent}, ) agent = self._agents.get(recipient) self._event_bus.publish( "message_received", { "task_id": envelope.task_id, "recipient": recipient, "intent": intent, }, ) if agent is not None and hasattr(agent, "handle_message"): agent.handle_message(envelope) def route_message_return(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None: """ Deliver an envelope to the recipient agent and return the response envelope, if any. Use this when the caller needs to handle or re-route the agent's response. """ recipient = envelope.message.recipient intent = envelope.message.intent task_id = envelope.task_id or "" logger.info( "Message routed", extra={"task_id": task_id, "recipient": recipient, "intent": intent}, ) agent = self._agents.get(recipient) self._event_bus.publish( "message_received", { "task_id": envelope.task_id, "recipient": recipient, "intent": intent, }, ) if agent is not None and hasattr(agent, "handle_message"): return agent.handle_message(envelope) return None def route_messages_batch( self, envelopes: list[AgentMessageEnvelope], ) -> list[AgentMessageEnvelope | None]: """ Route multiple messages; return responses in same order. Uses concurrent execution for parallel dispatch. """ from concurrent.futures import ThreadPoolExecutor, as_completed results: list[AgentMessageEnvelope | None] = [None] * len(envelopes) def route_one(i: int, env: AgentMessageEnvelope) -> tuple[int, AgentMessageEnvelope | None]: return i, self.route_message_return(env) with ThreadPoolExecutor(max_workers=min(len(envelopes), 32)) as ex: futures = [ex.submit(route_one, i, env) for i, env in enumerate(envelopes)] for fut in as_completed(futures): idx, resp = fut.result() results[idx] = resp return results def route_message_async( self, envelope: AgentMessageEnvelope, callback: Callable[[AgentMessageEnvelope | None], None] | None = None, ) -> Any: """ Route message in background; optionally invoke callback with response. Returns Future for non-blocking await. """ from concurrent import futures def run() -> AgentMessageEnvelope | None: return self.route_message_return(envelope) future = self._async_executor.submit(run) if callback: def done(f: futures.Future) -> None: try: callback(f.result()) except Exception: logger.exception("Async route callback failed") future.add_done_callback(done) return future