Files
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

169 lines
6.4 KiB
Python

"""Self-correction: on failure, run reflection and optionally prepare retry with feedback."""
from typing import Any, Protocol
from fusionagi.schemas.task import TaskState
from fusionagi.schemas.recommendation import Recommendation, RecommendationKind
from fusionagi._logger import logger
class StateManagerLike(Protocol):
"""Protocol for state manager: get task state, trace, task."""
def get_task_state(self, task_id: str) -> TaskState | None: ...
def get_trace(self, task_id: str) -> list[dict[str, Any]]: ...
def get_task(self, task_id: str) -> Any: ...
class OrchestratorLike(Protocol):
"""Protocol for orchestrator: get plan, set state (for retry)."""
def get_task_plan(self, task_id: str) -> dict[str, Any] | None: ...
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None: ...
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None: ...
class CriticLike(Protocol):
"""Protocol for critic: handle_message with evaluate_request -> evaluation_ready."""
identity: str
def handle_message(self, envelope: Any) -> Any | None: ...
def run_reflection_on_failure(
critic_agent: CriticLike,
task_id: str,
state_manager: StateManagerLike,
orchestrator: OrchestratorLike,
) -> dict[str, Any] | None:
"""
Run reflection (Critic evaluation) for a failed task.
Returns evaluation dict or None.
"""
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
trace = state_manager.get_trace(task_id)
plan = orchestrator.get_task_plan(task_id)
envelope = AgentMessageEnvelope(
message=AgentMessage(
sender="self_correction",
recipient=critic_agent.identity,
intent="evaluate_request",
payload={
"outcome": "failed",
"trace": trace,
"plan": plan,
},
),
task_id=task_id,
)
response = critic_agent.handle_message(envelope)
if not response or response.message.intent != "evaluation_ready":
return None
return response.message.payload.get("evaluation", {})
class SelfCorrectionLoop:
"""
Self-correction: on failed tasks, run Critic reflection and optionally
prepare retry by transitioning FAILED -> PENDING and storing correction context.
"""
def __init__(
self,
state_manager: StateManagerLike,
orchestrator: OrchestratorLike,
critic_agent: CriticLike,
max_retries_per_task: int = 2,
) -> None:
"""
Initialize the self-correction loop.
Args:
state_manager: State manager for task state and traces.
orchestrator: Orchestrator for plan and state transitions.
critic_agent: Critic agent for evaluate_request -> evaluation_ready.
max_retries_per_task: Maximum retries to suggest per task (default 2).
"""
self._state = state_manager
self._orchestrator = orchestrator
self._critic = critic_agent
self._max_retries = max_retries_per_task
self._retry_counts: dict[str, int] = {}
def suggest_retry(self, task_id: str) -> tuple[bool, dict[str, Any]]:
"""
For a failed task, run reflection and decide whether to suggest retry.
Returns (should_retry, correction_context).
"""
state = self._state.get_task_state(task_id)
if state != TaskState.FAILED:
return False, {}
retries = self._retry_counts.get(task_id, 0)
if retries >= self._max_retries:
logger.info(
"Self-correction: max retries reached",
extra={"task_id": task_id, "retries": retries},
)
return False, {}
evaluation = run_reflection_on_failure(
self._critic, task_id, self._state, self._orchestrator,
)
if not evaluation:
return False, {}
suggestions = evaluation.get("suggestions", [])
error_analysis = evaluation.get("error_analysis", [])
should_retry = bool(suggestions or evaluation.get("score", 0) < 0.5)
context = {
"evaluation": evaluation,
"suggestions": suggestions,
"error_analysis": error_analysis,
"retry_count": retries + 1,
}
return should_retry, context
def prepare_retry(self, task_id: str, correction_context: dict[str, Any] | None = None) -> None:
"""
Transition task from FAILED to PENDING and store correction context in plan.
If correction_context is None, runs suggest_retry to obtain it.
"""
state = self._state.get_task_state(task_id)
if state != TaskState.FAILED:
logger.warning("Self-correction: prepare_retry called for non-failed task", extra={"task_id": task_id})
return
if correction_context is None:
ok, correction_context = self.suggest_retry(task_id)
if not ok:
return
plan = self._orchestrator.get_task_plan(task_id) or {}
plan = dict(plan)
plan["_correction_context"] = correction_context
self._orchestrator.set_task_plan(task_id, plan)
self._orchestrator.set_task_state(task_id, TaskState.PENDING, force=True)
self._retry_counts[task_id] = self._retry_counts.get(task_id, 0) + 1
logger.info("Self-correction: prepared retry", extra={"task_id": task_id, "retry_count": self._retry_counts[task_id]})
def correction_recommendations(self, task_id: str) -> list[Recommendation]:
"""For a failed task, run reflection and return structured recommendations."""
evaluation = run_reflection_on_failure(
self._critic, task_id, self._state, self._orchestrator,
)
if not evaluation:
return []
suggestions = evaluation.get("suggestions", [])
error_analysis = evaluation.get("error_analysis", [])
recs: list[Recommendation] = []
for i, s in enumerate(suggestions[:10]):
recs.append(
Recommendation(
kind=RecommendationKind.NEXT_ACTION,
title=f"Correction suggestion {i + 1}",
description=s if isinstance(s, str) else str(s),
payload={"raw": s, "error_analysis": error_analysis},
source_task_id=task_id,
priority=7,
)
)
return recs