190 lines
6.8 KiB
Python
190 lines
6.8 KiB
Python
|
|
"""SQLite-backed state backend for task persistence.
|
||
|
|
|
||
|
|
Uses synchronous sqlite3 wrapped in a thread pool for async compatibility.
|
||
|
|
For production Postgres, swap with asyncpg or SQLAlchemy async.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
import threading
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from fusionagi._logger import logger
|
||
|
|
from fusionagi.core.persistence import StateBackend
|
||
|
|
from fusionagi.schemas.task import Task, TaskState
|
||
|
|
|
||
|
|
|
||
|
|
class SQLiteStateBackend(StateBackend):
|
||
|
|
"""SQLite-backed implementation of StateBackend.
|
||
|
|
|
||
|
|
Stores tasks, task states, and traces in a local SQLite database.
|
||
|
|
Thread-safe via a threading lock on write operations.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, db_path: str = "fusionagi_state.db") -> None:
|
||
|
|
self._db_path = db_path
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
self._init_schema()
|
||
|
|
|
||
|
|
def _get_conn(self) -> sqlite3.Connection:
|
||
|
|
"""Get a new connection (sqlite3 connections are not thread-safe)."""
|
||
|
|
conn = sqlite3.connect(self._db_path)
|
||
|
|
conn.row_factory = sqlite3.Row
|
||
|
|
return conn
|
||
|
|
|
||
|
|
def _init_schema(self) -> None:
|
||
|
|
"""Create tables if they don't exist."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
conn.executescript("""
|
||
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
||
|
|
task_id TEXT PRIMARY KEY,
|
||
|
|
data TEXT NOT NULL,
|
||
|
|
state TEXT NOT NULL DEFAULT 'pending',
|
||
|
|
created_at TEXT,
|
||
|
|
updated_at TEXT
|
||
|
|
);
|
||
|
|
CREATE TABLE IF NOT EXISTS traces (
|
||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
task_id TEXT NOT NULL,
|
||
|
|
entry TEXT NOT NULL,
|
||
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
|
|
FOREIGN KEY (task_id) REFERENCES tasks(task_id)
|
||
|
|
);
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_traces_task ON traces(task_id);
|
||
|
|
""")
|
||
|
|
conn.commit()
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
logger.info("SQLiteStateBackend initialized", extra={"db_path": self._db_path})
|
||
|
|
|
||
|
|
def get_task(self, task_id: str) -> Task | None:
|
||
|
|
"""Load task by id."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
row = conn.execute("SELECT data FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return Task.model_validate_json(row["data"])
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def set_task(self, task: Task) -> None:
|
||
|
|
"""Save or update a task."""
|
||
|
|
with self._lock:
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
data = task.model_dump_json()
|
||
|
|
conn.execute(
|
||
|
|
"INSERT OR REPLACE INTO tasks (task_id, data, state, created_at, updated_at) "
|
||
|
|
"VALUES (?, ?, ?, ?, ?)",
|
||
|
|
(
|
||
|
|
task.task_id,
|
||
|
|
data,
|
||
|
|
task.state.value,
|
||
|
|
task.created_at.isoformat() if task.created_at else None,
|
||
|
|
task.updated_at.isoformat() if task.updated_at else None,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
conn.commit()
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def get_task_state(self, task_id: str) -> TaskState | None:
|
||
|
|
"""Return current task state or None if task unknown."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
row = conn.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id,)).fetchone()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return TaskState(row["state"])
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||
|
|
"""Update task state; creates no task if missing."""
|
||
|
|
with self._lock:
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
task = self.get_task(task_id)
|
||
|
|
if task is not None:
|
||
|
|
conn.execute(
|
||
|
|
"UPDATE tasks SET state = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ?",
|
||
|
|
(state.value, task_id),
|
||
|
|
)
|
||
|
|
# Also update the JSON data blob
|
||
|
|
updated = task.model_copy(update={"state": state})
|
||
|
|
conn.execute(
|
||
|
|
"UPDATE tasks SET data = ? WHERE task_id = ?",
|
||
|
|
(updated.model_dump_json(), task_id),
|
||
|
|
)
|
||
|
|
conn.commit()
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||
|
|
"""Append trace entry."""
|
||
|
|
with self._lock:
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
conn.execute(
|
||
|
|
"INSERT INTO traces (task_id, entry) VALUES (?, ?)",
|
||
|
|
(task_id, json.dumps(entry)),
|
||
|
|
)
|
||
|
|
conn.commit()
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||
|
|
"""Load trace for task."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT entry FROM traces WHERE task_id = ? ORDER BY id",
|
||
|
|
(task_id,),
|
||
|
|
).fetchall()
|
||
|
|
return [json.loads(row["entry"]) for row in rows]
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def list_tasks(self, state: TaskState | None = None, limit: int = 100) -> list[Task]:
|
||
|
|
"""List tasks, optionally filtered by state."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
if state is not None:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT data FROM tasks WHERE state = ? ORDER BY rowid DESC LIMIT ?",
|
||
|
|
(state.value, limit),
|
||
|
|
).fetchall()
|
||
|
|
else:
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT data FROM tasks ORDER BY rowid DESC LIMIT ?",
|
||
|
|
(limit,),
|
||
|
|
).fetchall()
|
||
|
|
return [Task.model_validate_json(row["data"]) for row in rows]
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def delete_task(self, task_id: str) -> bool:
|
||
|
|
"""Delete a task and its traces."""
|
||
|
|
with self._lock:
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
conn.execute("DELETE FROM traces WHERE task_id = ?", (task_id,))
|
||
|
|
cursor = conn.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,))
|
||
|
|
conn.commit()
|
||
|
|
return cursor.rowcount > 0
|
||
|
|
finally:
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def count_tasks(self) -> int:
|
||
|
|
"""Return total task count."""
|
||
|
|
conn = self._get_conn()
|
||
|
|
try:
|
||
|
|
row = conn.execute("SELECT COUNT(*) as cnt FROM tasks").fetchone()
|
||
|
|
return row["cnt"] if row else 0
|
||
|
|
finally:
|
||
|
|
conn.close()
|