237 lines
9.4 KiB
Python
237 lines
9.4 KiB
Python
|
|
"""Executor agent: receives execute_step, invokes tool via safe runner, returns step_done/step_failed."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any, TYPE_CHECKING
|
||
|
|
|
||
|
|
from fusionagi.agents.base_agent import BaseAgent
|
||
|
|
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||
|
|
from fusionagi.schemas.plan import Plan
|
||
|
|
from fusionagi.planning import get_step
|
||
|
|
from fusionagi.tools.registry import ToolRegistry
|
||
|
|
from fusionagi.tools.runner import run_tool
|
||
|
|
from fusionagi._logger import logger
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from fusionagi.core.state_manager import StateManager
|
||
|
|
from fusionagi.governance.guardrails import Guardrails
|
||
|
|
from fusionagi.governance.rate_limiter import RateLimiter
|
||
|
|
from fusionagi.governance.access_control import AccessControl
|
||
|
|
from fusionagi.governance.override import OverrideHooks
|
||
|
|
from fusionagi.memory.episodic import EpisodicMemory
|
||
|
|
|
||
|
|
|
||
|
|
class ExecutorAgent(BaseAgent):
|
||
|
|
"""
|
||
|
|
Executes steps: maps step to tool call, runs via safe runner, emits step_done/step_failed.
|
||
|
|
|
||
|
|
Supports full governance integration:
|
||
|
|
- Guardrails: Pre/post checks for tool invocations
|
||
|
|
- RateLimiter: Limits tool invocation rate per agent/tool
|
||
|
|
- AccessControl: Policy-based tool access control
|
||
|
|
- OverrideHooks: Human-in-the-loop for high-risk operations
|
||
|
|
- EpisodicMemory: Records step outcomes for learning
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
identity: str = "executor",
|
||
|
|
registry: ToolRegistry | None = None,
|
||
|
|
state_manager: StateManager | None = None,
|
||
|
|
guardrails: Guardrails | None = None,
|
||
|
|
rate_limiter: RateLimiter | None = None,
|
||
|
|
access_control: AccessControl | None = None,
|
||
|
|
override_hooks: OverrideHooks | None = None,
|
||
|
|
episodic_memory: EpisodicMemory | None = None,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Initialize the executor agent.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
identity: Agent identifier.
|
||
|
|
registry: Tool registry for tool lookup.
|
||
|
|
state_manager: State manager for trace storage.
|
||
|
|
guardrails: Guardrails for pre/post checks.
|
||
|
|
rate_limiter: Rate limiter for tool invocation throttling.
|
||
|
|
access_control: Access control for policy-based tool access.
|
||
|
|
override_hooks: Override hooks for human-in-the-loop.
|
||
|
|
episodic_memory: Episodic memory for recording step outcomes.
|
||
|
|
"""
|
||
|
|
super().__init__(
|
||
|
|
identity=identity,
|
||
|
|
role="Executor",
|
||
|
|
objective="Execute plan steps via tools",
|
||
|
|
memory_access=False,
|
||
|
|
tool_permissions=["*"],
|
||
|
|
)
|
||
|
|
self._registry = registry or ToolRegistry()
|
||
|
|
self._state = state_manager
|
||
|
|
self._guardrails = guardrails
|
||
|
|
self._rate_limiter = rate_limiter
|
||
|
|
self._access_control = access_control
|
||
|
|
self._override_hooks = override_hooks
|
||
|
|
self._episodic_memory = episodic_memory
|
||
|
|
|
||
|
|
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||
|
|
"""On execute_step, run tool and return step_done or step_failed."""
|
||
|
|
if envelope.message.intent != "execute_step":
|
||
|
|
return None
|
||
|
|
logger.info(
|
||
|
|
"Executor handle_message",
|
||
|
|
extra={"recipient": self.identity, "intent": envelope.message.intent},
|
||
|
|
)
|
||
|
|
payload = envelope.message.payload
|
||
|
|
task_id = envelope.task_id
|
||
|
|
step_id = payload.get("step_id")
|
||
|
|
plan_dict = payload.get("plan")
|
||
|
|
if not step_id or not plan_dict:
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id or "?", "missing step_id or plan")
|
||
|
|
plan = Plan.from_dict(plan_dict)
|
||
|
|
step = get_step(plan, step_id)
|
||
|
|
if not step:
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, "step not found")
|
||
|
|
tool_name = step.tool_name or payload.get("tool_name")
|
||
|
|
tool_args = step.tool_args or payload.get("tool_args", {})
|
||
|
|
if not tool_name:
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, "no tool_name")
|
||
|
|
tool = self._registry.get(tool_name)
|
||
|
|
if not tool:
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, f"tool not found: {tool_name}")
|
||
|
|
|
||
|
|
# Check tool registry permissions
|
||
|
|
if not self._registry.allowed_for(tool_name, self.tool_permissions):
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, "permission denied")
|
||
|
|
|
||
|
|
# Check access control policy
|
||
|
|
if self._access_control is not None:
|
||
|
|
if not self._access_control.allowed(self.identity, tool_name, task_id):
|
||
|
|
logger.info(
|
||
|
|
"Executor access_control denied",
|
||
|
|
extra={"tool_name": tool_name, "agent_id": self.identity, "task_id": task_id},
|
||
|
|
)
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, "access control denied")
|
||
|
|
|
||
|
|
# Check rate limiter
|
||
|
|
if self._rate_limiter is not None:
|
||
|
|
rate_key = f"{self.identity}:{tool_name}"
|
||
|
|
allowed, reason = self._rate_limiter.allow(rate_key)
|
||
|
|
if not allowed:
|
||
|
|
logger.info(
|
||
|
|
"Executor rate_limiter denied",
|
||
|
|
extra={"tool_name": tool_name, "key": rate_key, "reason": reason},
|
||
|
|
)
|
||
|
|
return self._fail(task_id, envelope.message.sender, step_id, reason)
|
||
|
|
|
||
|
|
# Check guardrails pre-check
|
||
|
|
if self._guardrails is not None:
|
||
|
|
pre_result = self._guardrails.pre_check(tool_name, tool_args)
|
||
|
|
logger.info(
|
||
|
|
"Executor guardrail pre_check",
|
||
|
|
extra={"tool_name": tool_name, "allowed": pre_result.allowed},
|
||
|
|
)
|
||
|
|
if not pre_result.allowed:
|
||
|
|
return self._fail(
|
||
|
|
task_id, envelope.message.sender, step_id,
|
||
|
|
pre_result.error_message or "Guardrails pre-check failed",
|
||
|
|
)
|
||
|
|
if pre_result.sanitized_args is not None:
|
||
|
|
tool_args = pre_result.sanitized_args
|
||
|
|
|
||
|
|
# Check override hooks for high-risk operations
|
||
|
|
if self._override_hooks is not None and tool.manufacturing:
|
||
|
|
proceed = self._override_hooks.fire(
|
||
|
|
"tool_execution",
|
||
|
|
{"tool_name": tool_name, "args": tool_args, "task_id": task_id, "step_id": step_id},
|
||
|
|
)
|
||
|
|
if not proceed:
|
||
|
|
logger.info(
|
||
|
|
"Executor override_hooks blocked",
|
||
|
|
extra={"tool_name": tool_name, "step_id": step_id},
|
||
|
|
)
|
||
|
|
return self._fail(
|
||
|
|
task_id, envelope.message.sender, step_id,
|
||
|
|
"Override hook blocked execution",
|
||
|
|
)
|
||
|
|
|
||
|
|
# Execute the tool
|
||
|
|
result, log_entry = run_tool(tool, tool_args)
|
||
|
|
logger.info(
|
||
|
|
"Executor tool run",
|
||
|
|
extra={"tool_name": tool_name, "step_id": step_id, "error": log_entry.get("error")},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check guardrails post-check
|
||
|
|
if self._guardrails is not None and not log_entry.get("error"):
|
||
|
|
post_ok, post_reason = self._guardrails.post_check(tool_name, result)
|
||
|
|
if not post_ok:
|
||
|
|
log_entry["error"] = f"Post-check failed: {post_reason}"
|
||
|
|
log_entry["post_check_failed"] = True
|
||
|
|
logger.info(
|
||
|
|
"Executor guardrail post_check failed",
|
||
|
|
extra={"tool_name": tool_name, "reason": post_reason},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Record trace in state manager
|
||
|
|
if self._state:
|
||
|
|
self._state.append_trace(task_id or "", log_entry)
|
||
|
|
|
||
|
|
# Record in episodic memory
|
||
|
|
if self._episodic_memory:
|
||
|
|
self._episodic_memory.append(
|
||
|
|
task_id=task_id or "",
|
||
|
|
event={
|
||
|
|
"type": "step_execution",
|
||
|
|
"step_id": step_id,
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"success": not log_entry.get("error"),
|
||
|
|
"duration_seconds": log_entry.get("duration_seconds"),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
if log_entry.get("error"):
|
||
|
|
return self._fail(
|
||
|
|
task_id, envelope.message.sender, step_id,
|
||
|
|
log_entry["error"],
|
||
|
|
log_entry=log_entry,
|
||
|
|
)
|
||
|
|
logger.info(
|
||
|
|
"Executor response",
|
||
|
|
extra={"recipient": self.identity, "response_intent": "step_done"},
|
||
|
|
)
|
||
|
|
return AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(
|
||
|
|
sender=self.identity,
|
||
|
|
recipient=envelope.message.sender,
|
||
|
|
intent="step_done",
|
||
|
|
payload={
|
||
|
|
"step_id": step_id,
|
||
|
|
"result": result,
|
||
|
|
"log_entry": log_entry,
|
||
|
|
},
|
||
|
|
),
|
||
|
|
task_id=task_id,
|
||
|
|
correlation_id=envelope.correlation_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _fail(
|
||
|
|
self,
|
||
|
|
task_id: str | None,
|
||
|
|
recipient: str,
|
||
|
|
step_id: str,
|
||
|
|
error: str,
|
||
|
|
log_entry: dict[str, Any] | None = None,
|
||
|
|
) -> AgentMessageEnvelope:
|
||
|
|
return AgentMessageEnvelope(
|
||
|
|
message=AgentMessage(
|
||
|
|
sender=self.identity,
|
||
|
|
recipient=recipient,
|
||
|
|
intent="step_failed",
|
||
|
|
payload={
|
||
|
|
"step_id": step_id,
|
||
|
|
"error": error,
|
||
|
|
"log_entry": log_entry or {},
|
||
|
|
},
|
||
|
|
),
|
||
|
|
task_id=task_id,
|
||
|
|
)
|