191 lines
6.0 KiB
Python
191 lines
6.0 KiB
Python
"""Agent pool: load-balanced routing for horizontal scaling.
|
|
|
|
Multiple executor (or other) agents behind a single logical endpoint.
|
|
Supports round-robin, least-busy, and random selection strategies.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable
|
|
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@dataclass
|
|
class PooledAgent:
|
|
"""An agent in the pool with load tracking."""
|
|
|
|
agent_id: str
|
|
agent: Any # AgentProtocol
|
|
in_flight: int = 0
|
|
total_dispatched: int = 0
|
|
last_used: float = field(default_factory=time.monotonic)
|
|
|
|
|
|
class AgentPool:
|
|
"""
|
|
Pool of agents for load-balanced dispatch.
|
|
|
|
Strategies:
|
|
- round_robin: Rotate through agents in order.
|
|
- least_busy: Prefer agent with lowest in_flight count.
|
|
- random: Random selection (useful for load spreading).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
strategy: str = "least_busy",
|
|
) -> None:
|
|
self._strategy = strategy
|
|
self._agents: list[PooledAgent] = []
|
|
self._round_robin_index = 0
|
|
self._lock = threading.Lock()
|
|
|
|
def add(self, agent_id: str, agent: Any) -> None:
|
|
"""Add an agent to the pool."""
|
|
with self._lock:
|
|
if any(p.agent_id == agent_id for p in self._agents):
|
|
return
|
|
self._agents.append(PooledAgent(agent_id=agent_id, agent=agent))
|
|
logger.info("Agent added to pool", extra={"agent_id": agent_id, "pool_size": len(self._agents)})
|
|
|
|
def remove(self, agent_id: str) -> bool:
|
|
"""Remove an agent from the pool."""
|
|
with self._lock:
|
|
for i, p in enumerate(self._agents):
|
|
if p.agent_id == agent_id:
|
|
self._agents.pop(i)
|
|
return True
|
|
return False
|
|
|
|
def _select(self) -> PooledAgent | None:
|
|
"""Select an agent based on strategy."""
|
|
with self._lock:
|
|
if not self._agents:
|
|
return None
|
|
|
|
if self._strategy == "round_robin":
|
|
idx = self._round_robin_index % len(self._agents)
|
|
self._round_robin_index += 1
|
|
return self._agents[idx]
|
|
|
|
if self._strategy == "random":
|
|
return random.choice(self._agents)
|
|
|
|
# least_busy (default)
|
|
return min(self._agents, key=lambda p: (p.in_flight, p.last_used))
|
|
|
|
def dispatch(
|
|
self,
|
|
envelope: AgentMessageEnvelope,
|
|
on_complete: Callable[[str], None] | None = None,
|
|
rewrite_recipient: bool = True,
|
|
) -> Any:
|
|
"""
|
|
Dispatch to a pooled agent and return response.
|
|
|
|
Tracks in-flight for least_busy; calls on_complete(agent_id) when done
|
|
if provided (for async cleanup).
|
|
|
|
If rewrite_recipient, the envelope's recipient is set to the selected
|
|
agent's id so the agent receives it correctly.
|
|
"""
|
|
pooled = self._select()
|
|
if not pooled:
|
|
logger.error("Agent pool empty, cannot dispatch")
|
|
return None
|
|
|
|
with self._lock:
|
|
pooled.in_flight += 1
|
|
pooled.total_dispatched += 1
|
|
pooled.last_used = time.monotonic()
|
|
|
|
# Rewrite recipient so pooled agent receives correctly
|
|
if rewrite_recipient:
|
|
msg = envelope.message
|
|
envelope = AgentMessageEnvelope(
|
|
message=AgentMessage(
|
|
sender=msg.sender,
|
|
recipient=pooled.agent_id,
|
|
intent=msg.intent,
|
|
payload=msg.payload,
|
|
confidence=msg.confidence,
|
|
uncertainty=msg.uncertainty,
|
|
timestamp=msg.timestamp,
|
|
),
|
|
task_id=envelope.task_id,
|
|
correlation_id=envelope.correlation_id,
|
|
)
|
|
|
|
try:
|
|
agent = pooled.agent
|
|
if hasattr(agent, "handle_message"):
|
|
response = agent.handle_message(envelope)
|
|
# Ensure response recipient points back to original sender
|
|
return response
|
|
return None
|
|
finally:
|
|
with self._lock:
|
|
pooled.in_flight = max(0, pooled.in_flight - 1)
|
|
if on_complete:
|
|
on_complete(pooled.agent_id)
|
|
|
|
def size(self) -> int:
|
|
"""Return pool size."""
|
|
return len(self._agents)
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""Return pool statistics for monitoring."""
|
|
with self._lock:
|
|
return {
|
|
"strategy": self._strategy,
|
|
"size": len(self._agents),
|
|
"agents": [
|
|
{
|
|
"id": p.agent_id,
|
|
"in_flight": p.in_flight,
|
|
"total_dispatched": p.total_dispatched,
|
|
}
|
|
for p in self._agents
|
|
],
|
|
}
|
|
|
|
|
|
class PooledExecutorRouter:
|
|
"""
|
|
Routes execute_step messages to a pool of executors.
|
|
|
|
Wraps multiple ExecutorAgent instances; orchestrator or supervisor
|
|
sends to this router's identity, and it load-balances to the pool.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
identity: str = "executor_pool",
|
|
pool: AgentPool | None = None,
|
|
) -> None:
|
|
self.identity = identity
|
|
self._pool = pool or AgentPool(strategy="least_busy")
|
|
|
|
def add_executor(self, executor_id: str, executor: Any) -> None:
|
|
"""Add an executor to the pool."""
|
|
self._pool.add(executor_id, executor)
|
|
|
|
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
|
"""Route execute_step to pooled executor; return response."""
|
|
if envelope.message.intent != "execute_step":
|
|
return None
|
|
|
|
# Rewrite recipient so response comes back to original sender
|
|
response = self._pool.dispatch(envelope)
|
|
return response
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""Pool statistics."""
|
|
return self._pool.stats()
|