222 lines
7.9 KiB
Python
222 lines
7.9 KiB
Python
"""Safe runner: invoke tool with timeout, input validation, and failure handling; log for replay."""
|
|
|
|
import time
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from fusionagi.governance.audit_log import AuditLog
|
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
|
from typing import Any
|
|
|
|
from fusionagi.tools.registry import ToolDef
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class ToolValidationError(Exception):
|
|
"""Raised when tool arguments fail validation."""
|
|
|
|
def __init__(self, tool_name: str, message: str, details: dict[str, Any] | None = None):
|
|
self.tool_name = tool_name
|
|
self.details = details or {}
|
|
super().__init__(f"Tool {tool_name}: {message}")
|
|
|
|
|
|
def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
|
"""
|
|
Validate arguments against tool's JSON schema.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message). error_message is empty if valid.
|
|
"""
|
|
schema = tool.parameters_schema
|
|
if not schema:
|
|
return True, ""
|
|
|
|
# Basic JSON schema validation (without external dependency)
|
|
schema_type = schema.get("type", "object")
|
|
if schema_type != "object":
|
|
return True, "" # Only validate object schemas
|
|
|
|
properties = schema.get("properties", {})
|
|
required = schema.get("required", [])
|
|
|
|
# Check required fields
|
|
for field in required:
|
|
if field not in args:
|
|
return False, f"Missing required argument: {field}"
|
|
|
|
# Check types of provided fields
|
|
for field, value in args.items():
|
|
if field not in properties:
|
|
# Allow extra fields by default (additionalProperties: true is common)
|
|
continue
|
|
|
|
prop_schema = properties[field]
|
|
prop_type = prop_schema.get("type")
|
|
|
|
if prop_type is None:
|
|
continue
|
|
|
|
# Type checking
|
|
type_valid = True
|
|
if prop_type == "string":
|
|
type_valid = isinstance(value, str)
|
|
elif prop_type == "integer":
|
|
type_valid = isinstance(value, int) and not isinstance(value, bool)
|
|
elif prop_type == "number":
|
|
type_valid = isinstance(value, (int, float)) and not isinstance(value, bool)
|
|
elif prop_type == "boolean":
|
|
type_valid = isinstance(value, bool)
|
|
elif prop_type == "array":
|
|
type_valid = isinstance(value, list)
|
|
elif prop_type == "object":
|
|
type_valid = isinstance(value, dict)
|
|
elif prop_type == "null":
|
|
type_valid = value is None
|
|
|
|
if not type_valid:
|
|
return False, f"Argument '{field}' must be of type {prop_type}, got {type(value).__name__}"
|
|
|
|
# String constraints
|
|
if prop_type == "string" and isinstance(value, str):
|
|
min_len = prop_schema.get("minLength")
|
|
max_len = prop_schema.get("maxLength")
|
|
pattern = prop_schema.get("pattern")
|
|
|
|
if min_len is not None and len(value) < min_len:
|
|
return False, f"Argument '{field}' must be at least {min_len} characters"
|
|
if max_len is not None and len(value) > max_len:
|
|
return False, f"Argument '{field}' must be at most {max_len} characters"
|
|
if pattern:
|
|
import re
|
|
if not re.match(pattern, value):
|
|
return False, f"Argument '{field}' does not match pattern: {pattern}"
|
|
|
|
# Number constraints
|
|
if prop_type in ("integer", "number") and isinstance(value, (int, float)):
|
|
minimum = prop_schema.get("minimum")
|
|
maximum = prop_schema.get("maximum")
|
|
exclusive_min = prop_schema.get("exclusiveMinimum")
|
|
exclusive_max = prop_schema.get("exclusiveMaximum")
|
|
|
|
if minimum is not None and value < minimum:
|
|
return False, f"Argument '{field}' must be >= {minimum}"
|
|
if maximum is not None and value > maximum:
|
|
return False, f"Argument '{field}' must be <= {maximum}"
|
|
if exclusive_min is not None and value <= exclusive_min:
|
|
return False, f"Argument '{field}' must be > {exclusive_min}"
|
|
if exclusive_max is not None and value >= exclusive_max:
|
|
return False, f"Argument '{field}' must be < {exclusive_max}"
|
|
|
|
# Enum constraint
|
|
enum = prop_schema.get("enum")
|
|
if enum is not None and value not in enum:
|
|
return False, f"Argument '{field}' must be one of: {enum}"
|
|
|
|
return True, ""
|
|
|
|
|
|
def run_tool(
|
|
tool: ToolDef,
|
|
args: dict[str, Any],
|
|
timeout_seconds: float | None = None,
|
|
validate: bool = True,
|
|
) -> tuple[Any, dict[str, Any]]:
|
|
"""
|
|
Invoke tool.fn(args) with optional validation and timeout.
|
|
|
|
Args:
|
|
tool: The tool definition to execute.
|
|
args: Arguments to pass to the tool function.
|
|
timeout_seconds: Override timeout (uses tool.timeout_seconds if None).
|
|
validate: Whether to validate args against tool's schema (default True).
|
|
|
|
Returns:
|
|
Tuple of (result, log_entry). On error, result is None and log_entry contains error.
|
|
"""
|
|
timeout = timeout_seconds if timeout_seconds is not None else tool.timeout_seconds
|
|
start = time.monotonic()
|
|
log_entry: dict[str, Any] = {
|
|
"tool": tool.name,
|
|
"args": args,
|
|
"result": None,
|
|
"error": None,
|
|
"duration_seconds": None,
|
|
"validated": validate,
|
|
}
|
|
|
|
# Validate arguments before execution
|
|
if validate:
|
|
is_valid, error_msg = validate_args(tool, args)
|
|
if not is_valid:
|
|
log_entry["error"] = f"Validation error: {error_msg}"
|
|
log_entry["duration_seconds"] = round(time.monotonic() - start, 3)
|
|
logger.warning(
|
|
"Tool validation failed",
|
|
extra={"tool": tool.name, "error": error_msg},
|
|
)
|
|
return None, log_entry
|
|
|
|
def _invoke() -> Any:
|
|
return tool.fn(**args)
|
|
|
|
try:
|
|
with ThreadPoolExecutor(max_workers=1) as ex:
|
|
fut = ex.submit(_invoke)
|
|
result = fut.result(timeout=timeout if timeout and timeout > 0 else None)
|
|
log_entry["result"] = result
|
|
logger.debug(
|
|
"Tool executed successfully",
|
|
extra={"tool": tool.name, "duration": log_entry.get("duration_seconds")},
|
|
)
|
|
return result, log_entry
|
|
except FuturesTimeoutError:
|
|
log_entry["error"] = f"Tool {tool.name} timed out after {timeout}s"
|
|
logger.warning(
|
|
"Tool timed out",
|
|
extra={"tool": tool.name, "timeout": timeout},
|
|
)
|
|
return None, log_entry
|
|
except Exception as e:
|
|
log_entry["error"] = str(e)
|
|
logger.error(
|
|
"Tool execution failed",
|
|
extra={"tool": tool.name, "error": str(e), "error_type": type(e).__name__},
|
|
)
|
|
return None, log_entry
|
|
finally:
|
|
log_entry["duration_seconds"] = round(time.monotonic() - start, 3)
|
|
|
|
|
|
def run_tool_with_audit(
|
|
tool: ToolDef,
|
|
args: dict[str, Any],
|
|
audit_log: "AuditLog",
|
|
actor: str = "system",
|
|
task_id: str | None = None,
|
|
timeout_seconds: float | None = None,
|
|
validate: bool = True,
|
|
) -> tuple[Any, dict[str, Any]]:
|
|
"""
|
|
Invoke tool and log to AuditLog.
|
|
Sanitizes args in log (e.g. truncate long values).
|
|
"""
|
|
from fusionagi.schemas.audit import AuditEventType
|
|
|
|
sanitized = {}
|
|
for k, v in args.items():
|
|
if isinstance(v, str) and len(v) > 200:
|
|
sanitized[k] = v[:200] + "..."
|
|
else:
|
|
sanitized[k] = v
|
|
result, log_entry = run_tool(tool, args, timeout_seconds, validate)
|
|
audit_log.append(
|
|
AuditEventType.TOOL_CALL,
|
|
actor,
|
|
action=f"tool:{tool.name}",
|
|
task_id=task_id,
|
|
payload={"tool": tool.name, "args": sanitized, "error": log_entry.get("error")},
|
|
outcome="success" if result is not None else "failure",
|
|
)
|
|
return result, log_entry
|