"""Working memory: in-memory key-value / list per task/session. Working memory provides short-term storage for active tasks: - Key-value storage per session/task - List append operations for accumulating results - Context retrieval for reasoning - Session lifecycle management """ from collections import defaultdict from datetime import datetime from typing import Any, Iterator from fusionagi._logger import logger from fusionagi._time import utc_now class WorkingMemory: """ Short-term working memory per task/session. Features: - Key-value get/set operations - List append with automatic coercion - Context summary for LLM prompts - Session management and cleanup - Size limits to prevent unbounded growth """ def __init__(self, max_entries_per_session: int = 1000) -> None: """ Initialize working memory. Args: max_entries_per_session: Maximum entries per session before oldest are removed. """ self._store: dict[str, dict[str, Any]] = defaultdict(dict) self._timestamps: dict[str, datetime] = {} self._max_entries = max_entries_per_session def get(self, session_id: str, key: str, default: Any = None) -> Any: """Get value for session and key; returns default if not found.""" return self._store[session_id].get(key, default) def set(self, session_id: str, key: str, value: Any) -> None: """Set value for session and key.""" self._store[session_id][key] = value self._timestamps[session_id] = utc_now() self._enforce_limits(session_id) def append(self, session_id: str, key: str, value: Any) -> None: """Append to list for session and key (creates list if needed).""" if key not in self._store[session_id]: self._store[session_id][key] = [] lst = self._store[session_id][key] if not isinstance(lst, list): lst = [lst] self._store[session_id][key] = lst lst.append(value) self._timestamps[session_id] = utc_now() self._enforce_limits(session_id) def get_list(self, session_id: str, key: str) -> list[Any]: """Return list for session and key (copy).""" val = self._store[session_id].get(key) if isinstance(val, list): return list(val) return [val] if val is not None else [] def has(self, session_id: str, key: str) -> bool: """Check if a key exists in session.""" return key in self._store.get(session_id, {}) def keys(self, session_id: str) -> list[str]: """Return all keys for a session.""" return list(self._store.get(session_id, {}).keys()) def delete(self, session_id: str, key: str) -> bool: """Delete a key from session. Returns True if existed.""" if session_id in self._store and key in self._store[session_id]: del self._store[session_id][key] return True return False def clear_session(self, session_id: str) -> None: """Clear all data for a session.""" self._store.pop(session_id, None) self._timestamps.pop(session_id, None) def get_context_summary(self, session_id: str, max_items: int = 10) -> dict[str, Any]: """ Get a summary of working memory for context injection. Useful for including relevant context in LLM prompts. """ session_data = self._store.get(session_id, {}) summary = {} for key, value in list(session_data.items())[:max_items]: if isinstance(value, list): # For lists, include count and last few items summary[key] = { "type": "list", "count": len(value), "recent": value[-3:] if len(value) > 3 else value, } elif isinstance(value, dict): # For dicts, include keys summary[key] = { "type": "dict", "keys": list(value.keys())[:10], } else: # For scalars, include the value (truncated if string) if isinstance(value, str) and len(value) > 200: summary[key] = value[:200] + "..." else: summary[key] = value return summary def get_all(self, session_id: str) -> dict[str, Any]: """Return all data for a session (copy).""" return dict(self._store.get(session_id, {})) def session_exists(self, session_id: str) -> bool: """Check if a session has any data.""" return session_id in self._store and bool(self._store[session_id]) def active_sessions(self) -> list[str]: """Return list of sessions with data.""" return [sid for sid, data in self._store.items() if data] def session_count(self) -> int: """Return number of active sessions.""" return len([s for s in self._store.values() if s]) def _enforce_limits(self, session_id: str) -> None: """Enforce size limits on session data.""" session_data = self._store.get(session_id, {}) total_items = sum( len(v) if isinstance(v, (list, dict)) else 1 for v in session_data.values() ) if total_items > self._max_entries: logger.warning( "Working memory size limit exceeded", extra={"session_id": session_id, "items": total_items}, )