202 lines
7.2 KiB
Python
202 lines
7.2 KiB
Python
"""Plan schema: steps with ids, dependencies, optional fallback paths with validation."""
|
|
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
|
|
class PlanStep(BaseModel):
|
|
"""
|
|
Single step in a plan.
|
|
|
|
Validation:
|
|
- id and description must be non-empty
|
|
"""
|
|
|
|
id: str = Field(..., min_length=1, description="Step identifier")
|
|
description: str = Field(..., min_length=1, description="What to do")
|
|
dependencies: list[str] = Field(default_factory=list, description="Step ids that must complete first")
|
|
tool_name: str | None = Field(default=None, description="Optional tool to invoke")
|
|
tool_args: dict[str, Any] = Field(default_factory=dict, description="Optional tool arguments")
|
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Extra data")
|
|
|
|
@field_validator("id", "description")
|
|
@classmethod
|
|
def validate_non_whitespace(cls, v: str) -> str:
|
|
"""Validate string fields are not just whitespace."""
|
|
if not v.strip():
|
|
raise ValueError("Field cannot be empty or whitespace")
|
|
return v
|
|
|
|
|
|
class Plan(BaseModel):
|
|
"""
|
|
Plan graph: steps and optional fallback paths.
|
|
|
|
Validation:
|
|
- No duplicate step IDs
|
|
- All dependency references must be valid step IDs
|
|
- All fallback path references must be valid step IDs
|
|
- No circular dependencies
|
|
"""
|
|
|
|
steps: list[PlanStep] = Field(default_factory=list, description="Ordered steps")
|
|
fallback_paths: list[list[str]] = Field(default_factory=list, description="Alternative step sequences")
|
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Plan-level metadata")
|
|
|
|
@model_validator(mode="after")
|
|
def validate_plan(self) -> "Plan":
|
|
"""Validate the entire plan structure."""
|
|
step_ids = {s.id for s in self.steps}
|
|
|
|
# Check for duplicate step IDs
|
|
if len(step_ids) != len(self.steps):
|
|
seen = set()
|
|
duplicates = []
|
|
for s in self.steps:
|
|
if s.id in seen:
|
|
duplicates.append(s.id)
|
|
seen.add(s.id)
|
|
raise ValueError(f"Duplicate step IDs: {duplicates}")
|
|
|
|
# Check all dependency references are valid
|
|
for step in self.steps:
|
|
invalid_deps = [d for d in step.dependencies if d not in step_ids]
|
|
if invalid_deps:
|
|
raise ValueError(
|
|
f"Step '{step.id}' has invalid dependencies: {invalid_deps}"
|
|
)
|
|
|
|
# Check all fallback path references are valid
|
|
for i, path in enumerate(self.fallback_paths):
|
|
invalid_refs = [ref for ref in path if ref not in step_ids]
|
|
if invalid_refs:
|
|
raise ValueError(
|
|
f"Fallback path {i} has invalid step references: {invalid_refs}"
|
|
)
|
|
|
|
# Check for circular dependencies
|
|
cycles = self._find_cycles()
|
|
if cycles:
|
|
raise ValueError(f"Circular dependencies detected: {cycles}")
|
|
|
|
return self
|
|
|
|
def _find_cycles(self) -> list[list[str]]:
|
|
"""Find circular dependencies in the plan graph using DFS."""
|
|
# Build adjacency list
|
|
graph: dict[str, list[str]] = {s.id: list(s.dependencies) for s in self.steps}
|
|
|
|
cycles = []
|
|
visited = set()
|
|
rec_stack = set()
|
|
path = []
|
|
|
|
def dfs(node: str) -> bool:
|
|
visited.add(node)
|
|
rec_stack.add(node)
|
|
path.append(node)
|
|
|
|
for neighbor in graph.get(node, []):
|
|
if neighbor not in visited:
|
|
if dfs(neighbor):
|
|
return True
|
|
elif neighbor in rec_stack:
|
|
# Found cycle
|
|
cycle_start = path.index(neighbor)
|
|
cycles.append(path[cycle_start:] + [neighbor])
|
|
return True
|
|
|
|
path.pop()
|
|
rec_stack.remove(node)
|
|
return False
|
|
|
|
for step_id in graph:
|
|
if step_id not in visited:
|
|
dfs(step_id)
|
|
|
|
return cycles
|
|
|
|
def step_ids(self) -> list[str]:
|
|
"""Return step ids in order."""
|
|
return [s.id for s in self.steps]
|
|
|
|
def get_step(self, step_id: str) -> PlanStep | None:
|
|
"""Get a step by ID."""
|
|
for step in self.steps:
|
|
if step.id == step_id:
|
|
return step
|
|
return None
|
|
|
|
def get_dependencies(self, step_id: str) -> list[PlanStep]:
|
|
"""Get all dependency steps for a given step."""
|
|
step = self.get_step(step_id)
|
|
if not step:
|
|
return []
|
|
return [s for s in self.steps if s.id in step.dependencies]
|
|
|
|
def get_dependents(self, step_id: str) -> list[PlanStep]:
|
|
"""Get all steps that depend on the given step."""
|
|
return [s for s in self.steps if step_id in s.dependencies]
|
|
|
|
def topological_order(self) -> list[str]:
|
|
"""
|
|
Return step IDs in topological order (dependencies first).
|
|
|
|
Uses Kahn's algorithm.
|
|
"""
|
|
# Build in-degree map
|
|
in_degree = {s.id: len(s.dependencies) for s in self.steps}
|
|
# Build adjacency list (reverse direction for dependents)
|
|
dependents: dict[str, list[str]] = {s.id: [] for s in self.steps}
|
|
for step in self.steps:
|
|
for dep in step.dependencies:
|
|
if dep in dependents:
|
|
dependents[dep].append(step.id)
|
|
|
|
# Start with nodes that have no dependencies
|
|
queue = [sid for sid, deg in in_degree.items() if deg == 0]
|
|
result = []
|
|
|
|
while queue:
|
|
node = queue.pop(0)
|
|
result.append(node)
|
|
for dependent in dependents.get(node, []):
|
|
in_degree[dependent] -= 1
|
|
if in_degree[dependent] == 0:
|
|
queue.append(dependent)
|
|
|
|
# Add any remaining nodes (would indicate cycles, but we validate above)
|
|
remaining = [sid for sid in in_degree if sid not in result]
|
|
result.extend(remaining)
|
|
|
|
return result
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Serialize for message payload / state."""
|
|
return {
|
|
"steps": [s.model_dump() for s in self.steps],
|
|
"fallback_paths": self.fallback_paths,
|
|
"metadata": self.metadata,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict[str, Any]) -> "Plan":
|
|
"""Deserialize from dict. Steps may be dicts (validated) or PlanStep instances."""
|
|
if not isinstance(d, dict):
|
|
raise TypeError(f"Plan.from_dict expects dict, got {type(d).__name__}")
|
|
raw_steps = d.get("steps", [])
|
|
steps: list[PlanStep] = []
|
|
for s in raw_steps:
|
|
if isinstance(s, PlanStep):
|
|
steps.append(s)
|
|
elif isinstance(s, dict):
|
|
steps.append(PlanStep.model_validate(s))
|
|
else:
|
|
raise TypeError(f"Step must be dict or PlanStep, got {type(s).__name__}")
|
|
return cls(
|
|
steps=steps,
|
|
fallback_paths=d.get("fallback_paths", []),
|
|
metadata=d.get("metadata", {}),
|
|
)
|