"""MAA Gate: governance integration; MPC check and tool classification for manufacturing tools.""" from typing import Any from fusionagi.maa.gap_detection import check_gaps, GapReport from fusionagi.maa.layers.mpc_authority import MPCAuthority from fusionagi.maa.layers.dlt_engine import DLTEngine from fusionagi._logger import logger # Default manufacturing tool names that require MPC DEFAULT_MANUFACTURING_TOOLS = frozenset({"cnc_emit", "am_slice", "machine_bind"}) class MAAGate: """ Gate for manufacturing tools: (tool_name, args) -> (allowed, sanitized_args | error_message). Compatible with Guardrails.add_check. Manufacturing tools require valid MPC and no gaps. """ def __init__( self, mpc_authority: MPCAuthority, dlt_engine: DLTEngine | None = None, manufacturing_tools: set[str] | frozenset[str] | None = None, ) -> None: self._mpc = mpc_authority self._dlt = dlt_engine or DLTEngine() self._manufacturing_tools = manufacturing_tools or DEFAULT_MANUFACTURING_TOOLS def is_manufacturing(self, tool_name: str, tool_def: Any = None) -> bool: """Return True if tool is classified as manufacturing (allowlist or ToolDef scope).""" if tool_def is not None and getattr(tool_def, "manufacturing", False): return True return tool_name in self._manufacturing_tools def check(self, tool_name: str, args: dict[str, Any]) -> tuple[bool, dict[str, Any] | str]: """ Pre-check for Guardrails: (tool_name, args) -> (allowed, sanitized_args or error_message). Non-manufacturing tools pass through. Manufacturing tools require mpc_id, valid MPC, no gaps. """ if not self.is_manufacturing(tool_name, None): logger.debug("MAA check pass-through (non-manufacturing)", extra={"tool_name": tool_name}) return True, args mpc_id_value = args.get("mpc_id") or args.get("mpc_id_value") if not mpc_id_value: logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "missing mpc_id"}) return False, "MAA: manufacturing tool requires mpc_id in args" cert = self._mpc.verify(mpc_id_value) if cert is None: logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "invalid or unknown MPC"}) return False, f"MAA: invalid or unknown MPC: {mpc_id_value}" context: dict[str, Any] = { **args, "mpc_id": mpc_id_value, "mpc_version": cert.mpc_id.version, } gaps = check_gaps(context) if gaps: root_cause = _format_root_cause(gaps) logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "gaps", "gap_count": len(gaps)}) return False, root_cause # Optional DLT evaluation when dlt_contract_id and dlt_context are in args dlt_contract_id = args.get("dlt_contract_id") if dlt_contract_id: dlt_context = args.get("dlt_context") or context ok, cause = self._dlt.evaluate(dlt_contract_id, dlt_context) if not ok: logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "dlt_failed"}) return False, f"MAA DLT: {cause}" logger.debug("MAA check allowed", extra={"tool_name": tool_name}) return True, args def _format_root_cause(gaps: list[GapReport]) -> str: """Format gap reports as single root-cause message.""" parts = [f"MAA gap: {g.gap_class.value} — {g.description}" for g in gaps] if any(g.required_resolution for g in gaps): parts.append("Required resolution: " + "; ".join(g.required_resolution for g in gaps if g.required_resolution)) return " | ".join(parts)