Files
FusionAGI/fusionagi/multi_agent/consensus_engine.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

195 lines
6.3 KiB
Python

"""Consensus engine: claim collection, deduplication, conflict detection, scoring."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim
from fusionagi.schemas.witness import AgreementMap
from fusionagi._logger import logger
@dataclass
class CollectedClaim:
"""Claim with source head and metadata for consensus."""
claim_text: str
confidence: float
head_id: HeadId
evidence_count: int
raw: dict[str, Any]
def _normalize_text(s: str) -> str:
"""Normalize for duplicate detection."""
return " ".join(s.lower().split())
def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool:
"""Simple similarity: exact match or one contains the other (normalized)."""
na, nb = _normalize_text(a), _normalize_text(b)
if na == nb:
return True
if len(na) < 10 or len(nb) < 10:
return na == nb
# Jaccard-like: word overlap
wa, wb = set(na.split()), set(nb.split())
inter = len(wa & wb)
union = len(wa | wb)
if union == 0:
return False
return (inter / union) >= threshold
def _looks_contradictory(a: str, b: str) -> bool:
"""Heuristic: same subject with opposite polarity indicators."""
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
na, nb = _normalize_text(a), _normalize_text(b)
wa, wb = set(na.split()), set(nb.split())
# If one has neg and the other doesn't, and they share significant overlap
a_neg = bool(wa & neg_words)
b_neg = bool(wb & neg_words)
if a_neg != b_neg:
overlap = len(wa & wb) / max(len(wa), 1)
if overlap > 0.3:
return True
return False
def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]:
"""Flatten all head claims with source metadata."""
collected: list[CollectedClaim] = []
for out in outputs:
for c in out.claims:
collected.append(
CollectedClaim(
claim_text=c.claim_text,
confidence=c.confidence,
head_id=out.head_id,
evidence_count=len(c.evidence),
raw={
"claim_text": c.claim_text,
"confidence": c.confidence,
"head_id": out.head_id.value,
"evidence_count": len(c.evidence),
"assumptions": c.assumptions,
},
)
)
return collected
def run_consensus(
outputs: list[HeadOutput],
head_weights: dict[HeadId, float] | None = None,
confidence_threshold: float = 0.5,
) -> AgreementMap:
"""
Run consensus: deduplicate, detect conflicts, score, produce AgreementMap.
Args:
outputs: HeadOutput from all content heads.
head_weights: Optional per-head reliability weights (default 1.0).
confidence_threshold: Minimum confidence for agreed claim.
Returns:
AgreementMap with agreed_claims, disputed_claims, confidence_score.
"""
if not outputs:
return AgreementMap(
agreed_claims=[],
disputed_claims=[],
confidence_score=0.0,
)
weights = head_weights or {h: 1.0 for h in HeadId}
collected = collect_claims(outputs)
# Group by similarity (merge near-duplicates)
merged: list[CollectedClaim] = []
used: set[int] = set()
for i, ca in enumerate(collected):
if i in used:
continue
group = [ca]
used.add(i)
for j, cb in enumerate(collected):
if j in used:
continue
if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text):
group.append(cb)
used.add(j)
# Aggregate: weighted avg confidence, combine heads
if len(group) == 1:
c = group[0]
score = c.confidence * weights.get(c.head_id, 1.0)
if c.evidence_count > 0:
score *= 1.1 # boost for citations
merged.append(
CollectedClaim(
claim_text=c.claim_text,
confidence=score,
head_id=c.head_id,
evidence_count=c.evidence_count,
raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]},
)
)
else:
total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group)
avg_conf = total_conf / len(group)
evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0
score = min(1.0, avg_conf * evidence_boost)
merged.append(
CollectedClaim(
claim_text=group[0].claim_text,
confidence=score,
head_id=group[0].head_id,
evidence_count=sum(g.evidence_count for g in group),
raw={
"claim_text": group[0].claim_text,
"aggregated_confidence": score,
"supporting_heads": [g.head_id.value for g in group],
},
)
)
# Conflict detection
agreed: list[dict[str, Any]] = []
disputed: list[dict[str, Any]] = []
for c in merged:
in_conflict = False
for d in merged:
if c is d:
continue
if _looks_contradictory(c.claim_text, d.claim_text):
in_conflict = True
break
rec = {
"claim_text": c.claim_text,
"confidence": c.confidence,
"supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]),
}
if in_conflict or c.confidence < confidence_threshold:
disputed.append(rec)
else:
agreed.append(rec)
overall_conf = (
sum(a["confidence"] for a in agreed) / len(agreed)
if agreed
else 0.0
)
logger.info(
"Consensus complete",
extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf},
)
return AgreementMap(
agreed_claims=agreed,
disputed_claims=disputed,
confidence_score=min(1.0, overall_conf),
)