227 lines
7.3 KiB
Python
227 lines
7.3 KiB
Python
"""Episodic memory: append-only log of task/step outcomes; query by task_id or time range.
|
|
|
|
Episodic memory stores historical records of agent actions and outcomes:
|
|
- Task execution traces
|
|
- Step outcomes (success/failure)
|
|
- Tool invocation results
|
|
- Decision points and their outcomes
|
|
"""
|
|
|
|
import time
|
|
from typing import Any, Callable, Iterator
|
|
|
|
from fusionagi._logger import logger
|
|
from fusionagi._time import utc_now_iso
|
|
|
|
|
|
class EpisodicMemory:
|
|
"""
|
|
Append-only log of task and step outcomes.
|
|
|
|
Features:
|
|
- Time-stamped event logging
|
|
- Query by task ID
|
|
- Query by time range
|
|
- Query by event type
|
|
- Statistical summaries
|
|
- Memory size limits with optional archival
|
|
"""
|
|
|
|
def __init__(self, max_entries: int = 10000) -> None:
|
|
"""
|
|
Initialize episodic memory.
|
|
|
|
Args:
|
|
max_entries: Maximum entries before oldest are archived/removed.
|
|
"""
|
|
self._entries: list[dict[str, Any]] = []
|
|
self._by_task: dict[str, list[int]] = {} # task_id -> indices into _entries
|
|
self._by_type: dict[str, list[int]] = {} # event_type -> indices
|
|
self._max_entries = max_entries
|
|
self._archived_count = 0
|
|
|
|
def append(
|
|
self,
|
|
task_id: str,
|
|
event: dict[str, Any],
|
|
event_type: str | None = None,
|
|
) -> int:
|
|
"""
|
|
Append an episodic entry.
|
|
|
|
Args:
|
|
task_id: Task identifier this event belongs to.
|
|
event: Event data dictionary.
|
|
event_type: Optional event type for categorization (e.g., "step_done", "tool_call").
|
|
|
|
Returns:
|
|
Index of the appended entry.
|
|
"""
|
|
# Enforce size limits
|
|
if len(self._entries) >= self._max_entries:
|
|
self._archive_oldest(self._max_entries // 10)
|
|
|
|
# Add metadata
|
|
entry = {
|
|
**event,
|
|
"task_id": task_id,
|
|
"timestamp": event.get("timestamp", time.monotonic()),
|
|
"datetime": event.get("datetime", utc_now_iso()),
|
|
}
|
|
|
|
if event_type:
|
|
entry["event_type"] = event_type
|
|
|
|
idx = len(self._entries)
|
|
self._entries.append(entry)
|
|
|
|
# Index by task
|
|
self._by_task.setdefault(task_id, []).append(idx)
|
|
|
|
# Index by type if provided
|
|
etype = event_type or event.get("type") or event.get("event_type")
|
|
if etype:
|
|
self._by_type.setdefault(etype, []).append(idx)
|
|
|
|
return idx
|
|
|
|
def get_by_task(self, task_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
|
"""Return all entries for a task (copy), optionally limited."""
|
|
indices = self._by_task.get(task_id, [])
|
|
if limit:
|
|
indices = indices[-limit:]
|
|
return [self._entries[i].copy() for i in indices]
|
|
|
|
def get_by_type(self, event_type: str, limit: int | None = None) -> list[dict[str, Any]]:
|
|
"""Return entries of a specific type."""
|
|
indices = self._by_type.get(event_type, [])
|
|
if limit:
|
|
indices = indices[-limit:]
|
|
return [self._entries[i].copy() for i in indices]
|
|
|
|
def get_recent(self, limit: int = 100) -> list[dict[str, Any]]:
|
|
"""Return most recent entries (copy)."""
|
|
return [e.copy() for e in self._entries[-limit:]]
|
|
|
|
def get_by_time_range(
|
|
self,
|
|
start_timestamp: float | None = None,
|
|
end_timestamp: float | None = None,
|
|
limit: int | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Return entries within a time range (using monotonic timestamps).
|
|
|
|
Args:
|
|
start_timestamp: Start of range (inclusive).
|
|
end_timestamp: End of range (inclusive).
|
|
limit: Maximum entries to return.
|
|
"""
|
|
results = []
|
|
for entry in self._entries:
|
|
ts = entry.get("timestamp", 0)
|
|
if start_timestamp and ts < start_timestamp:
|
|
continue
|
|
if end_timestamp and ts > end_timestamp:
|
|
continue
|
|
results.append(entry.copy())
|
|
if limit and len(results) >= limit:
|
|
break
|
|
return results
|
|
|
|
def query(
|
|
self,
|
|
filter_fn: Callable[[dict[str, Any]], bool],
|
|
limit: int | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Query entries using a custom filter function.
|
|
|
|
Args:
|
|
filter_fn: Function that returns True for entries to include.
|
|
limit: Maximum entries to return.
|
|
"""
|
|
results = []
|
|
for entry in self._entries:
|
|
if filter_fn(entry):
|
|
results.append(entry.copy())
|
|
if limit and len(results) >= limit:
|
|
break
|
|
return results
|
|
|
|
def get_task_summary(self, task_id: str) -> dict[str, Any]:
|
|
"""
|
|
Get a summary of episodes for a task.
|
|
|
|
Returns statistics like count, first/last timestamps, event types.
|
|
"""
|
|
entries = self.get_by_task(task_id)
|
|
if not entries:
|
|
return {"task_id": task_id, "count": 0}
|
|
|
|
event_types: dict[str, int] = {}
|
|
success_count = 0
|
|
failure_count = 0
|
|
|
|
for entry in entries:
|
|
etype = entry.get("event_type") or entry.get("type") or "unknown"
|
|
event_types[etype] = event_types.get(etype, 0) + 1
|
|
|
|
if entry.get("success"):
|
|
success_count += 1
|
|
elif entry.get("error") or entry.get("success") is False:
|
|
failure_count += 1
|
|
|
|
return {
|
|
"task_id": task_id,
|
|
"count": len(entries),
|
|
"first_timestamp": entries[0].get("datetime"),
|
|
"last_timestamp": entries[-1].get("datetime"),
|
|
"event_types": event_types,
|
|
"success_count": success_count,
|
|
"failure_count": failure_count,
|
|
}
|
|
|
|
def get_statistics(self) -> dict[str, Any]:
|
|
"""Get overall memory statistics."""
|
|
return {
|
|
"total_entries": len(self._entries),
|
|
"archived_entries": self._archived_count,
|
|
"task_count": len(self._by_task),
|
|
"event_type_count": len(self._by_type),
|
|
"event_types": list(self._by_type.keys()),
|
|
}
|
|
|
|
def _archive_oldest(self, count: int) -> None:
|
|
"""Archive/remove oldest entries to enforce size limits."""
|
|
if count <= 0 or count >= len(self._entries):
|
|
return
|
|
|
|
logger.info(
|
|
"Archiving episodic memory entries",
|
|
extra={"count": count, "total": len(self._entries)},
|
|
)
|
|
|
|
# Remove oldest entries
|
|
self._entries = self._entries[count:]
|
|
self._archived_count += count
|
|
|
|
# Rebuild indices (entries shifted)
|
|
self._by_task = {}
|
|
self._by_type = {}
|
|
for idx, entry in enumerate(self._entries):
|
|
task_id = entry.get("task_id")
|
|
if task_id:
|
|
self._by_task.setdefault(task_id, []).append(idx)
|
|
|
|
etype = entry.get("event_type") or entry.get("type")
|
|
if etype:
|
|
self._by_type.setdefault(etype, []).append(idx)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all entries (for tests)."""
|
|
self._entries.clear()
|
|
self._by_task.clear()
|
|
self._by_type.clear()
|
|
self._archived_count = 0
|