"""Tests for planning modules.""" import pytest from fusionagi.schemas.plan import Plan, PlanStep from fusionagi.planning.graph import topological_order, next_step, get_step from fusionagi.planning.strategies import linear_order, dependency_order, get_strategy class TestPlanValidation: """Test Plan schema validation.""" def test_basic_plan_creation(self): """Test creating a basic plan.""" plan = Plan( steps=[ PlanStep(id="s1", description="First step"), PlanStep(id="s2", description="Second step"), ] ) assert len(plan.steps) == 2 assert plan.step_ids() == ["s1", "s2"] def test_plan_with_dependencies(self): """Test plan with valid dependencies.""" plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second", dependencies=["s1"]), PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]), ] ) assert plan.steps[2].dependencies == ["s1", "s2"] def test_invalid_dependency_reference(self): """Test that invalid dependency references raise error.""" with pytest.raises(ValueError, match="invalid dependencies"): Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second", dependencies=["s999"]), ] ) def test_duplicate_step_ids(self): """Test that duplicate step IDs raise error.""" with pytest.raises(ValueError, match="Duplicate step IDs"): Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s1", description="Duplicate"), ] ) def test_circular_dependency_detection(self): """Test that circular dependencies are detected.""" with pytest.raises(ValueError, match="Circular dependencies"): Plan( steps=[ PlanStep(id="s1", description="First", dependencies=["s2"]), PlanStep(id="s2", description="Second", dependencies=["s1"]), ] ) def test_fallback_path_validation(self): """Test fallback path reference validation.""" # Valid fallback paths plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second"), ], fallback_paths=[["s1", "s2"]], ) assert len(plan.fallback_paths) == 1 # Invalid fallback path reference with pytest.raises(ValueError, match="invalid step references"): Plan( steps=[PlanStep(id="s1", description="First")], fallback_paths=[["s1", "s999"]], ) def test_plan_get_step(self): """Test get_step helper.""" plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second"), ] ) step = plan.get_step("s1") assert step is not None assert step.description == "First" assert plan.get_step("nonexistent") is None def test_plan_get_dependencies(self): """Test get_dependencies helper.""" plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second"), PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]), ] ) deps = plan.get_dependencies("s3") assert len(deps) == 2 assert {d.id for d in deps} == {"s1", "s2"} def test_plan_get_dependents(self): """Test get_dependents helper.""" plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second", dependencies=["s1"]), PlanStep(id="s3", description="Third", dependencies=["s1"]), ] ) dependents = plan.get_dependents("s1") assert len(dependents) == 2 assert {d.id for d in dependents} == {"s2", "s3"} def test_plan_topological_order(self): """Test plan's topological_order method.""" plan = Plan( steps=[ PlanStep(id="s3", description="Third", dependencies=["s1", "s2"]), PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second", dependencies=["s1"]), ] ) order = plan.topological_order() # s1 must come before s2 and s3 assert order.index("s1") < order.index("s2") assert order.index("s1") < order.index("s3") # s2 must come before s3 assert order.index("s2") < order.index("s3") class TestPlanGraph: """Test planning graph functions.""" def test_topological_order_simple(self): """Test simple topological ordering.""" plan = Plan( steps=[ PlanStep(id="a", description="A"), PlanStep(id="b", description="B", dependencies=["a"]), PlanStep(id="c", description="C", dependencies=["b"]), ] ) order = topological_order(plan) assert order == ["a", "b", "c"] def test_topological_order_parallel(self): """Test topological order with parallel steps.""" plan = Plan( steps=[ PlanStep(id="root", description="Root"), PlanStep(id="a", description="A", dependencies=["root"]), PlanStep(id="b", description="B", dependencies=["root"]), PlanStep(id="final", description="Final", dependencies=["a", "b"]), ] ) order = topological_order(plan) # root must be first assert order[0] == "root" # final must be last assert order[-1] == "final" # a and b must be between root and final assert "a" in order[1:3] assert "b" in order[1:3] def test_get_step(self): """Test get_step function.""" plan = Plan( steps=[ PlanStep(id="s1", description="Step 1"), PlanStep(id="s2", description="Step 2"), ] ) step = get_step(plan, "s1") assert step is not None assert step.description == "Step 1" assert get_step(plan, "nonexistent") is None def test_next_step(self): """Test next_step function.""" plan = Plan( steps=[ PlanStep(id="s1", description="Step 1"), PlanStep(id="s2", description="Step 2", dependencies=["s1"]), PlanStep(id="s3", description="Step 3", dependencies=["s2"]), ] ) # First call with no completed steps - s1 has no deps step_id = next_step(plan, completed_step_ids=set()) assert step_id == "s1" # After completing s1 - s2 is available step_id = next_step(plan, completed_step_ids={"s1"}) assert step_id == "s2" # After completing s1, s2 - s3 is available step_id = next_step(plan, completed_step_ids={"s1", "s2"}) assert step_id == "s3" # All completed step_id = next_step(plan, completed_step_ids={"s1", "s2", "s3"}) assert step_id is None class TestPlanningStrategies: """Test planning strategy functions.""" def test_linear_order(self): """Test linear ordering strategy.""" plan = Plan( steps=[ PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second"), PlanStep(id="s3", description="Third"), ] ) order = linear_order(plan) assert order == ["s1", "s2", "s3"] def test_dependency_order(self): """Test dependency ordering strategy.""" plan = Plan( steps=[ PlanStep(id="s3", description="Third", dependencies=["s2"]), PlanStep(id="s1", description="First"), PlanStep(id="s2", description="Second", dependencies=["s1"]), ] ) order = dependency_order(plan) assert order.index("s1") < order.index("s2") assert order.index("s2") < order.index("s3") def test_get_strategy(self): """Test strategy getter.""" linear = get_strategy("linear") assert linear == linear_order dep = get_strategy("dependency") assert dep == dependency_order # Unknown strategy defaults to dependency unknown = get_strategy("unknown") assert unknown == dependency_order class TestPlanSerialization: """Test Plan serialization.""" def test_to_dict(self): """Test serialization to dict.""" plan = Plan( steps=[ PlanStep(id="s1", description="Step 1", tool_name="tool1"), ], metadata={"key": "value"}, ) d = plan.to_dict() assert "steps" in d assert len(d["steps"]) == 1 assert d["steps"][0]["id"] == "s1" assert d["metadata"]["key"] == "value" def test_from_dict(self): """Test deserialization from dict.""" d = { "steps": [ {"id": "s1", "description": "Step 1"}, {"id": "s2", "description": "Step 2", "dependencies": ["s1"]}, ], "metadata": {"source": "test"}, } plan = Plan.from_dict(d) assert len(plan.steps) == 2 assert plan.steps[1].dependencies == ["s1"] assert plan.metadata["source"] == "test" def test_roundtrip(self): """Test serialization roundtrip.""" original = Plan( steps=[ PlanStep(id="s1", description="First", tool_name="tool_a"), PlanStep(id="s2", description="Second", dependencies=["s1"]), ], fallback_paths=[["s1", "s2"]], metadata={"version": 1}, ) d = original.to_dict() restored = Plan.from_dict(d) assert restored.step_ids() == original.step_ids() assert restored.steps[0].tool_name == "tool_a" assert restored.fallback_paths == original.fallback_paths