"""Built-in tools: file read (scoped), HTTP GET (with SSRF protection), query state.""" import ipaddress import os import socket from typing import Any, Callable from urllib.parse import urlparse from fusionagi.tools.registry import ToolDef from fusionagi._logger import logger # Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env) # and not rely on cwd in production. DEFAULT_FILE_SCOPE = os.path.abspath(os.getcwd()) # Maximum file size for read/write operations (10MB) MAX_FILE_SIZE = 10 * 1024 * 1024 class SSRFProtectionError(Exception): """Raised when a URL is blocked for SSRF protection.""" pass class FileSizeError(Exception): """Raised when file size exceeds limit.""" pass def _normalize_path(path: str, scope: str) -> str: """ Normalize and validate a file path against scope. Resolves symlinks and prevents path traversal attacks. """ # Resolve to absolute path abs_path = os.path.abspath(path) # Resolve symlinks to get the real path try: real_path = os.path.realpath(abs_path) except OSError: real_path = abs_path # Normalize scope too real_scope = os.path.realpath(os.path.abspath(scope)) # Check if path is under scope if not real_path.startswith(real_scope + os.sep) and real_path != real_scope: raise PermissionError(f"Path not allowed: {path} resolves outside {scope}") return real_path def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str: """ Read file content; path must be under scope. Args: path: File path to read. scope: Allowed directory scope. max_size: Maximum file size in bytes. Returns: File contents as string. Raises: PermissionError: If path is outside scope. FileSizeError: If file exceeds max_size. """ real_path = _normalize_path(path, scope) # Check file size before reading try: file_size = os.path.getsize(real_path) if file_size > max_size: raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})") except OSError as e: raise PermissionError(f"Cannot access file: {e}") with open(real_path, "r", encoding="utf-8", errors="replace") as f: return f.read() def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str: """ Write content to file; path must be under scope. Args: path: File path to write. content: Content to write. scope: Allowed directory scope. max_size: Maximum content size in bytes. Returns: Success message with byte count. Raises: PermissionError: If path is outside scope. FileSizeError: If content exceeds max_size. """ # Check content size before writing content_bytes = len(content.encode("utf-8")) if content_bytes > max_size: raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})") real_path = _normalize_path(path, scope) # Ensure parent directory exists parent_dir = os.path.dirname(real_path) if parent_dir and not os.path.exists(parent_dir): # Check if parent would be under scope _normalize_path(parent_dir, scope) os.makedirs(parent_dir, exist_ok=True) with open(real_path, "w", encoding="utf-8") as f: f.write(content) return f"Wrote {content_bytes} bytes to {real_path}" def _is_private_ip(ip: str) -> bool: """Check if an IP address is private, loopback, or otherwise unsafe.""" try: addr = ipaddress.ip_address(ip) return ( addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_multicast or addr.is_reserved or addr.is_unspecified # Block IPv6 mapped IPv4 addresses or (isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None) ) except ValueError: return True # Invalid IP is treated as unsafe def _validate_url(url: str, allow_private: bool = False) -> str: """ Validate a URL for SSRF protection. Args: url: URL to validate. allow_private: If True, allow private/internal IPs (default False). Returns: The validated URL. Raises: SSRFProtectionError: If URL is blocked for security reasons. """ try: parsed = urlparse(url) except Exception as e: raise SSRFProtectionError(f"Invalid URL: {e}") # Only allow HTTP and HTTPS if parsed.scheme not in ("http", "https"): raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}") # Must have a hostname hostname = parsed.hostname if not hostname: raise SSRFProtectionError("URL must have a hostname") # Block localhost variants localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"] if hostname.lower() in localhost_patterns: raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}") # Block common internal hostnames internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"] for pattern in internal_patterns: if hostname.lower().endswith(pattern): raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}") if not allow_private: # Resolve hostname and check if IP is private try: # Get all IP addresses for the hostname ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80)) for family, socktype, proto, canonname, sockaddr in ips: ip = sockaddr[0] if _is_private_ip(ip): raise SSRFProtectionError(f"URL resolves to private IP: {ip}") except socket.gaierror as e: # DNS resolution failed - could be a security issue logger.warning(f"DNS resolution failed for {hostname}: {e}") raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}") return url def _http_get(url: str, allow_private: bool = False) -> str: """ Simple HTTP GET with SSRF protection. Args: url: URL to fetch. allow_private: If True, allow private/internal IPs (default False). Returns: Response text. On failure returns a string starting with 'Error: '. """ try: validated_url = _validate_url(url, allow_private=allow_private) except SSRFProtectionError as e: return f"Error: SSRF protection: {e}" try: import urllib.request with urllib.request.urlopen(validated_url, timeout=10) as resp: return resp.read().decode("utf-8", errors="replace") except Exception as e: return f"Error: {e}" def make_file_read_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File read tool with path scope.""" def fn(path: str) -> str: return _file_read(path, scope=scope) return ToolDef( name="file_read", description="Read file content; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": {"path": {"type": "string", "description": "File path"}}, "required": ["path"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_file_write_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File write tool with path scope.""" def fn(path: str, content: str) -> str: return _file_write(path, content, scope=scope) return ToolDef( name="file_write", description="Write content to file; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "File path"}, "content": {"type": "string", "description": "Content to write"}, }, "required": ["path", "content"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_http_get_tool() -> ToolDef: """HTTP GET tool.""" return ToolDef( name="http_get", description="Perform HTTP GET request and return response body", fn=_http_get, parameters_schema={ "type": "object", "properties": {"url": {"type": "string", "description": "URL to fetch"}}, "required": ["url"], }, permission_scope=["network"], timeout_seconds=15.0, ) def make_query_state_tool(get_state_fn: Callable[[str], Any]) -> ToolDef: """Internal tool: query task state (injected get_state_fn(task_id) -> state/trace).""" def fn(task_id: str) -> Any: return get_state_fn(task_id) return ToolDef( name="query_state", description="Query task state and trace (internal)", fn=fn, parameters_schema={ "type": "object", "properties": {"task_id": {"type": "string"}}, "required": ["task_id"], }, permission_scope=["internal"], timeout_seconds=2.0, )