Files
FusionAGI/fusionagi/core/head_orchestrator.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

340 lines
11 KiB
Python

"""Dvādaśa head orchestrator: parallel head dispatch, Witness coordination, second-pass."""
from __future__ import annotations
import math
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
from typing import TYPE_CHECKING, Any
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
if TYPE_CHECKING:
from fusionagi.core.orchestrator import Orchestrator
from fusionagi.schemas.head import HeadId, HeadOutput
from fusionagi.schemas.witness import FinalResponse
from fusionagi.schemas.commands import ParsedCommand, UserIntent
from fusionagi._logger import logger
# MVP: 5 heads. Full: 11.
MVP_HEADS: list[HeadId] = [
HeadId.LOGIC,
HeadId.RESEARCH,
HeadId.STRATEGY,
HeadId.SECURITY,
HeadId.SAFETY,
]
ALL_CONTENT_HEADS: list[HeadId] = [h for h in HeadId if h != HeadId.WITNESS]
# Heads for second-pass when risk/conflict/security
SECOND_PASS_HEADS: list[HeadId] = [HeadId.SECURITY, HeadId.SAFETY, HeadId.LOGIC]
# Thresholds for automatic second-pass
SECOND_PASS_CONFIG: dict[str, Any] = {
"min_confidence": 0.5,
"max_disputed": 3,
"security_keywords": ("security", "risk", "threat", "vulnerability"),
}
def run_heads_parallel(
orchestrator: Orchestrator,
task_id: str,
user_prompt: str,
head_ids: list[HeadId] | None = None,
sender: str = "head_orchestrator",
timeout_per_head: float = 60.0,
min_heads_ratio: float | None = 0.6,
) -> list[HeadOutput]:
"""
Dispatch head_request to multiple heads in parallel; collect HeadOutput.
Args:
orchestrator: Orchestrator with registered head agents.
task_id: Task identifier.
user_prompt: User's prompt/question.
head_ids: Heads to run (default: MVP_HEADS).
sender: Sender identity for messages.
timeout_per_head: Max seconds per head.
min_heads_ratio: Return early once this fraction of heads respond (0.6 = 60%).
None = wait for all heads. Reduces latency when some heads are slow.
Returns:
List of HeadOutput (may be partial on timeout/failure).
"""
heads = head_ids or MVP_HEADS
heads = [h for h in heads if h != HeadId.WITNESS]
if not heads:
return []
envelopes = [
AgentMessageEnvelope(
message=AgentMessage(
sender=sender,
recipient=hid.value,
intent="head_request",
payload={"prompt": user_prompt},
),
task_id=task_id,
)
for hid in heads
]
results: list[HeadOutput] = []
min_required = (
max(1, math.ceil(len(heads) * min_heads_ratio))
if min_heads_ratio is not None
else len(heads)
)
def run_one(env: AgentMessageEnvelope) -> HeadOutput | None:
resp = orchestrator.route_message_return(env)
if resp is None or resp.message.intent != "head_output":
return None
payload = resp.message.payload or {}
ho = payload.get("head_output")
if not isinstance(ho, dict):
return None
try:
return HeadOutput.model_validate(ho)
except Exception as e:
logger.warning("HeadOutput parse failed", extra={"error": str(e)})
return None
with ThreadPoolExecutor(max_workers=len(heads)) as ex:
future_to_env = {ex.submit(run_one, env): env for env in envelopes}
for future in as_completed(future_to_env, timeout=timeout_per_head * len(heads) + 5):
try:
out = future.result(timeout=1)
if out is not None:
results.append(out)
if len(results) >= min_required:
logger.info(
"Early exit: sufficient heads responded",
extra={"responded": len(results), "required": min_required},
)
break
except FuturesTimeoutError:
env = future_to_env[future]
logger.warning("Head timeout", extra={"recipient": env.message.recipient})
except Exception as e:
logger.exception("Head execution failed", extra={"error": str(e)})
return results
def run_witness(
orchestrator: Orchestrator,
task_id: str,
head_outputs: list[HeadOutput],
user_prompt: str,
sender: str = "head_orchestrator",
) -> FinalResponse | None:
"""
Route head outputs to Witness; return FinalResponse.
"""
envelope = AgentMessageEnvelope(
message=AgentMessage(
sender=sender,
recipient=HeadId.WITNESS.value,
intent="witness_request",
payload={
"head_outputs": [h.model_dump() for h in head_outputs],
"prompt": user_prompt,
},
),
task_id=task_id,
)
resp = orchestrator.route_message_return(envelope)
if resp is None or resp.message.intent != "witness_output":
return None
payload = resp.message.payload or {}
fr = payload.get("final_response")
if not isinstance(fr, dict):
return None
try:
return FinalResponse.model_validate(fr)
except Exception as e:
logger.warning("FinalResponse parse failed", extra={"error": str(e)})
return None
def run_second_pass(
orchestrator: Orchestrator,
task_id: str,
user_prompt: str,
initial_outputs: list[HeadOutput],
head_ids: list[HeadId] | None = None,
timeout_per_head: float = 60.0,
) -> list[HeadOutput]:
"""
Run second-pass heads (Security, Safety, Logic) and merge with initial outputs.
Replaces outputs from second-pass heads with new ones.
"""
heads = head_ids or SECOND_PASS_HEADS
heads = [h for h in heads if h != HeadId.WITNESS]
if not heads:
return initial_outputs
second_outputs = run_heads_parallel(
orchestrator,
task_id,
user_prompt,
head_ids=heads,
timeout_per_head=timeout_per_head,
)
by_head: dict[HeadId, HeadOutput] = {o.head_id: o for o in initial_outputs}
for o in second_outputs:
by_head[o.head_id] = o
return list(by_head.values())
def _should_run_second_pass(
final: FinalResponse,
force: bool = False,
second_pass_config: dict[str, Any] | None = None,
) -> bool:
"""Check if second-pass should run based on transparency report."""
if force:
return True
cfg = {**SECOND_PASS_CONFIG, **(second_pass_config or {})}
am = final.transparency_report.agreement_map
if am.confidence_score < cfg.get("min_confidence", 0.5):
return True
if len(am.disputed_claims) > cfg.get("max_disputed", 3):
return True
sr = (final.transparency_report.safety_report or "").lower()
if any(kw in sr for kw in cfg.get("security_keywords", ())):
return True
return False
def run_dvadasa(
orchestrator: Orchestrator,
task_id: str,
user_prompt: str,
parsed: ParsedCommand | None = None,
head_ids: list[HeadId] | None = None,
timeout_per_head: float = 60.0,
event_bus: Any | None = None,
force_second_pass: bool = False,
return_head_outputs: bool = False,
second_pass_config: dict[str, Any] | None = None,
min_heads_ratio: float | None = 0.6,
) -> FinalResponse | tuple[FinalResponse, list[HeadOutput]] | tuple[None, list[HeadOutput]] | None:
"""
Full Dvādaśa flow: run heads in parallel, then Witness.
Args:
orchestrator: Orchestrator with heads and witness registered.
task_id: Task identifier.
user_prompt: User's prompt (or use parsed.cleaned_prompt when HEAD_STRATEGY).
parsed: Optional ParsedCommand from parse_user_input.
head_ids: Override heads to run (e.g. single head for HEAD_STRATEGY).
timeout_per_head: Max seconds per head.
event_bus: Optional EventBus to publish dvadasa_complete.
second_pass_config: Override SECOND_PASS_CONFIG (min_confidence, max_disputed, etc).
min_heads_ratio: Early exit once this fraction of heads respond; None = wait all.
Returns:
FinalResponse or None on failure.
"""
prompt = user_prompt
heads = head_ids
if parsed:
if parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id and parsed.cleaned_prompt:
prompt = parsed.cleaned_prompt
heads = [parsed.head_id]
elif parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id:
heads = [parsed.head_id]
if heads is None:
heads = select_heads_for_complexity(prompt)
head_outputs = run_heads_parallel(
orchestrator,
task_id,
prompt,
head_ids=heads,
timeout_per_head=timeout_per_head,
)
if not head_outputs:
logger.warning("No head outputs; cannot run Witness")
return (None, []) if return_head_outputs else None
final = run_witness(orchestrator, task_id, head_outputs, prompt)
if final and (
force_second_pass
or _should_run_second_pass(
final, force=force_second_pass, second_pass_config=second_pass_config
)
):
head_outputs = run_second_pass(
orchestrator,
task_id,
prompt,
head_outputs,
timeout_per_head=timeout_per_head,
)
final = run_witness(orchestrator, task_id, head_outputs, prompt)
if final and event_bus:
try:
event_bus.publish(
"dvadasa_complete",
{
"task_id": task_id,
"final_response": final.model_dump(),
"head_count": len(head_outputs),
},
)
except Exception as e:
logger.warning("Failed to publish dvadasa_complete", extra={"error": str(e)})
if return_head_outputs:
return (final, head_outputs)
return final
def extract_sources_from_head_outputs(head_outputs: list[HeadOutput]) -> list[dict[str, Any]]:
"""Extract citations from head outputs for SOURCES command."""
sources: list[dict[str, Any]] = []
seen: set[tuple[str, str]] = set()
for ho in head_outputs:
for claim in ho.claims:
for ev in claim.evidence:
key = (ho.head_id.value, ev.source_id or "")
if key in seen or not ev.source_id:
continue
seen.add(key)
sources.append({
"head_id": ho.head_id.value,
"source_id": ev.source_id,
"excerpt": ev.excerpt or "",
"confidence": ev.confidence,
})
return sources
def select_heads_for_complexity(
prompt: str,
mvp_heads: list[HeadId] = MVP_HEADS,
all_heads: list[HeadId] | None = None,
) -> list[HeadId]:
"""
Dynamic routing: simple prompts use fewer heads.
Heuristic: long prompt or keywords => all heads.
"""
all_heads = all_heads or ALL_CONTENT_HEADS
prompt_lower = prompt.lower()
complex_keywords = (
"security", "risk", "architecture", "scalability", "compliance",
"critical", "production", "audit", "privacy", "sensitive",
)
if len(prompt.split()) > 50 or any(kw in prompt_lower for kw in complex_keywords):
return all_heads
return mvp_heads