Files

51 lines
2.0 KiB
Python
Raw Permalink Normal View History

"""Multi-path inference: parallel hypothesis generation and scoring."""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
from fusionagi.schemas.atomic import AtomicSemanticUnit
from fusionagi.reasoning.tot import ThoughtNode
from fusionagi._logger import logger
def _score_coherence(node: ThoughtNode, _units: list[AtomicSemanticUnit]) -> float:
return node.score * (0.9 + 0.1 * min(1, len(node.trace) / 5))
def _score_consistency(node: ThoughtNode, units: list[AtomicSemanticUnit]) -> float:
if not units:
return 0.5
unit_content = " ".join(u.content.lower() for u in units)
thought_words = set(node.thought.lower().split())
unit_words = set(unit_content.split())
overlap = len(thought_words & unit_words) / max(len(thought_words), 1)
return min(1.0, overlap * 2)
def generate_and_score_parallel(
hypotheses: list[str],
units: list[AtomicSemanticUnit],
score_fn: Callable[[ThoughtNode, list[AtomicSemanticUnit]], float] | None = None,
) -> list[tuple[ThoughtNode, float]]:
"""Score multiple hypotheses in parallel."""
score_fn = score_fn or (lambda n, u: _score_coherence(n, u) * 0.5 + _score_consistency(n, u) * 0.5)
def score_one(h: str, i: int) -> tuple[ThoughtNode, float]:
node = ThoughtNode(thought=h, trace=[h], unit_refs=[u.unit_id for u in units[:10]])
s = score_fn(node, units)
node.score = s
return node, s
results: list[tuple[ThoughtNode, float]] = []
with ThreadPoolExecutor(max_workers=min(len(hypotheses), 8)) as ex:
futures = {ex.submit(score_one, h, i): i for i, h in enumerate(hypotheses)}
for future in as_completed(futures):
try:
node, score = future.result()
results.append((node, score))
except Exception as e:
logger.warning("Multi-path score failed", extra={"error": str(e)})
return results