Files
FusionAGI/fusionagi/tools/builtins.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

292 lines
9.2 KiB
Python

"""Built-in tools: file read (scoped), HTTP GET (with SSRF protection), query state."""
import ipaddress
import os
import socket
from typing import Any, Callable
from urllib.parse import urlparse
from fusionagi.tools.registry import ToolDef
from fusionagi._logger import logger
# Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env)
# and not rely on cwd in production.
DEFAULT_FILE_SCOPE = os.path.abspath(os.getcwd())
# Maximum file size for read/write operations (10MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
class SSRFProtectionError(Exception):
"""Raised when a URL is blocked for SSRF protection."""
pass
class FileSizeError(Exception):
"""Raised when file size exceeds limit."""
pass
def _normalize_path(path: str, scope: str) -> str:
"""
Normalize and validate a file path against scope.
Resolves symlinks and prevents path traversal attacks.
"""
# Resolve to absolute path
abs_path = os.path.abspath(path)
# Resolve symlinks to get the real path
try:
real_path = os.path.realpath(abs_path)
except OSError:
real_path = abs_path
# Normalize scope too
real_scope = os.path.realpath(os.path.abspath(scope))
# Check if path is under scope
if not real_path.startswith(real_scope + os.sep) and real_path != real_scope:
raise PermissionError(f"Path not allowed: {path} resolves outside {scope}")
return real_path
def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
"""
Read file content; path must be under scope.
Args:
path: File path to read.
scope: Allowed directory scope.
max_size: Maximum file size in bytes.
Returns:
File contents as string.
Raises:
PermissionError: If path is outside scope.
FileSizeError: If file exceeds max_size.
"""
real_path = _normalize_path(path, scope)
# Check file size before reading
try:
file_size = os.path.getsize(real_path)
if file_size > max_size:
raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})")
except OSError as e:
raise PermissionError(f"Cannot access file: {e}")
with open(real_path, "r", encoding="utf-8", errors="replace") as f:
return f.read()
def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
"""
Write content to file; path must be under scope.
Args:
path: File path to write.
content: Content to write.
scope: Allowed directory scope.
max_size: Maximum content size in bytes.
Returns:
Success message with byte count.
Raises:
PermissionError: If path is outside scope.
FileSizeError: If content exceeds max_size.
"""
# Check content size before writing
content_bytes = len(content.encode("utf-8"))
if content_bytes > max_size:
raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})")
real_path = _normalize_path(path, scope)
# Ensure parent directory exists
parent_dir = os.path.dirname(real_path)
if parent_dir and not os.path.exists(parent_dir):
# Check if parent would be under scope
_normalize_path(parent_dir, scope)
os.makedirs(parent_dir, exist_ok=True)
with open(real_path, "w", encoding="utf-8") as f:
f.write(content)
return f"Wrote {content_bytes} bytes to {real_path}"
def _is_private_ip(ip: str) -> bool:
"""Check if an IP address is private, loopback, or otherwise unsafe."""
try:
addr = ipaddress.ip_address(ip)
return (
addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_multicast
or addr.is_reserved
or addr.is_unspecified
# Block IPv6 mapped IPv4 addresses
or (isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None)
)
except ValueError:
return True # Invalid IP is treated as unsafe
def _validate_url(url: str, allow_private: bool = False) -> str:
"""
Validate a URL for SSRF protection.
Args:
url: URL to validate.
allow_private: If True, allow private/internal IPs (default False).
Returns:
The validated URL.
Raises:
SSRFProtectionError: If URL is blocked for security reasons.
"""
try:
parsed = urlparse(url)
except Exception as e:
raise SSRFProtectionError(f"Invalid URL: {e}")
# Only allow HTTP and HTTPS
if parsed.scheme not in ("http", "https"):
raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}")
# Must have a hostname
hostname = parsed.hostname
if not hostname:
raise SSRFProtectionError("URL must have a hostname")
# Block localhost variants
localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"]
if hostname.lower() in localhost_patterns:
raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}")
# Block common internal hostnames
internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"]
for pattern in internal_patterns:
if hostname.lower().endswith(pattern):
raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}")
if not allow_private:
# Resolve hostname and check if IP is private
try:
# Get all IP addresses for the hostname
ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80))
for family, socktype, proto, canonname, sockaddr in ips:
ip = sockaddr[0]
if _is_private_ip(ip):
raise SSRFProtectionError(f"URL resolves to private IP: {ip}")
except socket.gaierror as e:
# DNS resolution failed - could be a security issue
logger.warning(f"DNS resolution failed for {hostname}: {e}")
raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}")
return url
def _http_get(url: str, allow_private: bool = False) -> str:
"""
Simple HTTP GET with SSRF protection.
Args:
url: URL to fetch.
allow_private: If True, allow private/internal IPs (default False).
Returns:
Response text. On failure returns a string starting with 'Error: '.
"""
try:
validated_url = _validate_url(url, allow_private=allow_private)
except SSRFProtectionError as e:
return f"Error: SSRF protection: {e}"
try:
import urllib.request
with urllib.request.urlopen(validated_url, timeout=10) as resp:
return resp.read().decode("utf-8", errors="replace")
except Exception as e:
return f"Error: {e}"
def make_file_read_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef:
"""File read tool with path scope."""
def fn(path: str) -> str:
return _file_read(path, scope=scope)
return ToolDef(
name="file_read",
description="Read file content; path must be under allowed scope",
fn=fn,
parameters_schema={
"type": "object",
"properties": {"path": {"type": "string", "description": "File path"}},
"required": ["path"],
},
permission_scope=["file"],
timeout_seconds=5.0,
)
def make_file_write_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef:
"""File write tool with path scope."""
def fn(path: str, content: str) -> str:
return _file_write(path, content, scope=scope)
return ToolDef(
name="file_write",
description="Write content to file; path must be under allowed scope",
fn=fn,
parameters_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path"},
"content": {"type": "string", "description": "Content to write"},
},
"required": ["path", "content"],
},
permission_scope=["file"],
timeout_seconds=5.0,
)
def make_http_get_tool() -> ToolDef:
"""HTTP GET tool."""
return ToolDef(
name="http_get",
description="Perform HTTP GET request and return response body",
fn=_http_get,
parameters_schema={
"type": "object",
"properties": {"url": {"type": "string", "description": "URL to fetch"}},
"required": ["url"],
},
permission_scope=["network"],
timeout_seconds=15.0,
)
def make_query_state_tool(get_state_fn: Callable[[str], Any]) -> ToolDef:
"""Internal tool: query task state (injected get_state_fn(task_id) -> state/trace)."""
def fn(task_id: str) -> Any:
return get_state_fn(task_id)
return ToolDef(
name="query_state",
description="Query task state and trace (internal)",
fn=fn,
parameters_schema={
"type": "object",
"properties": {"task_id": {"type": "string"}},
"required": ["task_id"],
},
permission_scope=["internal"],
timeout_seconds=2.0,
)