169 lines
6.4 KiB
Python
169 lines
6.4 KiB
Python
"""Self-correction: on failure, run reflection and optionally prepare retry with feedback."""
|
|
|
|
from typing import Any, Protocol
|
|
|
|
from fusionagi.schemas.task import TaskState
|
|
from fusionagi.schemas.recommendation import Recommendation, RecommendationKind
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class StateManagerLike(Protocol):
|
|
"""Protocol for state manager: get task state, trace, task."""
|
|
|
|
def get_task_state(self, task_id: str) -> TaskState | None: ...
|
|
def get_trace(self, task_id: str) -> list[dict[str, Any]]: ...
|
|
def get_task(self, task_id: str) -> Any: ...
|
|
|
|
|
|
class OrchestratorLike(Protocol):
|
|
"""Protocol for orchestrator: get plan, set state (for retry)."""
|
|
|
|
def get_task_plan(self, task_id: str) -> dict[str, Any] | None: ...
|
|
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None: ...
|
|
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None: ...
|
|
|
|
|
|
class CriticLike(Protocol):
|
|
"""Protocol for critic: handle_message with evaluate_request -> evaluation_ready."""
|
|
|
|
identity: str
|
|
|
|
def handle_message(self, envelope: Any) -> Any | None: ...
|
|
|
|
|
|
def run_reflection_on_failure(
|
|
critic_agent: CriticLike,
|
|
task_id: str,
|
|
state_manager: StateManagerLike,
|
|
orchestrator: OrchestratorLike,
|
|
) -> dict[str, Any] | None:
|
|
"""
|
|
Run reflection (Critic evaluation) for a failed task.
|
|
Returns evaluation dict or None.
|
|
"""
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
|
|
|
trace = state_manager.get_trace(task_id)
|
|
plan = orchestrator.get_task_plan(task_id)
|
|
envelope = AgentMessageEnvelope(
|
|
message=AgentMessage(
|
|
sender="self_correction",
|
|
recipient=critic_agent.identity,
|
|
intent="evaluate_request",
|
|
payload={
|
|
"outcome": "failed",
|
|
"trace": trace,
|
|
"plan": plan,
|
|
},
|
|
),
|
|
task_id=task_id,
|
|
)
|
|
response = critic_agent.handle_message(envelope)
|
|
if not response or response.message.intent != "evaluation_ready":
|
|
return None
|
|
return response.message.payload.get("evaluation", {})
|
|
|
|
|
|
class SelfCorrectionLoop:
|
|
"""
|
|
Self-correction: on failed tasks, run Critic reflection and optionally
|
|
prepare retry by transitioning FAILED -> PENDING and storing correction context.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
state_manager: StateManagerLike,
|
|
orchestrator: OrchestratorLike,
|
|
critic_agent: CriticLike,
|
|
max_retries_per_task: int = 2,
|
|
) -> None:
|
|
"""
|
|
Initialize the self-correction loop.
|
|
|
|
Args:
|
|
state_manager: State manager for task state and traces.
|
|
orchestrator: Orchestrator for plan and state transitions.
|
|
critic_agent: Critic agent for evaluate_request -> evaluation_ready.
|
|
max_retries_per_task: Maximum retries to suggest per task (default 2).
|
|
"""
|
|
self._state = state_manager
|
|
self._orchestrator = orchestrator
|
|
self._critic = critic_agent
|
|
self._max_retries = max_retries_per_task
|
|
self._retry_counts: dict[str, int] = {}
|
|
|
|
def suggest_retry(self, task_id: str) -> tuple[bool, dict[str, Any]]:
|
|
"""
|
|
For a failed task, run reflection and decide whether to suggest retry.
|
|
Returns (should_retry, correction_context).
|
|
"""
|
|
state = self._state.get_task_state(task_id)
|
|
if state != TaskState.FAILED:
|
|
return False, {}
|
|
retries = self._retry_counts.get(task_id, 0)
|
|
if retries >= self._max_retries:
|
|
logger.info(
|
|
"Self-correction: max retries reached",
|
|
extra={"task_id": task_id, "retries": retries},
|
|
)
|
|
return False, {}
|
|
evaluation = run_reflection_on_failure(
|
|
self._critic, task_id, self._state, self._orchestrator,
|
|
)
|
|
if not evaluation:
|
|
return False, {}
|
|
suggestions = evaluation.get("suggestions", [])
|
|
error_analysis = evaluation.get("error_analysis", [])
|
|
should_retry = bool(suggestions or evaluation.get("score", 0) < 0.5)
|
|
context = {
|
|
"evaluation": evaluation,
|
|
"suggestions": suggestions,
|
|
"error_analysis": error_analysis,
|
|
"retry_count": retries + 1,
|
|
}
|
|
return should_retry, context
|
|
|
|
def prepare_retry(self, task_id: str, correction_context: dict[str, Any] | None = None) -> None:
|
|
"""
|
|
Transition task from FAILED to PENDING and store correction context in plan.
|
|
If correction_context is None, runs suggest_retry to obtain it.
|
|
"""
|
|
state = self._state.get_task_state(task_id)
|
|
if state != TaskState.FAILED:
|
|
logger.warning("Self-correction: prepare_retry called for non-failed task", extra={"task_id": task_id})
|
|
return
|
|
if correction_context is None:
|
|
ok, correction_context = self.suggest_retry(task_id)
|
|
if not ok:
|
|
return
|
|
plan = self._orchestrator.get_task_plan(task_id) or {}
|
|
plan = dict(plan)
|
|
plan["_correction_context"] = correction_context
|
|
self._orchestrator.set_task_plan(task_id, plan)
|
|
self._orchestrator.set_task_state(task_id, TaskState.PENDING, force=True)
|
|
self._retry_counts[task_id] = self._retry_counts.get(task_id, 0) + 1
|
|
logger.info("Self-correction: prepared retry", extra={"task_id": task_id, "retry_count": self._retry_counts[task_id]})
|
|
|
|
def correction_recommendations(self, task_id: str) -> list[Recommendation]:
|
|
"""For a failed task, run reflection and return structured recommendations."""
|
|
evaluation = run_reflection_on_failure(
|
|
self._critic, task_id, self._state, self._orchestrator,
|
|
)
|
|
if not evaluation:
|
|
return []
|
|
suggestions = evaluation.get("suggestions", [])
|
|
error_analysis = evaluation.get("error_analysis", [])
|
|
recs: list[Recommendation] = []
|
|
for i, s in enumerate(suggestions[:10]):
|
|
recs.append(
|
|
Recommendation(
|
|
kind=RecommendationKind.NEXT_ACTION,
|
|
title=f"Correction suggestion {i + 1}",
|
|
description=s if isinstance(s, str) else str(s),
|
|
payload={"raw": s, "error_analysis": error_analysis},
|
|
source_task_id=task_id,
|
|
priority=7,
|
|
)
|
|
)
|
|
return recs
|