86 lines
3.7 KiB
Python
86 lines
3.7 KiB
Python
"""MAA Gate: governance integration; MPC check and tool classification for manufacturing tools."""
|
|
|
|
from typing import Any
|
|
|
|
from fusionagi.maa.gap_detection import check_gaps, GapReport
|
|
from fusionagi.maa.layers.mpc_authority import MPCAuthority
|
|
from fusionagi.maa.layers.dlt_engine import DLTEngine
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
# Default manufacturing tool names that require MPC
|
|
DEFAULT_MANUFACTURING_TOOLS = frozenset({"cnc_emit", "am_slice", "machine_bind"})
|
|
|
|
|
|
class MAAGate:
|
|
"""
|
|
Gate for manufacturing tools: (tool_name, args) -> (allowed, sanitized_args | error_message).
|
|
Compatible with Guardrails.add_check. Manufacturing tools require valid MPC and no gaps.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mpc_authority: MPCAuthority,
|
|
dlt_engine: DLTEngine | None = None,
|
|
manufacturing_tools: set[str] | frozenset[str] | None = None,
|
|
) -> None:
|
|
self._mpc = mpc_authority
|
|
self._dlt = dlt_engine or DLTEngine()
|
|
self._manufacturing_tools = manufacturing_tools or DEFAULT_MANUFACTURING_TOOLS
|
|
|
|
def is_manufacturing(self, tool_name: str, tool_def: Any = None) -> bool:
|
|
"""Return True if tool is classified as manufacturing (allowlist or ToolDef scope)."""
|
|
if tool_def is not None and getattr(tool_def, "manufacturing", False):
|
|
return True
|
|
return tool_name in self._manufacturing_tools
|
|
|
|
def check(self, tool_name: str, args: dict[str, Any]) -> tuple[bool, dict[str, Any] | str]:
|
|
"""
|
|
Pre-check for Guardrails: (tool_name, args) -> (allowed, sanitized_args or error_message).
|
|
Non-manufacturing tools pass through. Manufacturing tools require mpc_id, valid MPC, no gaps.
|
|
"""
|
|
if not self.is_manufacturing(tool_name, None):
|
|
logger.debug("MAA check pass-through (non-manufacturing)", extra={"tool_name": tool_name})
|
|
return True, args
|
|
|
|
mpc_id_value = args.get("mpc_id") or args.get("mpc_id_value")
|
|
if not mpc_id_value:
|
|
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "missing mpc_id"})
|
|
return False, "MAA: manufacturing tool requires mpc_id in args"
|
|
|
|
cert = self._mpc.verify(mpc_id_value)
|
|
if cert is None:
|
|
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "invalid or unknown MPC"})
|
|
return False, f"MAA: invalid or unknown MPC: {mpc_id_value}"
|
|
|
|
context: dict[str, Any] = {
|
|
**args,
|
|
"mpc_id": mpc_id_value,
|
|
"mpc_version": cert.mpc_id.version,
|
|
}
|
|
gaps = check_gaps(context)
|
|
if gaps:
|
|
root_cause = _format_root_cause(gaps)
|
|
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "gaps", "gap_count": len(gaps)})
|
|
return False, root_cause
|
|
|
|
# Optional DLT evaluation when dlt_contract_id and dlt_context are in args
|
|
dlt_contract_id = args.get("dlt_contract_id")
|
|
if dlt_contract_id:
|
|
dlt_context = args.get("dlt_context") or context
|
|
ok, cause = self._dlt.evaluate(dlt_contract_id, dlt_context)
|
|
if not ok:
|
|
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "dlt_failed"})
|
|
return False, f"MAA DLT: {cause}"
|
|
|
|
logger.debug("MAA check allowed", extra={"tool_name": tool_name})
|
|
return True, args
|
|
|
|
|
|
def _format_root_cause(gaps: list[GapReport]) -> str:
|
|
"""Format gap reports as single root-cause message."""
|
|
parts = [f"MAA gap: {g.gap_class.value} — {g.description}" for g in gaps]
|
|
if any(g.required_resolution for g in gaps):
|
|
parts.append("Required resolution: " + "; ".join(g.required_resolution for g in gaps if g.required_resolution))
|
|
return " | ".join(parts)
|