Files
FusionAGI/fusionagi/schemas/plan.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

202 lines
7.2 KiB
Python

"""Plan schema: steps with ids, dependencies, optional fallback paths with validation."""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
class PlanStep(BaseModel):
"""
Single step in a plan.
Validation:
- id and description must be non-empty
"""
id: str = Field(..., min_length=1, description="Step identifier")
description: str = Field(..., min_length=1, description="What to do")
dependencies: list[str] = Field(default_factory=list, description="Step ids that must complete first")
tool_name: str | None = Field(default=None, description="Optional tool to invoke")
tool_args: dict[str, Any] = Field(default_factory=dict, description="Optional tool arguments")
metadata: dict[str, Any] = Field(default_factory=dict, description="Extra data")
@field_validator("id", "description")
@classmethod
def validate_non_whitespace(cls, v: str) -> str:
"""Validate string fields are not just whitespace."""
if not v.strip():
raise ValueError("Field cannot be empty or whitespace")
return v
class Plan(BaseModel):
"""
Plan graph: steps and optional fallback paths.
Validation:
- No duplicate step IDs
- All dependency references must be valid step IDs
- All fallback path references must be valid step IDs
- No circular dependencies
"""
steps: list[PlanStep] = Field(default_factory=list, description="Ordered steps")
fallback_paths: list[list[str]] = Field(default_factory=list, description="Alternative step sequences")
metadata: dict[str, Any] = Field(default_factory=dict, description="Plan-level metadata")
@model_validator(mode="after")
def validate_plan(self) -> "Plan":
"""Validate the entire plan structure."""
step_ids = {s.id for s in self.steps}
# Check for duplicate step IDs
if len(step_ids) != len(self.steps):
seen = set()
duplicates = []
for s in self.steps:
if s.id in seen:
duplicates.append(s.id)
seen.add(s.id)
raise ValueError(f"Duplicate step IDs: {duplicates}")
# Check all dependency references are valid
for step in self.steps:
invalid_deps = [d for d in step.dependencies if d not in step_ids]
if invalid_deps:
raise ValueError(
f"Step '{step.id}' has invalid dependencies: {invalid_deps}"
)
# Check all fallback path references are valid
for i, path in enumerate(self.fallback_paths):
invalid_refs = [ref for ref in path if ref not in step_ids]
if invalid_refs:
raise ValueError(
f"Fallback path {i} has invalid step references: {invalid_refs}"
)
# Check for circular dependencies
cycles = self._find_cycles()
if cycles:
raise ValueError(f"Circular dependencies detected: {cycles}")
return self
def _find_cycles(self) -> list[list[str]]:
"""Find circular dependencies in the plan graph using DFS."""
# Build adjacency list
graph: dict[str, list[str]] = {s.id: list(s.dependencies) for s in self.steps}
cycles = []
visited = set()
rec_stack = set()
path = []
def dfs(node: str) -> bool:
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
# Found cycle
cycle_start = path.index(neighbor)
cycles.append(path[cycle_start:] + [neighbor])
return True
path.pop()
rec_stack.remove(node)
return False
for step_id in graph:
if step_id not in visited:
dfs(step_id)
return cycles
def step_ids(self) -> list[str]:
"""Return step ids in order."""
return [s.id for s in self.steps]
def get_step(self, step_id: str) -> PlanStep | None:
"""Get a step by ID."""
for step in self.steps:
if step.id == step_id:
return step
return None
def get_dependencies(self, step_id: str) -> list[PlanStep]:
"""Get all dependency steps for a given step."""
step = self.get_step(step_id)
if not step:
return []
return [s for s in self.steps if s.id in step.dependencies]
def get_dependents(self, step_id: str) -> list[PlanStep]:
"""Get all steps that depend on the given step."""
return [s for s in self.steps if step_id in s.dependencies]
def topological_order(self) -> list[str]:
"""
Return step IDs in topological order (dependencies first).
Uses Kahn's algorithm.
"""
# Build in-degree map
in_degree = {s.id: len(s.dependencies) for s in self.steps}
# Build adjacency list (reverse direction for dependents)
dependents: dict[str, list[str]] = {s.id: [] for s in self.steps}
for step in self.steps:
for dep in step.dependencies:
if dep in dependents:
dependents[dep].append(step.id)
# Start with nodes that have no dependencies
queue = [sid for sid, deg in in_degree.items() if deg == 0]
result = []
while queue:
node = queue.pop(0)
result.append(node)
for dependent in dependents.get(node, []):
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
queue.append(dependent)
# Add any remaining nodes (would indicate cycles, but we validate above)
remaining = [sid for sid in in_degree if sid not in result]
result.extend(remaining)
return result
def to_dict(self) -> dict[str, Any]:
"""Serialize for message payload / state."""
return {
"steps": [s.model_dump() for s in self.steps],
"fallback_paths": self.fallback_paths,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, d: dict[str, Any]) -> "Plan":
"""Deserialize from dict. Steps may be dicts (validated) or PlanStep instances."""
if not isinstance(d, dict):
raise TypeError(f"Plan.from_dict expects dict, got {type(d).__name__}")
raw_steps = d.get("steps", [])
steps: list[PlanStep] = []
for s in raw_steps:
if isinstance(s, PlanStep):
steps.append(s)
elif isinstance(s, dict):
steps.append(PlanStep.model_validate(s))
else:
raise TypeError(f"Step must be dict or PlanStep, got {type(s).__name__}")
return cls(
steps=steps,
fallback_paths=d.get("fallback_paths", []),
metadata=d.get("metadata", {}),
)