135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
"""Versioned thought states: snapshots, rollback, branching."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from fusionagi.memory.scratchpad import ThoughtState
|
|
from fusionagi.reasoning.tot import ThoughtNode
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@dataclass
|
|
class ThoughtStateSnapshot:
|
|
"""Snapshot of reasoning state: tree + scratchpad."""
|
|
|
|
version_id: str = field(default_factory=lambda: f"v_{uuid.uuid4().hex[:12]}")
|
|
tree_state: dict[str, Any] | None = None
|
|
scratchpad_state: ThoughtState | None = None
|
|
timestamp: float = field(default_factory=time.monotonic)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def _serialize_tree(node: ThoughtNode | None) -> dict[str, Any]:
|
|
"""Serialize ThoughtNode to dict."""
|
|
if node is None:
|
|
return {}
|
|
return {
|
|
"node_id": node.node_id,
|
|
"parent_id": node.parent_id,
|
|
"thought": node.thought,
|
|
"trace": node.trace,
|
|
"score": node.score,
|
|
"depth": node.depth,
|
|
"unit_refs": node.unit_refs,
|
|
"metadata": node.metadata,
|
|
"children": [_serialize_tree(c) for c in node.children],
|
|
}
|
|
|
|
|
|
def _deserialize_tree(data: dict) -> ThoughtNode | None:
|
|
"""Deserialize dict to ThoughtNode."""
|
|
if not data:
|
|
return None
|
|
node = ThoughtNode(
|
|
node_id=data.get("node_id", ""),
|
|
parent_id=data.get("parent_id"),
|
|
thought=data.get("thought", ""),
|
|
trace=data.get("trace", []),
|
|
score=float(data.get("score", 0)),
|
|
depth=int(data.get("depth", 0)),
|
|
unit_refs=list(data.get("unit_refs", [])),
|
|
metadata=dict(data.get("metadata", {})),
|
|
)
|
|
for c in data.get("children", []):
|
|
child = _deserialize_tree(c)
|
|
if child:
|
|
node.children.append(child)
|
|
return node
|
|
|
|
|
|
class ThoughtVersioning:
|
|
"""Save, load, rollback, branch thought states."""
|
|
|
|
def __init__(self, max_snapshots: int = 50) -> None:
|
|
self._snapshots: dict[str, ThoughtStateSnapshot] = {}
|
|
self._max_snapshots = max_snapshots
|
|
|
|
def save_snapshot(
|
|
self,
|
|
tree: ThoughtNode | None,
|
|
scratchpad: ThoughtState | None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> str:
|
|
"""Save snapshot; return version_id."""
|
|
snapshot = ThoughtStateSnapshot(
|
|
tree_state=_serialize_tree(tree) if tree else {},
|
|
scratchpad_state=scratchpad,
|
|
metadata=metadata or {},
|
|
)
|
|
self._snapshots[snapshot.version_id] = snapshot
|
|
if len(self._snapshots) > self._max_snapshots:
|
|
oldest = min(self._snapshots.keys(), key=lambda k: self._snapshots[k].timestamp)
|
|
del self._snapshots[oldest]
|
|
logger.debug("Thought snapshot saved", extra={"version_id": snapshot.version_id})
|
|
return snapshot.version_id
|
|
|
|
def load_snapshot(
|
|
self,
|
|
version_id: str,
|
|
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
|
"""Load snapshot; return (tree, scratchpad)."""
|
|
snap = self._snapshots.get(version_id)
|
|
if not snap:
|
|
return None, None
|
|
tree = _deserialize_tree(snap.tree_state or {}) if snap.tree_state else None
|
|
return tree, snap.scratchpad_state
|
|
|
|
def list_snapshots(self) -> list[dict[str, Any]]:
|
|
"""List available snapshots."""
|
|
return [
|
|
{
|
|
"version_id": v.version_id,
|
|
"timestamp": v.timestamp,
|
|
"metadata": v.metadata,
|
|
}
|
|
for v in self._snapshots.values()
|
|
]
|
|
|
|
def rollback_to(
|
|
self,
|
|
version_id: str,
|
|
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
|
"""Load and return snapshot (alias for load_snapshot)."""
|
|
return self.load_snapshot(version_id)
|
|
|
|
def branch_from(
|
|
self,
|
|
version_id: str,
|
|
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
|
"""Branch from snapshot (returns copy for further edits)."""
|
|
tree, scratchpad = self.load_snapshot(version_id)
|
|
if tree:
|
|
tree = _deserialize_tree(_serialize_tree(tree))
|
|
if scratchpad:
|
|
scratchpad = ThoughtState(
|
|
hypotheses=list(scratchpad.hypotheses),
|
|
partial_conclusions=list(scratchpad.partial_conclusions),
|
|
discarded_paths=list(scratchpad.discarded_paths),
|
|
metadata=dict(scratchpad.metadata),
|
|
)
|
|
return tree, scratchpad
|