"""Tests for tools runner and builtins.""" import pytest import os import tempfile from fusionagi.tools.registry import ToolDef, ToolRegistry from fusionagi.tools.runner import run_tool, validate_args, ToolValidationError from fusionagi.tools.builtins import ( make_file_read_tool, make_file_write_tool, make_http_get_tool, _validate_url, SSRFProtectionError, ) class TestToolRunner: """Test tool runner functionality.""" def test_run_tool_success(self): """Test successful tool execution.""" def add(a: int, b: int) -> int: return a + b tool = ToolDef( name="add", description="Add two numbers", fn=add, parameters_schema={ "type": "object", "properties": { "a": {"type": "integer"}, "b": {"type": "integer"}, }, "required": ["a", "b"], }, ) result, log = run_tool(tool, {"a": 2, "b": 3}) assert result == 5 assert log["result"] == 5 assert log["error"] is None def test_run_tool_timeout(self): """Test tool timeout handling.""" import time def slow_fn() -> str: time.sleep(2) return "done" tool = ToolDef( name="slow", description="Slow function", fn=slow_fn, timeout_seconds=0.1, ) result, log = run_tool(tool, {}) assert result is None assert "timed out" in log["error"] def test_run_tool_exception(self): """Test tool exception handling.""" def failing_fn() -> None: raise ValueError("Something went wrong") tool = ToolDef( name="fail", description="Failing function", fn=failing_fn, ) result, log = run_tool(tool, {}) assert result is None assert "Something went wrong" in log["error"] class TestArgValidation: """Test argument validation.""" def test_validate_required_fields(self): """Test validation of required fields.""" tool = ToolDef( name="test", description="Test", fn=lambda: None, parameters_schema={ "type": "object", "properties": { "required_field": {"type": "string"}, }, "required": ["required_field"], }, ) # Missing required field is_valid, error = validate_args(tool, {}) assert not is_valid assert "required_field" in error # With required field is_valid, error = validate_args(tool, {"required_field": "value"}) assert is_valid def test_validate_string_type(self): """Test string type validation.""" tool = ToolDef( name="test", description="Test", fn=lambda: None, parameters_schema={ "type": "object", "properties": { "name": {"type": "string"}, }, }, ) is_valid, _ = validate_args(tool, {"name": "hello"}) assert is_valid is_valid, error = validate_args(tool, {"name": 123}) assert not is_valid assert "string" in error def test_validate_number_constraints(self): """Test number constraint validation.""" tool = ToolDef( name="test", description="Test", fn=lambda: None, parameters_schema={ "type": "object", "properties": { "score": { "type": "number", "minimum": 0, "maximum": 100, }, }, }, ) is_valid, _ = validate_args(tool, {"score": 50}) assert is_valid is_valid, error = validate_args(tool, {"score": -1}) assert not is_valid assert ">=" in error is_valid, error = validate_args(tool, {"score": 101}) assert not is_valid assert "<=" in error def test_validate_enum(self): """Test enum constraint validation.""" tool = ToolDef( name="test", description="Test", fn=lambda: None, parameters_schema={ "type": "object", "properties": { "status": { "type": "string", "enum": ["pending", "active", "done"], }, }, }, ) is_valid, _ = validate_args(tool, {"status": "active"}) assert is_valid is_valid, error = validate_args(tool, {"status": "invalid"}) assert not is_valid assert "one of" in error def test_validate_with_tool_runner(self): """Test validation integration with run_tool.""" tool = ToolDef( name="test", description="Test", fn=lambda x: x, parameters_schema={ "type": "object", "properties": { "x": {"type": "integer"}, }, "required": ["x"], }, ) # Invalid args should fail validation result, log = run_tool(tool, {"x": "not an int"}, validate=True) assert result is None assert "Validation error" in log["error"] # Skip validation result, log = run_tool(tool, {"x": "not an int"}, validate=False) # Execution may fail, but not due to validation assert "Validation error" not in (log.get("error") or "") class TestToolRegistry: """Test tool registry functionality.""" def test_register_and_get(self): """Test registering and retrieving tools.""" registry = ToolRegistry() tool = ToolDef(name="test", description="Test", fn=lambda: None) registry.register(tool) retrieved = registry.get("test") assert retrieved is not None assert retrieved.name == "test" def test_list_tools(self): """Test listing all tools.""" registry = ToolRegistry() registry.register(ToolDef(name="t1", description="Tool 1", fn=lambda: None)) registry.register(ToolDef(name="t2", description="Tool 2", fn=lambda: None)) tools = registry.list_tools() assert len(tools) == 2 names = {t["name"] for t in tools} assert names == {"t1", "t2"} def test_permission_check(self): """Test permission checking.""" registry = ToolRegistry() tool = ToolDef( name="restricted", description="Restricted tool", fn=lambda: None, permission_scope=["admin", "write"], ) registry.register(tool) # Has matching permission assert registry.allowed_for("restricted", ["admin"]) assert registry.allowed_for("restricted", ["write"]) # No matching permission assert not registry.allowed_for("restricted", ["read"]) # Wildcard permissions assert registry.allowed_for("restricted", ["*"]) class TestSSRFProtection: """Test SSRF protection in URL validation.""" def test_localhost_blocked(self): """Test that localhost URLs are blocked.""" with pytest.raises(SSRFProtectionError, match="Localhost"): _validate_url("http://localhost/path") with pytest.raises(SSRFProtectionError, match="Localhost"): _validate_url("http://127.0.0.1/path") def test_private_ip_blocked(self): """Test that private IPs are blocked after DNS resolution.""" # Note: This test may pass or fail depending on DNS resolution # Testing the concept with a known internal hostname pattern with pytest.raises(SSRFProtectionError): _validate_url("http://test.local/path") def test_non_http_scheme_blocked(self): """Test that non-HTTP schemes are blocked.""" with pytest.raises(SSRFProtectionError, match="scheme"): _validate_url("file:///etc/passwd") with pytest.raises(SSRFProtectionError, match="scheme"): _validate_url("ftp://example.com/file") def test_valid_url_passes(self): """Test that valid public URLs pass.""" # This should not raise url = _validate_url("https://example.com/path") assert url == "https://example.com/path" class TestFileTools: """Test file read/write tools.""" def test_file_read_in_scope(self): """Test reading a file within scope.""" with tempfile.TemporaryDirectory() as tmpdir: # Create a test file test_file = os.path.join(tmpdir, "test.txt") with open(test_file, "w") as f: f.write("Hello, World!") tool = make_file_read_tool(scope=tmpdir) result, log = run_tool(tool, {"path": test_file}) assert result == "Hello, World!" assert log["error"] is None def test_file_read_outside_scope(self): """Test reading a file outside scope is blocked.""" with tempfile.TemporaryDirectory() as tmpdir: tool = make_file_read_tool(scope=tmpdir) # Try to read file outside scope result, log = run_tool(tool, {"path": "/etc/passwd"}) assert result is None assert "not allowed" in log["error"].lower() or "permission" in log["error"].lower() def test_file_write_in_scope(self): """Test writing a file within scope.""" with tempfile.TemporaryDirectory() as tmpdir: tool = make_file_write_tool(scope=tmpdir) test_file = os.path.join(tmpdir, "output.txt") result, log = run_tool(tool, {"path": test_file, "content": "Test content"}) assert log["error"] is None assert os.path.exists(test_file) with open(test_file) as f: assert f.read() == "Test content"