195 lines
6.3 KiB
Python
195 lines
6.3 KiB
Python
"""Consensus engine: claim collection, deduplication, conflict detection, scoring."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim
|
|
from fusionagi.schemas.witness import AgreementMap
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@dataclass
|
|
class CollectedClaim:
|
|
"""Claim with source head and metadata for consensus."""
|
|
|
|
claim_text: str
|
|
confidence: float
|
|
head_id: HeadId
|
|
evidence_count: int
|
|
raw: dict[str, Any]
|
|
|
|
|
|
def _normalize_text(s: str) -> str:
|
|
"""Normalize for duplicate detection."""
|
|
return " ".join(s.lower().split())
|
|
|
|
|
|
def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool:
|
|
"""Simple similarity: exact match or one contains the other (normalized)."""
|
|
na, nb = _normalize_text(a), _normalize_text(b)
|
|
if na == nb:
|
|
return True
|
|
if len(na) < 10 or len(nb) < 10:
|
|
return na == nb
|
|
# Jaccard-like: word overlap
|
|
wa, wb = set(na.split()), set(nb.split())
|
|
inter = len(wa & wb)
|
|
union = len(wa | wb)
|
|
if union == 0:
|
|
return False
|
|
return (inter / union) >= threshold
|
|
|
|
|
|
def _looks_contradictory(a: str, b: str) -> bool:
|
|
"""Heuristic: same subject with opposite polarity indicators."""
|
|
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
|
|
na, nb = _normalize_text(a), _normalize_text(b)
|
|
wa, wb = set(na.split()), set(nb.split())
|
|
# If one has neg and the other doesn't, and they share significant overlap
|
|
a_neg = bool(wa & neg_words)
|
|
b_neg = bool(wb & neg_words)
|
|
if a_neg != b_neg:
|
|
overlap = len(wa & wb) / max(len(wa), 1)
|
|
if overlap > 0.3:
|
|
return True
|
|
return False
|
|
|
|
|
|
def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]:
|
|
"""Flatten all head claims with source metadata."""
|
|
collected: list[CollectedClaim] = []
|
|
for out in outputs:
|
|
for c in out.claims:
|
|
collected.append(
|
|
CollectedClaim(
|
|
claim_text=c.claim_text,
|
|
confidence=c.confidence,
|
|
head_id=out.head_id,
|
|
evidence_count=len(c.evidence),
|
|
raw={
|
|
"claim_text": c.claim_text,
|
|
"confidence": c.confidence,
|
|
"head_id": out.head_id.value,
|
|
"evidence_count": len(c.evidence),
|
|
"assumptions": c.assumptions,
|
|
},
|
|
)
|
|
)
|
|
return collected
|
|
|
|
|
|
def run_consensus(
|
|
outputs: list[HeadOutput],
|
|
head_weights: dict[HeadId, float] | None = None,
|
|
confidence_threshold: float = 0.5,
|
|
) -> AgreementMap:
|
|
"""
|
|
Run consensus: deduplicate, detect conflicts, score, produce AgreementMap.
|
|
|
|
Args:
|
|
outputs: HeadOutput from all content heads.
|
|
head_weights: Optional per-head reliability weights (default 1.0).
|
|
confidence_threshold: Minimum confidence for agreed claim.
|
|
|
|
Returns:
|
|
AgreementMap with agreed_claims, disputed_claims, confidence_score.
|
|
"""
|
|
if not outputs:
|
|
return AgreementMap(
|
|
agreed_claims=[],
|
|
disputed_claims=[],
|
|
confidence_score=0.0,
|
|
)
|
|
|
|
weights = head_weights or {h: 1.0 for h in HeadId}
|
|
collected = collect_claims(outputs)
|
|
|
|
# Group by similarity (merge near-duplicates)
|
|
merged: list[CollectedClaim] = []
|
|
used: set[int] = set()
|
|
for i, ca in enumerate(collected):
|
|
if i in used:
|
|
continue
|
|
group = [ca]
|
|
used.add(i)
|
|
for j, cb in enumerate(collected):
|
|
if j in used:
|
|
continue
|
|
if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text):
|
|
group.append(cb)
|
|
used.add(j)
|
|
# Aggregate: weighted avg confidence, combine heads
|
|
if len(group) == 1:
|
|
c = group[0]
|
|
score = c.confidence * weights.get(c.head_id, 1.0)
|
|
if c.evidence_count > 0:
|
|
score *= 1.1 # boost for citations
|
|
merged.append(
|
|
CollectedClaim(
|
|
claim_text=c.claim_text,
|
|
confidence=score,
|
|
head_id=c.head_id,
|
|
evidence_count=c.evidence_count,
|
|
raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]},
|
|
)
|
|
)
|
|
else:
|
|
total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group)
|
|
avg_conf = total_conf / len(group)
|
|
evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0
|
|
score = min(1.0, avg_conf * evidence_boost)
|
|
merged.append(
|
|
CollectedClaim(
|
|
claim_text=group[0].claim_text,
|
|
confidence=score,
|
|
head_id=group[0].head_id,
|
|
evidence_count=sum(g.evidence_count for g in group),
|
|
raw={
|
|
"claim_text": group[0].claim_text,
|
|
"aggregated_confidence": score,
|
|
"supporting_heads": [g.head_id.value for g in group],
|
|
},
|
|
)
|
|
)
|
|
|
|
# Conflict detection
|
|
agreed: list[dict[str, Any]] = []
|
|
disputed: list[dict[str, Any]] = []
|
|
|
|
for c in merged:
|
|
in_conflict = False
|
|
for d in merged:
|
|
if c is d:
|
|
continue
|
|
if _looks_contradictory(c.claim_text, d.claim_text):
|
|
in_conflict = True
|
|
break
|
|
rec = {
|
|
"claim_text": c.claim_text,
|
|
"confidence": c.confidence,
|
|
"supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]),
|
|
}
|
|
if in_conflict or c.confidence < confidence_threshold:
|
|
disputed.append(rec)
|
|
else:
|
|
agreed.append(rec)
|
|
|
|
overall_conf = (
|
|
sum(a["confidence"] for a in agreed) / len(agreed)
|
|
if agreed
|
|
else 0.0
|
|
)
|
|
|
|
logger.info(
|
|
"Consensus complete",
|
|
extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf},
|
|
)
|
|
|
|
return AgreementMap(
|
|
agreed_claims=agreed,
|
|
disputed_claims=disputed,
|
|
confidence_score=min(1.0, overall_conf),
|
|
)
|