228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
|
|
"""Tests for multi-agent accelerations: parallel execution, pool, delegation."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from fusionagi.planning import ready_steps
|
||
|
|
from fusionagi.schemas.plan import Plan, PlanStep
|
||
|
|
from fusionagi.multi_agent import (
|
||
|
|
execute_steps_parallel,
|
||
|
|
ParallelStepResult,
|
||
|
|
AgentPool,
|
||
|
|
PooledExecutorRouter,
|
||
|
|
delegate_sub_tasks,
|
||
|
|
DelegationConfig,
|
||
|
|
SubTask,
|
||
|
|
)
|
||
|
|
from fusionagi.core import EventBus, StateManager, Orchestrator
|
||
|
|
from fusionagi.agents import ExecutorAgent, PlannerAgent
|
||
|
|
from fusionagi.tools import ToolRegistry
|
||
|
|
from fusionagi.adapters import StubAdapter
|
||
|
|
|
||
|
|
|
||
|
|
class TestReadySteps:
|
||
|
|
"""Test ready_steps for parallel dispatch."""
|
||
|
|
|
||
|
|
def test_parallel_ready_steps(self):
|
||
|
|
"""Steps with same deps are ready together."""
|
||
|
|
plan = Plan(
|
||
|
|
steps=[
|
||
|
|
PlanStep(id="s1", description="First"),
|
||
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
||
|
|
PlanStep(id="s3", description="Third", dependencies=["s1"]),
|
||
|
|
PlanStep(id="s4", description="Fourth", dependencies=["s2", "s3"]),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
assert ready_steps(plan, set()) == ["s1"]
|
||
|
|
assert set(ready_steps(plan, {"s1"})) == {"s2", "s3"}
|
||
|
|
assert ready_steps(plan, {"s1", "s2", "s3"}) == ["s4"]
|
||
|
|
assert ready_steps(plan, {"s1", "s2", "s3", "s4"}) == []
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentPool:
|
||
|
|
"""Test AgentPool and PooledExecutorRouter."""
|
||
|
|
|
||
|
|
def test_pool_round_robin(self):
|
||
|
|
"""Round-robin selection rotates through agents."""
|
||
|
|
pool = AgentPool(strategy="round_robin")
|
||
|
|
calls = []
|
||
|
|
|
||
|
|
class FakeAgent:
|
||
|
|
def __init__(self, aid):
|
||
|
|
self.identity = aid
|
||
|
|
|
||
|
|
def handle_message(self, env):
|
||
|
|
calls.append(self.identity)
|
||
|
|
return None
|
||
|
|
|
||
|
|
pool.add("a1", FakeAgent("a1"))
|
||
|
|
pool.add("a2", FakeAgent("a2"))
|
||
|
|
pool.add("a3", FakeAgent("a3"))
|
||
|
|
|
||
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||
|
|
|
||
|
|
env = AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(sender="x", recipient="pool", intent="execute_step", payload={}),
|
||
|
|
task_id="t1",
|
||
|
|
)
|
||
|
|
for _ in range(6):
|
||
|
|
pool.dispatch(env)
|
||
|
|
|
||
|
|
assert calls == ["a1", "a2", "a3", "a1", "a2", "a3"]
|
||
|
|
|
||
|
|
def test_pool_least_busy(self):
|
||
|
|
"""Least-busy prefers agent with fewest in-flight."""
|
||
|
|
pool = AgentPool(strategy="least_busy")
|
||
|
|
|
||
|
|
class SlowAgent:
|
||
|
|
def __init__(self, aid):
|
||
|
|
self.identity = aid
|
||
|
|
|
||
|
|
def handle_message(self, env):
|
||
|
|
import time
|
||
|
|
time.sleep(0.05)
|
||
|
|
return None
|
||
|
|
|
||
|
|
pool.add("slow1", SlowAgent("slow1"))
|
||
|
|
pool.add("slow2", SlowAgent("slow2"))
|
||
|
|
|
||
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||
|
|
|
||
|
|
env = AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(sender="x", recipient="pool", intent="test", payload={}),
|
||
|
|
)
|
||
|
|
# Sequential dispatch - both should get used
|
||
|
|
pool.dispatch(env)
|
||
|
|
pool.dispatch(env)
|
||
|
|
stats = pool.stats()
|
||
|
|
assert stats["size"] == 2
|
||
|
|
assert sum(a["total_dispatched"] for a in stats["agents"]) == 2
|
||
|
|
|
||
|
|
def test_pooled_executor_router(self):
|
||
|
|
"""PooledExecutorRouter routes to pool."""
|
||
|
|
registry = ToolRegistry()
|
||
|
|
state = StateManager()
|
||
|
|
exec1 = ExecutorAgent(identity="exec1", registry=registry, state_manager=state)
|
||
|
|
exec2 = ExecutorAgent(identity="exec2", registry=registry, state_manager=state)
|
||
|
|
|
||
|
|
router = PooledExecutorRouter(identity="executor_pool")
|
||
|
|
router.add_executor("exec1", exec1)
|
||
|
|
router.add_executor("exec2", exec2)
|
||
|
|
|
||
|
|
assert router.stats()["size"] == 2
|
||
|
|
|
||
|
|
|
||
|
|
class TestDelegation:
|
||
|
|
"""Test sub-task delegation."""
|
||
|
|
|
||
|
|
def test_delegate_sub_tasks_parallel(self):
|
||
|
|
"""Delegation runs sub-tasks in parallel."""
|
||
|
|
results_received = []
|
||
|
|
|
||
|
|
def delegate_fn(st: SubTask) -> dict:
|
||
|
|
results_received.append(st.sub_task_id)
|
||
|
|
return dict(
|
||
|
|
sub_task_id=st.sub_task_id,
|
||
|
|
success=True,
|
||
|
|
result=f"done-{st.sub_task_id}",
|
||
|
|
agent_id="agent1",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Wrap to return SubTaskResult
|
||
|
|
def wrapped(st):
|
||
|
|
r = delegate_fn(st)
|
||
|
|
from fusionagi.multi_agent.delegation import SubTaskResult
|
||
|
|
return SubTaskResult(
|
||
|
|
sub_task_id=r["sub_task_id"],
|
||
|
|
success=r["success"],
|
||
|
|
result=r["result"],
|
||
|
|
agent_id=r.get("agent_id"),
|
||
|
|
)
|
||
|
|
|
||
|
|
tasks = [
|
||
|
|
SubTask("t1", "Goal 1"),
|
||
|
|
SubTask("t2", "Goal 2"),
|
||
|
|
SubTask("t3", "Goal 3"),
|
||
|
|
]
|
||
|
|
config = DelegationConfig(max_parallel=3)
|
||
|
|
results = delegate_sub_tasks(tasks, wrapped, config)
|
||
|
|
|
||
|
|
assert len(results) == 3
|
||
|
|
assert all(r.success for r in results)
|
||
|
|
assert set(r.sub_task_id for r in results) == {"t1", "t2", "t3"}
|
||
|
|
|
||
|
|
|
||
|
|
class TestParallelExecution:
|
||
|
|
"""Test parallel step execution."""
|
||
|
|
|
||
|
|
def test_execute_steps_parallel(self):
|
||
|
|
"""Parallel execution runs ready steps concurrently."""
|
||
|
|
completed = []
|
||
|
|
|
||
|
|
def execute_fn(task_id, step_id, plan, sender):
|
||
|
|
completed.append(step_id)
|
||
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||
|
|
return AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(
|
||
|
|
sender="executor",
|
||
|
|
recipient=sender,
|
||
|
|
intent="step_done",
|
||
|
|
payload={"step_id": step_id, "result": f"ok-{step_id}"},
|
||
|
|
),
|
||
|
|
task_id=task_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
plan = Plan(
|
||
|
|
steps=[
|
||
|
|
PlanStep(id="s1", description="First"),
|
||
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
||
|
|
PlanStep(id="s3", description="Third", dependencies=["s1"]),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
# s2 and s3 are ready when s1 is done
|
||
|
|
results = execute_steps_parallel(
|
||
|
|
execute_fn, "task1", plan, completed_step_ids={"s1"}, max_workers=4
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(results) == 2
|
||
|
|
assert set(r.step_id for r in results) == {"s2", "s3"}
|
||
|
|
assert all(r.success for r in results)
|
||
|
|
|
||
|
|
|
||
|
|
class TestOrchestratorBatchRouting:
|
||
|
|
"""Test orchestrator batch routing."""
|
||
|
|
|
||
|
|
def test_route_messages_batch(self):
|
||
|
|
"""Batch routing returns responses in order."""
|
||
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||
|
|
|
||
|
|
class EchoAgent:
|
||
|
|
identity = "echo"
|
||
|
|
|
||
|
|
def handle_message(self, env):
|
||
|
|
return AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(
|
||
|
|
sender="echo",
|
||
|
|
recipient=env.message.sender,
|
||
|
|
intent="echo_reply",
|
||
|
|
payload={"orig": env.message.payload.get("n")},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
bus = EventBus()
|
||
|
|
state = StateManager()
|
||
|
|
orch = Orchestrator(bus, state)
|
||
|
|
orch.register_agent("echo", EchoAgent())
|
||
|
|
|
||
|
|
envelopes = [
|
||
|
|
AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(sender="c", recipient="echo", intent="test", payload={"n": i}),
|
||
|
|
)
|
||
|
|
for i in range(5)
|
||
|
|
]
|
||
|
|
responses = orch.route_messages_batch(envelopes)
|
||
|
|
|
||
|
|
assert len(responses) == 5
|
||
|
|
for i, r in enumerate(responses):
|
||
|
|
assert r is not None
|
||
|
|
assert r.message.payload["orig"] == i
|