Files
FusionAGI/fusionagi/memory/episodic.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

227 lines
7.3 KiB
Python

"""Episodic memory: append-only log of task/step outcomes; query by task_id or time range.
Episodic memory stores historical records of agent actions and outcomes:
- Task execution traces
- Step outcomes (success/failure)
- Tool invocation results
- Decision points and their outcomes
"""
import time
from typing import Any, Callable, Iterator
from fusionagi._logger import logger
from fusionagi._time import utc_now_iso
class EpisodicMemory:
"""
Append-only log of task and step outcomes.
Features:
- Time-stamped event logging
- Query by task ID
- Query by time range
- Query by event type
- Statistical summaries
- Memory size limits with optional archival
"""
def __init__(self, max_entries: int = 10000) -> None:
"""
Initialize episodic memory.
Args:
max_entries: Maximum entries before oldest are archived/removed.
"""
self._entries: list[dict[str, Any]] = []
self._by_task: dict[str, list[int]] = {} # task_id -> indices into _entries
self._by_type: dict[str, list[int]] = {} # event_type -> indices
self._max_entries = max_entries
self._archived_count = 0
def append(
self,
task_id: str,
event: dict[str, Any],
event_type: str | None = None,
) -> int:
"""
Append an episodic entry.
Args:
task_id: Task identifier this event belongs to.
event: Event data dictionary.
event_type: Optional event type for categorization (e.g., "step_done", "tool_call").
Returns:
Index of the appended entry.
"""
# Enforce size limits
if len(self._entries) >= self._max_entries:
self._archive_oldest(self._max_entries // 10)
# Add metadata
entry = {
**event,
"task_id": task_id,
"timestamp": event.get("timestamp", time.monotonic()),
"datetime": event.get("datetime", utc_now_iso()),
}
if event_type:
entry["event_type"] = event_type
idx = len(self._entries)
self._entries.append(entry)
# Index by task
self._by_task.setdefault(task_id, []).append(idx)
# Index by type if provided
etype = event_type or event.get("type") or event.get("event_type")
if etype:
self._by_type.setdefault(etype, []).append(idx)
return idx
def get_by_task(self, task_id: str, limit: int | None = None) -> list[dict[str, Any]]:
"""Return all entries for a task (copy), optionally limited."""
indices = self._by_task.get(task_id, [])
if limit:
indices = indices[-limit:]
return [self._entries[i].copy() for i in indices]
def get_by_type(self, event_type: str, limit: int | None = None) -> list[dict[str, Any]]:
"""Return entries of a specific type."""
indices = self._by_type.get(event_type, [])
if limit:
indices = indices[-limit:]
return [self._entries[i].copy() for i in indices]
def get_recent(self, limit: int = 100) -> list[dict[str, Any]]:
"""Return most recent entries (copy)."""
return [e.copy() for e in self._entries[-limit:]]
def get_by_time_range(
self,
start_timestamp: float | None = None,
end_timestamp: float | None = None,
limit: int | None = None,
) -> list[dict[str, Any]]:
"""
Return entries within a time range (using monotonic timestamps).
Args:
start_timestamp: Start of range (inclusive).
end_timestamp: End of range (inclusive).
limit: Maximum entries to return.
"""
results = []
for entry in self._entries:
ts = entry.get("timestamp", 0)
if start_timestamp and ts < start_timestamp:
continue
if end_timestamp and ts > end_timestamp:
continue
results.append(entry.copy())
if limit and len(results) >= limit:
break
return results
def query(
self,
filter_fn: Callable[[dict[str, Any]], bool],
limit: int | None = None,
) -> list[dict[str, Any]]:
"""
Query entries using a custom filter function.
Args:
filter_fn: Function that returns True for entries to include.
limit: Maximum entries to return.
"""
results = []
for entry in self._entries:
if filter_fn(entry):
results.append(entry.copy())
if limit and len(results) >= limit:
break
return results
def get_task_summary(self, task_id: str) -> dict[str, Any]:
"""
Get a summary of episodes for a task.
Returns statistics like count, first/last timestamps, event types.
"""
entries = self.get_by_task(task_id)
if not entries:
return {"task_id": task_id, "count": 0}
event_types: dict[str, int] = {}
success_count = 0
failure_count = 0
for entry in entries:
etype = entry.get("event_type") or entry.get("type") or "unknown"
event_types[etype] = event_types.get(etype, 0) + 1
if entry.get("success"):
success_count += 1
elif entry.get("error") or entry.get("success") is False:
failure_count += 1
return {
"task_id": task_id,
"count": len(entries),
"first_timestamp": entries[0].get("datetime"),
"last_timestamp": entries[-1].get("datetime"),
"event_types": event_types,
"success_count": success_count,
"failure_count": failure_count,
}
def get_statistics(self) -> dict[str, Any]:
"""Get overall memory statistics."""
return {
"total_entries": len(self._entries),
"archived_entries": self._archived_count,
"task_count": len(self._by_task),
"event_type_count": len(self._by_type),
"event_types": list(self._by_type.keys()),
}
def _archive_oldest(self, count: int) -> None:
"""Archive/remove oldest entries to enforce size limits."""
if count <= 0 or count >= len(self._entries):
return
logger.info(
"Archiving episodic memory entries",
extra={"count": count, "total": len(self._entries)},
)
# Remove oldest entries
self._entries = self._entries[count:]
self._archived_count += count
# Rebuild indices (entries shifted)
self._by_task = {}
self._by_type = {}
for idx, entry in enumerate(self._entries):
task_id = entry.get("task_id")
if task_id:
self._by_task.setdefault(task_id, []).append(idx)
etype = entry.get("event_type") or entry.get("type")
if etype:
self._by_type.setdefault(etype, []).append(idx)
def clear(self) -> None:
"""Clear all entries (for tests)."""
self._entries.clear()
self._by_task.clear()
self._by_type.clear()
self._archived_count = 0