Files
FusionAGI/tests/test_multi_agent.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

228 lines
7.5 KiB
Python

"""Tests for multi-agent accelerations: parallel execution, pool, delegation."""
import pytest
from fusionagi.planning import ready_steps
from fusionagi.schemas.plan import Plan, PlanStep
from fusionagi.multi_agent import (
execute_steps_parallel,
ParallelStepResult,
AgentPool,
PooledExecutorRouter,
delegate_sub_tasks,
DelegationConfig,
SubTask,
)
from fusionagi.core import EventBus, StateManager, Orchestrator
from fusionagi.agents import ExecutorAgent, PlannerAgent
from fusionagi.tools import ToolRegistry
from fusionagi.adapters import StubAdapter
class TestReadySteps:
"""Test ready_steps for parallel dispatch."""
def test_parallel_ready_steps(self):
"""Steps with same deps are ready together."""
plan = Plan(
steps=[
PlanStep(id="s1", description="First"),
PlanStep(id="s2", description="Second", dependencies=["s1"]),
PlanStep(id="s3", description="Third", dependencies=["s1"]),
PlanStep(id="s4", description="Fourth", dependencies=["s2", "s3"]),
]
)
assert ready_steps(plan, set()) == ["s1"]
assert set(ready_steps(plan, {"s1"})) == {"s2", "s3"}
assert ready_steps(plan, {"s1", "s2", "s3"}) == ["s4"]
assert ready_steps(plan, {"s1", "s2", "s3", "s4"}) == []
class TestAgentPool:
"""Test AgentPool and PooledExecutorRouter."""
def test_pool_round_robin(self):
"""Round-robin selection rotates through agents."""
pool = AgentPool(strategy="round_robin")
calls = []
class FakeAgent:
def __init__(self, aid):
self.identity = aid
def handle_message(self, env):
calls.append(self.identity)
return None
pool.add("a1", FakeAgent("a1"))
pool.add("a2", FakeAgent("a2"))
pool.add("a3", FakeAgent("a3"))
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
env = AgentMessageEnvelope(
message=AgentMessage(sender="x", recipient="pool", intent="execute_step", payload={}),
task_id="t1",
)
for _ in range(6):
pool.dispatch(env)
assert calls == ["a1", "a2", "a3", "a1", "a2", "a3"]
def test_pool_least_busy(self):
"""Least-busy prefers agent with fewest in-flight."""
pool = AgentPool(strategy="least_busy")
class SlowAgent:
def __init__(self, aid):
self.identity = aid
def handle_message(self, env):
import time
time.sleep(0.05)
return None
pool.add("slow1", SlowAgent("slow1"))
pool.add("slow2", SlowAgent("slow2"))
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
env = AgentMessageEnvelope(
message=AgentMessage(sender="x", recipient="pool", intent="test", payload={}),
)
# Sequential dispatch - both should get used
pool.dispatch(env)
pool.dispatch(env)
stats = pool.stats()
assert stats["size"] == 2
assert sum(a["total_dispatched"] for a in stats["agents"]) == 2
def test_pooled_executor_router(self):
"""PooledExecutorRouter routes to pool."""
registry = ToolRegistry()
state = StateManager()
exec1 = ExecutorAgent(identity="exec1", registry=registry, state_manager=state)
exec2 = ExecutorAgent(identity="exec2", registry=registry, state_manager=state)
router = PooledExecutorRouter(identity="executor_pool")
router.add_executor("exec1", exec1)
router.add_executor("exec2", exec2)
assert router.stats()["size"] == 2
class TestDelegation:
"""Test sub-task delegation."""
def test_delegate_sub_tasks_parallel(self):
"""Delegation runs sub-tasks in parallel."""
results_received = []
def delegate_fn(st: SubTask) -> dict:
results_received.append(st.sub_task_id)
return dict(
sub_task_id=st.sub_task_id,
success=True,
result=f"done-{st.sub_task_id}",
agent_id="agent1",
)
# Wrap to return SubTaskResult
def wrapped(st):
r = delegate_fn(st)
from fusionagi.multi_agent.delegation import SubTaskResult
return SubTaskResult(
sub_task_id=r["sub_task_id"],
success=r["success"],
result=r["result"],
agent_id=r.get("agent_id"),
)
tasks = [
SubTask("t1", "Goal 1"),
SubTask("t2", "Goal 2"),
SubTask("t3", "Goal 3"),
]
config = DelegationConfig(max_parallel=3)
results = delegate_sub_tasks(tasks, wrapped, config)
assert len(results) == 3
assert all(r.success for r in results)
assert set(r.sub_task_id for r in results) == {"t1", "t2", "t3"}
class TestParallelExecution:
"""Test parallel step execution."""
def test_execute_steps_parallel(self):
"""Parallel execution runs ready steps concurrently."""
completed = []
def execute_fn(task_id, step_id, plan, sender):
completed.append(step_id)
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
return AgentMessageEnvelope(
message=AgentMessage(
sender="executor",
recipient=sender,
intent="step_done",
payload={"step_id": step_id, "result": f"ok-{step_id}"},
),
task_id=task_id,
)
plan = Plan(
steps=[
PlanStep(id="s1", description="First"),
PlanStep(id="s2", description="Second", dependencies=["s1"]),
PlanStep(id="s3", description="Third", dependencies=["s1"]),
]
)
# s2 and s3 are ready when s1 is done
results = execute_steps_parallel(
execute_fn, "task1", plan, completed_step_ids={"s1"}, max_workers=4
)
assert len(results) == 2
assert set(r.step_id for r in results) == {"s2", "s3"}
assert all(r.success for r in results)
class TestOrchestratorBatchRouting:
"""Test orchestrator batch routing."""
def test_route_messages_batch(self):
"""Batch routing returns responses in order."""
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
class EchoAgent:
identity = "echo"
def handle_message(self, env):
return AgentMessageEnvelope(
message=AgentMessage(
sender="echo",
recipient=env.message.sender,
intent="echo_reply",
payload={"orig": env.message.payload.get("n")},
),
)
bus = EventBus()
state = StateManager()
orch = Orchestrator(bus, state)
orch.register_agent("echo", EchoAgent())
envelopes = [
AgentMessageEnvelope(
message=AgentMessage(sender="c", recipient="echo", intent="test", payload={"n": i}),
)
for i in range(5)
]
responses = orch.route_messages_batch(envelopes)
assert len(responses) == 5
for i, r in enumerate(responses):
assert r is not None
assert r.message.payload["orig"] == i