"""Retrieve-by-reference: load context for reasoning without token overflow.""" from __future__ import annotations from typing import Any, Protocol, runtime_checkable from fusionagi.schemas.atomic import AtomicSemanticUnit from fusionagi.memory.sharding import Shard, shard_context @runtime_checkable class SemanticGraphLike(Protocol): """Protocol for semantic graph access.""" def get_unit(self, unit_id: str) -> AtomicSemanticUnit | None: ... def query_units(self, unit_ids: list[str] | None = None, limit: int = 100) -> list[AtomicSemanticUnit]: ... @runtime_checkable class SharderLike(Protocol): """Protocol for sharding.""" def __call__(self, units: list[AtomicSemanticUnit], max_cluster_size: int) -> list[Shard]: ... def load_context_for_reasoning( query_units: list[AtomicSemanticUnit], semantic_graph: SemanticGraphLike | None = None, sharder: SharderLike | None = None, max_cluster_size: int = 20, ) -> dict[str, Any]: """ Fetch relevant shards/units by reference for reasoning. Returns structured context (unit IDs + summaries) rather than raw text. """ shard_fn = sharder or shard_context shards = shard_fn(query_units, max_cluster_size=max_cluster_size) unit_refs: list[str] = [] unit_summaries: dict[str, str] = {} for u in query_units: unit_refs.append(u.unit_id) unit_summaries[u.unit_id] = u.content[:150] if semantic_graph: for s in shards: for uid in s.unit_ids: if uid not in unit_summaries: found = semantic_graph.get_unit(uid) if found: unit_summaries[uid] = found.content[:150] unit_refs.append(uid) return {"unit_refs": unit_refs, "shards": shards, "unit_summaries": unit_summaries} def build_compact_prompt( units: list[AtomicSemanticUnit], max_chars: int = 4000, ) -> str: """Materialize text for units that fit; rest stay as references.""" parts: list[str] = [] total = 0 refs: list[str] = [] for u in units: line = f"[{u.unit_id}] {u.content}\n" if total + len(line) <= max_chars: parts.append(line) total += len(line) else: refs.append(u.unit_id) if refs: parts.append(f"\n[References: {', '.join(refs[:20])}]") return "".join(parts) if parts else "[No units]"