39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
"""Rollouts: simulate plan before executing."""
|
|
|
|
from typing import Any, Callable, Protocol
|
|
|
|
from fusionagi.schemas.plan import Plan
|
|
from fusionagi.schemas.world_model import StateTransition
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class WorldModelLike(Protocol):
|
|
def predict(self, state: dict[str, Any], action: str, action_args: dict[str, Any]) -> StateTransition: ...
|
|
|
|
|
|
def run_rollout(
|
|
plan: Plan,
|
|
initial_state: dict[str, Any],
|
|
world_model: WorldModelLike,
|
|
step_action_fn: Callable[[str, dict], str] | None = None,
|
|
) -> tuple[bool, list[StateTransition], dict[str, Any]]:
|
|
"""
|
|
Simulate plan in world model. Returns (success, transitions, final_state).
|
|
step_action_fn(step_id, step_dict) -> action name for prediction.
|
|
"""
|
|
state = dict(initial_state)
|
|
transitions: list[StateTransition] = []
|
|
for step in plan.steps:
|
|
action = step.tool_name or "unknown"
|
|
action_args = step.tool_args or {}
|
|
if step_action_fn:
|
|
action = step_action_fn(step.id, step.model_dump())
|
|
trans = world_model.predict(state, action, action_args)
|
|
transitions.append(trans)
|
|
state = dict(trans.to_state)
|
|
if trans.confidence < 0.3:
|
|
logger.warning("Rollout low confidence", extra={"step_id": step.id, "confidence": trans.confidence})
|
|
return False, transitions, state
|
|
logger.info("Rollout completed", extra={"steps": len(transitions)})
|
|
return True, transitions, state
|