321 lines
10 KiB
Python
321 lines
10 KiB
Python
"""Tests for planning modules."""
|
|
|
|
import pytest
|
|
|
|
from fusionagi.schemas.plan import Plan, PlanStep
|
|
from fusionagi.planning.graph import topological_order, next_step, get_step
|
|
from fusionagi.planning.strategies import linear_order, dependency_order, get_strategy
|
|
|
|
|
|
class TestPlanValidation:
|
|
"""Test Plan schema validation."""
|
|
|
|
def test_basic_plan_creation(self):
|
|
"""Test creating a basic plan."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First step"),
|
|
PlanStep(id="s2", description="Second step"),
|
|
]
|
|
)
|
|
assert len(plan.steps) == 2
|
|
assert plan.step_ids() == ["s1", "s2"]
|
|
|
|
def test_plan_with_dependencies(self):
|
|
"""Test plan with valid dependencies."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]),
|
|
]
|
|
)
|
|
assert plan.steps[2].dependencies == ["s1", "s2"]
|
|
|
|
def test_invalid_dependency_reference(self):
|
|
"""Test that invalid dependency references raise error."""
|
|
with pytest.raises(ValueError, match="invalid dependencies"):
|
|
Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s999"]),
|
|
]
|
|
)
|
|
|
|
def test_duplicate_step_ids(self):
|
|
"""Test that duplicate step IDs raise error."""
|
|
with pytest.raises(ValueError, match="Duplicate step IDs"):
|
|
Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s1", description="Duplicate"),
|
|
]
|
|
)
|
|
|
|
def test_circular_dependency_detection(self):
|
|
"""Test that circular dependencies are detected."""
|
|
with pytest.raises(ValueError, match="Circular dependencies"):
|
|
Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First", dependencies=["s2"]),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
]
|
|
)
|
|
|
|
def test_fallback_path_validation(self):
|
|
"""Test fallback path reference validation."""
|
|
# Valid fallback paths
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second"),
|
|
],
|
|
fallback_paths=[["s1", "s2"]],
|
|
)
|
|
assert len(plan.fallback_paths) == 1
|
|
|
|
# Invalid fallback path reference
|
|
with pytest.raises(ValueError, match="invalid step references"):
|
|
Plan(
|
|
steps=[PlanStep(id="s1", description="First")],
|
|
fallback_paths=[["s1", "s999"]],
|
|
)
|
|
|
|
def test_plan_get_step(self):
|
|
"""Test get_step helper."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second"),
|
|
]
|
|
)
|
|
|
|
step = plan.get_step("s1")
|
|
assert step is not None
|
|
assert step.description == "First"
|
|
|
|
assert plan.get_step("nonexistent") is None
|
|
|
|
def test_plan_get_dependencies(self):
|
|
"""Test get_dependencies helper."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second"),
|
|
PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]),
|
|
]
|
|
)
|
|
|
|
deps = plan.get_dependencies("s3")
|
|
assert len(deps) == 2
|
|
assert {d.id for d in deps} == {"s1", "s2"}
|
|
|
|
def test_plan_get_dependents(self):
|
|
"""Test get_dependents helper."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
PlanStep(id="s3", description="Third", dependencies=["s1"]),
|
|
]
|
|
)
|
|
|
|
dependents = plan.get_dependents("s1")
|
|
assert len(dependents) == 2
|
|
assert {d.id for d in dependents} == {"s2", "s3"}
|
|
|
|
def test_plan_topological_order(self):
|
|
"""Test plan's topological_order method."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]),
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
]
|
|
)
|
|
|
|
order = plan.topological_order()
|
|
|
|
# s1 must come before s2 and s3
|
|
assert order.index("s1") < order.index("s2")
|
|
assert order.index("s1") < order.index("s3")
|
|
# s2 must come before s3
|
|
assert order.index("s2") < order.index("s3")
|
|
|
|
|
|
class TestPlanGraph:
|
|
"""Test planning graph functions."""
|
|
|
|
def test_topological_order_simple(self):
|
|
"""Test simple topological ordering."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="a", description="A"),
|
|
PlanStep(id="b", description="B", dependencies=["a"]),
|
|
PlanStep(id="c", description="C", dependencies=["b"]),
|
|
]
|
|
)
|
|
|
|
order = topological_order(plan)
|
|
assert order == ["a", "b", "c"]
|
|
|
|
def test_topological_order_parallel(self):
|
|
"""Test topological order with parallel steps."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="root", description="Root"),
|
|
PlanStep(id="a", description="A", dependencies=["root"]),
|
|
PlanStep(id="b", description="B", dependencies=["root"]),
|
|
PlanStep(id="final", description="Final", dependencies=["a", "b"]),
|
|
]
|
|
)
|
|
|
|
order = topological_order(plan)
|
|
|
|
# root must be first
|
|
assert order[0] == "root"
|
|
# final must be last
|
|
assert order[-1] == "final"
|
|
# a and b must be between root and final
|
|
assert "a" in order[1:3]
|
|
assert "b" in order[1:3]
|
|
|
|
def test_get_step(self):
|
|
"""Test get_step function."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="Step 1"),
|
|
PlanStep(id="s2", description="Step 2"),
|
|
]
|
|
)
|
|
|
|
step = get_step(plan, "s1")
|
|
assert step is not None
|
|
assert step.description == "Step 1"
|
|
|
|
assert get_step(plan, "nonexistent") is None
|
|
|
|
def test_next_step(self):
|
|
"""Test next_step function."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="Step 1"),
|
|
PlanStep(id="s2", description="Step 2", dependencies=["s1"]),
|
|
PlanStep(id="s3", description="Step 3", dependencies=["s2"]),
|
|
]
|
|
)
|
|
|
|
# First call with no completed steps - s1 has no deps
|
|
step_id = next_step(plan, completed_step_ids=set())
|
|
assert step_id == "s1"
|
|
|
|
# After completing s1 - s2 is available
|
|
step_id = next_step(plan, completed_step_ids={"s1"})
|
|
assert step_id == "s2"
|
|
|
|
# After completing s1, s2 - s3 is available
|
|
step_id = next_step(plan, completed_step_ids={"s1", "s2"})
|
|
assert step_id == "s3"
|
|
|
|
# All completed
|
|
step_id = next_step(plan, completed_step_ids={"s1", "s2", "s3"})
|
|
assert step_id is None
|
|
|
|
|
|
class TestPlanningStrategies:
|
|
"""Test planning strategy functions."""
|
|
|
|
def test_linear_order(self):
|
|
"""Test linear ordering strategy."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second"),
|
|
PlanStep(id="s3", description="Third"),
|
|
]
|
|
)
|
|
|
|
order = linear_order(plan)
|
|
assert order == ["s1", "s2", "s3"]
|
|
|
|
def test_dependency_order(self):
|
|
"""Test dependency ordering strategy."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s3", description="Third", dependencies=["s2"]),
|
|
PlanStep(id="s1", description="First"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
]
|
|
)
|
|
|
|
order = dependency_order(plan)
|
|
|
|
assert order.index("s1") < order.index("s2")
|
|
assert order.index("s2") < order.index("s3")
|
|
|
|
def test_get_strategy(self):
|
|
"""Test strategy getter."""
|
|
linear = get_strategy("linear")
|
|
assert linear == linear_order
|
|
|
|
dep = get_strategy("dependency")
|
|
assert dep == dependency_order
|
|
|
|
# Unknown strategy defaults to dependency
|
|
unknown = get_strategy("unknown")
|
|
assert unknown == dependency_order
|
|
|
|
|
|
class TestPlanSerialization:
|
|
"""Test Plan serialization."""
|
|
|
|
def test_to_dict(self):
|
|
"""Test serialization to dict."""
|
|
plan = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="Step 1", tool_name="tool1"),
|
|
],
|
|
metadata={"key": "value"},
|
|
)
|
|
|
|
d = plan.to_dict()
|
|
|
|
assert "steps" in d
|
|
assert len(d["steps"]) == 1
|
|
assert d["steps"][0]["id"] == "s1"
|
|
assert d["metadata"]["key"] == "value"
|
|
|
|
def test_from_dict(self):
|
|
"""Test deserialization from dict."""
|
|
d = {
|
|
"steps": [
|
|
{"id": "s1", "description": "Step 1"},
|
|
{"id": "s2", "description": "Step 2", "dependencies": ["s1"]},
|
|
],
|
|
"metadata": {"source": "test"},
|
|
}
|
|
|
|
plan = Plan.from_dict(d)
|
|
|
|
assert len(plan.steps) == 2
|
|
assert plan.steps[1].dependencies == ["s1"]
|
|
assert plan.metadata["source"] == "test"
|
|
|
|
def test_roundtrip(self):
|
|
"""Test serialization roundtrip."""
|
|
original = Plan(
|
|
steps=[
|
|
PlanStep(id="s1", description="First", tool_name="tool_a"),
|
|
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
|
],
|
|
fallback_paths=[["s1", "s2"]],
|
|
metadata={"version": 1},
|
|
)
|
|
|
|
d = original.to_dict()
|
|
restored = Plan.from_dict(d)
|
|
|
|
assert restored.step_ids() == original.step_ids()
|
|
assert restored.steps[0].tool_name == "tool_a"
|
|
assert restored.fallback_paths == original.fallback_paths
|