112 lines
4.4 KiB
Python
112 lines
4.4 KiB
Python
"""In-memory store for task state and execution traces; replaceable with persistent backend."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
from fusionagi.schemas.task import Task, TaskState
|
|
from fusionagi._logger import logger
|
|
|
|
if TYPE_CHECKING:
|
|
from fusionagi.core.persistence import StateBackend
|
|
|
|
|
|
class StateManager:
|
|
"""
|
|
Manages task state and execution traces.
|
|
|
|
Supports optional persistent backend via dependency injection. When a backend
|
|
is provided, all operations are persisted. In-memory cache is always maintained
|
|
for fast access.
|
|
"""
|
|
|
|
def __init__(self, backend: StateBackend | None = None) -> None:
|
|
"""
|
|
Initialize StateManager with optional persistence backend.
|
|
|
|
Args:
|
|
backend: Optional StateBackend for persistence. If None, uses in-memory only.
|
|
"""
|
|
self._backend = backend
|
|
self._tasks: dict[str, Task] = {}
|
|
self._traces: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
|
|
|
def get_task(self, task_id: str) -> Task | None:
|
|
"""Return the task by id or None. Checks memory first, then backend."""
|
|
if task_id in self._tasks:
|
|
return self._tasks[task_id]
|
|
if self._backend:
|
|
task = self._backend.get_task(task_id)
|
|
if task:
|
|
self._tasks[task_id] = task
|
|
return task
|
|
return None
|
|
|
|
def set_task(self, task: Task) -> None:
|
|
"""Store or update a task in memory and backend."""
|
|
self._tasks[task.task_id] = task
|
|
if self._backend:
|
|
self._backend.set_task(task)
|
|
logger.debug("Task set", extra={"task_id": task.task_id})
|
|
|
|
def get_task_state(self, task_id: str) -> TaskState | None:
|
|
"""Return current task state or None if task unknown."""
|
|
task = self.get_task(task_id)
|
|
return task.state if task else None
|
|
|
|
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
|
"""Update task state; creates no task if missing."""
|
|
if task_id in self._tasks:
|
|
self._tasks[task_id].state = state
|
|
if self._backend:
|
|
self._backend.set_task_state(task_id, state)
|
|
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
|
elif self._backend:
|
|
# Task might be in backend but not in memory
|
|
task = self._backend.get_task(task_id)
|
|
if task:
|
|
task.state = state
|
|
self._tasks[task_id] = task
|
|
self._backend.set_task_state(task_id, state)
|
|
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
|
|
|
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
|
"""Append an entry to the execution trace for a task."""
|
|
self._traces[task_id].append(entry)
|
|
if self._backend:
|
|
self._backend.append_trace(task_id, entry)
|
|
tool = entry.get("tool") or entry.get("step") or "entry"
|
|
logger.debug("Trace appended", extra={"task_id": task_id, "entry_key": tool})
|
|
|
|
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
|
"""Return the execution trace for a task (copy). Checks backend if not in memory."""
|
|
if task_id in self._traces and self._traces[task_id]:
|
|
return list(self._traces[task_id])
|
|
if self._backend:
|
|
trace = self._backend.get_trace(task_id)
|
|
if trace:
|
|
self._traces[task_id] = list(trace)
|
|
return trace
|
|
return list(self._traces.get(task_id, []))
|
|
|
|
def clear_task(self, task_id: str) -> None:
|
|
"""Remove task and its trace (for tests or cleanup). Does not clear backend."""
|
|
self._tasks.pop(task_id, None)
|
|
self._traces.pop(task_id, None)
|
|
|
|
def list_tasks(self, state: TaskState | None = None) -> list[Task]:
|
|
"""Return all tasks, optionally filtered by state.
|
|
|
|
When a persistence backend is configured, only tasks currently loaded
|
|
in memory are returned; tasks that exist only in the backend are not included.
|
|
"""
|
|
tasks = list(self._tasks.values())
|
|
if state is not None:
|
|
tasks = [t for t in tasks if t.state == state]
|
|
return tasks
|
|
|
|
def task_count(self) -> int:
|
|
"""Return total number of tasks in memory."""
|
|
return len(self._tasks)
|