151 lines
5.5 KiB
Python
151 lines
5.5 KiB
Python
"""Working memory: in-memory key-value / list per task/session.
|
|
|
|
Working memory provides short-term storage for active tasks:
|
|
- Key-value storage per session/task
|
|
- List append operations for accumulating results
|
|
- Context retrieval for reasoning
|
|
- Session lifecycle management
|
|
"""
|
|
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Iterator
|
|
|
|
from fusionagi._logger import logger
|
|
from fusionagi._time import utc_now
|
|
|
|
|
|
class WorkingMemory:
|
|
"""
|
|
Short-term working memory per task/session.
|
|
|
|
Features:
|
|
- Key-value get/set operations
|
|
- List append with automatic coercion
|
|
- Context summary for LLM prompts
|
|
- Session management and cleanup
|
|
- Size limits to prevent unbounded growth
|
|
"""
|
|
|
|
def __init__(self, max_entries_per_session: int = 1000) -> None:
|
|
"""
|
|
Initialize working memory.
|
|
|
|
Args:
|
|
max_entries_per_session: Maximum entries per session before oldest are removed.
|
|
"""
|
|
self._store: dict[str, dict[str, Any]] = defaultdict(dict)
|
|
self._timestamps: dict[str, datetime] = {}
|
|
self._max_entries = max_entries_per_session
|
|
|
|
def get(self, session_id: str, key: str, default: Any = None) -> Any:
|
|
"""Get value for session and key; returns default if not found."""
|
|
return self._store[session_id].get(key, default)
|
|
|
|
def set(self, session_id: str, key: str, value: Any) -> None:
|
|
"""Set value for session and key."""
|
|
self._store[session_id][key] = value
|
|
self._timestamps[session_id] = utc_now()
|
|
self._enforce_limits(session_id)
|
|
|
|
def append(self, session_id: str, key: str, value: Any) -> None:
|
|
"""Append to list for session and key (creates list if needed)."""
|
|
if key not in self._store[session_id]:
|
|
self._store[session_id][key] = []
|
|
lst = self._store[session_id][key]
|
|
if not isinstance(lst, list):
|
|
lst = [lst]
|
|
self._store[session_id][key] = lst
|
|
lst.append(value)
|
|
self._timestamps[session_id] = utc_now()
|
|
self._enforce_limits(session_id)
|
|
|
|
def get_list(self, session_id: str, key: str) -> list[Any]:
|
|
"""Return list for session and key (copy)."""
|
|
val = self._store[session_id].get(key)
|
|
if isinstance(val, list):
|
|
return list(val)
|
|
return [val] if val is not None else []
|
|
|
|
def has(self, session_id: str, key: str) -> bool:
|
|
"""Check if a key exists in session."""
|
|
return key in self._store.get(session_id, {})
|
|
|
|
def keys(self, session_id: str) -> list[str]:
|
|
"""Return all keys for a session."""
|
|
return list(self._store.get(session_id, {}).keys())
|
|
|
|
def delete(self, session_id: str, key: str) -> bool:
|
|
"""Delete a key from session. Returns True if existed."""
|
|
if session_id in self._store and key in self._store[session_id]:
|
|
del self._store[session_id][key]
|
|
return True
|
|
return False
|
|
|
|
def clear_session(self, session_id: str) -> None:
|
|
"""Clear all data for a session."""
|
|
self._store.pop(session_id, None)
|
|
self._timestamps.pop(session_id, None)
|
|
|
|
def get_context_summary(self, session_id: str, max_items: int = 10) -> dict[str, Any]:
|
|
"""
|
|
Get a summary of working memory for context injection.
|
|
|
|
Useful for including relevant context in LLM prompts.
|
|
"""
|
|
session_data = self._store.get(session_id, {})
|
|
summary = {}
|
|
|
|
for key, value in list(session_data.items())[:max_items]:
|
|
if isinstance(value, list):
|
|
# For lists, include count and last few items
|
|
summary[key] = {
|
|
"type": "list",
|
|
"count": len(value),
|
|
"recent": value[-3:] if len(value) > 3 else value,
|
|
}
|
|
elif isinstance(value, dict):
|
|
# For dicts, include keys
|
|
summary[key] = {
|
|
"type": "dict",
|
|
"keys": list(value.keys())[:10],
|
|
}
|
|
else:
|
|
# For scalars, include the value (truncated if string)
|
|
if isinstance(value, str) and len(value) > 200:
|
|
summary[key] = value[:200] + "..."
|
|
else:
|
|
summary[key] = value
|
|
|
|
return summary
|
|
|
|
def get_all(self, session_id: str) -> dict[str, Any]:
|
|
"""Return all data for a session (copy)."""
|
|
return dict(self._store.get(session_id, {}))
|
|
|
|
def session_exists(self, session_id: str) -> bool:
|
|
"""Check if a session has any data."""
|
|
return session_id in self._store and bool(self._store[session_id])
|
|
|
|
def active_sessions(self) -> list[str]:
|
|
"""Return list of sessions with data."""
|
|
return [sid for sid, data in self._store.items() if data]
|
|
|
|
def session_count(self) -> int:
|
|
"""Return number of active sessions."""
|
|
return len([s for s in self._store.values() if s])
|
|
|
|
def _enforce_limits(self, session_id: str) -> None:
|
|
"""Enforce size limits on session data."""
|
|
session_data = self._store.get(session_id, {})
|
|
total_items = sum(
|
|
len(v) if isinstance(v, (list, dict)) else 1
|
|
for v in session_data.values()
|
|
)
|
|
|
|
if total_items > self._max_entries:
|
|
logger.warning(
|
|
"Working memory size limit exceeded",
|
|
extra={"session_id": session_id, "items": total_items},
|
|
)
|