"""Consensus engine: claim collection, deduplication, conflict detection, scoring.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Any from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim from fusionagi.schemas.witness import AgreementMap from fusionagi._logger import logger @dataclass class CollectedClaim: """Claim with source head and metadata for consensus.""" claim_text: str confidence: float head_id: HeadId evidence_count: int raw: dict[str, Any] def _normalize_text(s: str) -> str: """Normalize for duplicate detection.""" return " ".join(s.lower().split()) def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool: """Simple similarity: exact match or one contains the other (normalized).""" na, nb = _normalize_text(a), _normalize_text(b) if na == nb: return True if len(na) < 10 or len(nb) < 10: return na == nb # Jaccard-like: word overlap wa, wb = set(na.split()), set(nb.split()) inter = len(wa & wb) union = len(wa | wb) if union == 0: return False return (inter / union) >= threshold def _looks_contradictory(a: str, b: str) -> bool: """Heuristic: same subject with opposite polarity indicators.""" neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"} na, nb = _normalize_text(a), _normalize_text(b) wa, wb = set(na.split()), set(nb.split()) # If one has neg and the other doesn't, and they share significant overlap a_neg = bool(wa & neg_words) b_neg = bool(wb & neg_words) if a_neg != b_neg: overlap = len(wa & wb) / max(len(wa), 1) if overlap > 0.3: return True return False def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]: """Flatten all head claims with source metadata.""" collected: list[CollectedClaim] = [] for out in outputs: for c in out.claims: collected.append( CollectedClaim( claim_text=c.claim_text, confidence=c.confidence, head_id=out.head_id, evidence_count=len(c.evidence), raw={ "claim_text": c.claim_text, "confidence": c.confidence, "head_id": out.head_id.value, "evidence_count": len(c.evidence), "assumptions": c.assumptions, }, ) ) return collected def run_consensus( outputs: list[HeadOutput], head_weights: dict[HeadId, float] | None = None, confidence_threshold: float = 0.5, ) -> AgreementMap: """ Run consensus: deduplicate, detect conflicts, score, produce AgreementMap. Args: outputs: HeadOutput from all content heads. head_weights: Optional per-head reliability weights (default 1.0). confidence_threshold: Minimum confidence for agreed claim. Returns: AgreementMap with agreed_claims, disputed_claims, confidence_score. """ if not outputs: return AgreementMap( agreed_claims=[], disputed_claims=[], confidence_score=0.0, ) weights = head_weights or {h: 1.0 for h in HeadId} collected = collect_claims(outputs) # Group by similarity (merge near-duplicates) merged: list[CollectedClaim] = [] used: set[int] = set() for i, ca in enumerate(collected): if i in used: continue group = [ca] used.add(i) for j, cb in enumerate(collected): if j in used: continue if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text): group.append(cb) used.add(j) # Aggregate: weighted avg confidence, combine heads if len(group) == 1: c = group[0] score = c.confidence * weights.get(c.head_id, 1.0) if c.evidence_count > 0: score *= 1.1 # boost for citations merged.append( CollectedClaim( claim_text=c.claim_text, confidence=score, head_id=c.head_id, evidence_count=c.evidence_count, raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]}, ) ) else: total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group) avg_conf = total_conf / len(group) evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0 score = min(1.0, avg_conf * evidence_boost) merged.append( CollectedClaim( claim_text=group[0].claim_text, confidence=score, head_id=group[0].head_id, evidence_count=sum(g.evidence_count for g in group), raw={ "claim_text": group[0].claim_text, "aggregated_confidence": score, "supporting_heads": [g.head_id.value for g in group], }, ) ) # Conflict detection agreed: list[dict[str, Any]] = [] disputed: list[dict[str, Any]] = [] for c in merged: in_conflict = False for d in merged: if c is d: continue if _looks_contradictory(c.claim_text, d.claim_text): in_conflict = True break rec = { "claim_text": c.claim_text, "confidence": c.confidence, "supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]), } if in_conflict or c.confidence < confidence_threshold: disputed.append(rec) else: agreed.append(rec) overall_conf = ( sum(a["confidence"] for a in agreed) / len(agreed) if agreed else 0.0 ) logger.info( "Consensus complete", extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf}, ) return AgreementMap( agreed_claims=agreed, disputed_claims=disputed, confidence_score=min(1.0, overall_conf), )