"""Scheduler: think vs act, tool selection, retry logic, fallback modes for AGI.""" from enum import Enum from typing import Any, Callable from fusionagi._logger import logger class SchedulerMode(str, Enum): """Whether to think (reason) or act (tool) next.""" THINK = "think" ACT = "act" class FallbackMode(str, Enum): """Fallback when primary path fails.""" RETRY = "retry" SIMPLIFY_PLAN = "simplify_plan" HUMAN_HANDOFF = "human_handoff" ABORT = "abort" class Scheduler: """ Decides think vs act, tool selection policy, retry/backoff, fallback. Callers (e.g. Supervisor) query next_action() and record outcomes. """ def __init__( self, default_mode: SchedulerMode = SchedulerMode.ACT, max_retries_per_step: int = 2, fallback_sequence: list[FallbackMode] | None = None, ) -> None: self._default_mode = default_mode self._max_retries = max_retries_per_step self._fallback_sequence = fallback_sequence or [ FallbackMode.RETRY, FallbackMode.SIMPLIFY_PLAN, FallbackMode.HUMAN_HANDOFF, FallbackMode.ABORT, ] self._retry_counts: dict[str, int] = {} # step_key -> count self._fallback_index: dict[str, int] = {} # task_id -> index into fallback_sequence def next_mode( self, task_id: str, step_id: str, context: dict[str, Any] | None = None, ) -> SchedulerMode: """ Return whether to think (reason more) or act (execute step). Override via context["force_think"] or context["force_act"]. """ if context: if context.get("force_think"): return SchedulerMode.THINK if context.get("force_act"): return SchedulerMode.ACT return self._default_mode def should_retry(self, task_id: str, step_id: str) -> bool: """Return True if step should be retried (under max_retries).""" key = f"{task_id}:{step_id}" count = self._retry_counts.get(key, 0) return count < self._max_retries def record_retry(self, task_id: str, step_id: str) -> None: """Increment retry count for step.""" key = f"{task_id}:{step_id}" self._retry_counts[key] = self._retry_counts.get(key, 0) + 1 logger.debug("Scheduler recorded retry", extra={"task_id": task_id, "step_id": step_id}) def next_fallback(self, task_id: str) -> FallbackMode | None: """Return next fallback mode for task, or None if exhausted.""" idx = self._fallback_index.get(task_id, 0) if idx >= len(self._fallback_sequence): return None mode = self._fallback_sequence[idx] self._fallback_index[task_id] = idx + 1 logger.info("Scheduler fallback", extra={"task_id": task_id, "fallback": mode.value}) return mode def reset_fallback(self, task_id: str) -> None: """Reset fallback index for task (e.g. after success).""" self._fallback_index.pop(task_id, None)