116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
"""Optional response cache for LLM adapter."""
|
|
|
|
import hashlib
|
|
import json
|
|
from collections import OrderedDict
|
|
from typing import Any
|
|
|
|
from fusionagi.adapters.base import LLMAdapter
|
|
|
|
|
|
class CachedAdapter(LLMAdapter):
|
|
"""
|
|
Wraps an adapter and caches responses by messages hash.
|
|
|
|
Features:
|
|
- Caches both complete() and complete_structured() responses
|
|
- LRU eviction when at capacity (most recently used retained)
|
|
- Separate caches for text and structured responses
|
|
- Cache statistics for monitoring
|
|
"""
|
|
|
|
def __init__(self, adapter: LLMAdapter, max_entries: int = 100) -> None:
|
|
"""
|
|
Initialize the cached adapter.
|
|
|
|
Args:
|
|
adapter: The underlying LLM adapter to wrap.
|
|
max_entries: Maximum cache entries before eviction.
|
|
"""
|
|
self._adapter = adapter
|
|
self._cache: OrderedDict[str, str] = OrderedDict()
|
|
self._structured_cache: OrderedDict[str, Any] = OrderedDict()
|
|
self._max_entries = max_entries
|
|
self._hits = 0
|
|
self._misses = 0
|
|
|
|
def _key(self, messages: list[dict[str, str]], kwargs: dict[str, Any], prefix: str = "") -> str:
|
|
"""Generate a cache key from messages and kwargs."""
|
|
payload = json.dumps(
|
|
{"prefix": prefix, "messages": messages, "kwargs": kwargs},
|
|
sort_keys=True,
|
|
default=str,
|
|
)
|
|
return hashlib.sha256(payload.encode()).hexdigest()
|
|
|
|
def _evict_if_needed(self, cache: OrderedDict[str, Any]) -> None:
|
|
"""Evict least recently used entry if cache is at capacity."""
|
|
while len(cache) >= self._max_entries and cache:
|
|
cache.popitem(last=False)
|
|
|
|
def _get_and_touch(self, cache: OrderedDict[str, Any], key: str) -> Any:
|
|
"""Get value and move to end (LRU touch)."""
|
|
val = cache[key]
|
|
cache.move_to_end(key)
|
|
return val
|
|
|
|
def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
"""Complete with caching."""
|
|
key = self._key(messages, kwargs, prefix="complete")
|
|
if key in self._cache:
|
|
self._hits += 1
|
|
return self._get_and_touch(self._cache, key)
|
|
|
|
self._misses += 1
|
|
response = self._adapter.complete(messages, **kwargs)
|
|
self._evict_if_needed(self._cache)
|
|
self._cache[key] = response
|
|
return response
|
|
|
|
def complete_structured(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
schema: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""
|
|
Complete structured with caching.
|
|
|
|
Caches structured responses separately from text responses.
|
|
"""
|
|
cache_kwargs = {**kwargs, "_schema": schema}
|
|
key = self._key(messages, cache_kwargs, prefix="structured")
|
|
|
|
if key in self._structured_cache:
|
|
self._hits += 1
|
|
return self._get_and_touch(self._structured_cache, key)
|
|
|
|
self._misses += 1
|
|
response = self._adapter.complete_structured(messages, schema=schema, **kwargs)
|
|
|
|
if response is not None:
|
|
self._evict_if_needed(self._structured_cache)
|
|
self._structured_cache[key] = response
|
|
|
|
return response
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Return cache statistics."""
|
|
total = self._hits + self._misses
|
|
hit_rate = self._hits / total if total > 0 else 0.0
|
|
return {
|
|
"hits": self._hits,
|
|
"misses": self._misses,
|
|
"hit_rate": hit_rate,
|
|
"text_cache_size": len(self._cache),
|
|
"structured_cache_size": len(self._structured_cache),
|
|
"max_entries": self._max_entries,
|
|
}
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear all cached responses."""
|
|
self._cache.clear()
|
|
self._structured_cache.clear()
|
|
self._hits = 0
|
|
self._misses = 0
|