Files
FusionAGI/tests/test_tools_runner.py

333 lines
10 KiB
Python
Raw Permalink Normal View History

"""Tests for tools runner and builtins."""
import pytest
import os
import tempfile
from fusionagi.tools.registry import ToolDef, ToolRegistry
from fusionagi.tools.runner import run_tool, validate_args, ToolValidationError
from fusionagi.tools.builtins import (
make_file_read_tool,
make_file_write_tool,
make_http_get_tool,
_validate_url,
SSRFProtectionError,
)
class TestToolRunner:
"""Test tool runner functionality."""
def test_run_tool_success(self):
"""Test successful tool execution."""
def add(a: int, b: int) -> int:
return a + b
tool = ToolDef(
name="add",
description="Add two numbers",
fn=add,
parameters_schema={
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"},
},
"required": ["a", "b"],
},
)
result, log = run_tool(tool, {"a": 2, "b": 3})
assert result == 5
assert log["result"] == 5
assert log["error"] is None
def test_run_tool_timeout(self):
"""Test tool timeout handling."""
import time
def slow_fn() -> str:
time.sleep(2)
return "done"
tool = ToolDef(
name="slow",
description="Slow function",
fn=slow_fn,
timeout_seconds=0.1,
)
result, log = run_tool(tool, {})
assert result is None
assert "timed out" in log["error"]
def test_run_tool_exception(self):
"""Test tool exception handling."""
def failing_fn() -> None:
raise ValueError("Something went wrong")
tool = ToolDef(
name="fail",
description="Failing function",
fn=failing_fn,
)
result, log = run_tool(tool, {})
assert result is None
assert "Something went wrong" in log["error"]
class TestArgValidation:
"""Test argument validation."""
def test_validate_required_fields(self):
"""Test validation of required fields."""
tool = ToolDef(
name="test",
description="Test",
fn=lambda: None,
parameters_schema={
"type": "object",
"properties": {
"required_field": {"type": "string"},
},
"required": ["required_field"],
},
)
# Missing required field
is_valid, error = validate_args(tool, {})
assert not is_valid
assert "required_field" in error
# With required field
is_valid, error = validate_args(tool, {"required_field": "value"})
assert is_valid
def test_validate_string_type(self):
"""Test string type validation."""
tool = ToolDef(
name="test",
description="Test",
fn=lambda: None,
parameters_schema={
"type": "object",
"properties": {
"name": {"type": "string"},
},
},
)
is_valid, _ = validate_args(tool, {"name": "hello"})
assert is_valid
is_valid, error = validate_args(tool, {"name": 123})
assert not is_valid
assert "string" in error
def test_validate_number_constraints(self):
"""Test number constraint validation."""
tool = ToolDef(
name="test",
description="Test",
fn=lambda: None,
parameters_schema={
"type": "object",
"properties": {
"score": {
"type": "number",
"minimum": 0,
"maximum": 100,
},
},
},
)
is_valid, _ = validate_args(tool, {"score": 50})
assert is_valid
is_valid, error = validate_args(tool, {"score": -1})
assert not is_valid
assert ">=" in error
is_valid, error = validate_args(tool, {"score": 101})
assert not is_valid
assert "<=" in error
def test_validate_enum(self):
"""Test enum constraint validation."""
tool = ToolDef(
name="test",
description="Test",
fn=lambda: None,
parameters_schema={
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["pending", "active", "done"],
},
},
},
)
is_valid, _ = validate_args(tool, {"status": "active"})
assert is_valid
is_valid, error = validate_args(tool, {"status": "invalid"})
assert not is_valid
assert "one of" in error
def test_validate_with_tool_runner(self):
"""Test validation integration with run_tool."""
tool = ToolDef(
name="test",
description="Test",
fn=lambda x: x,
parameters_schema={
"type": "object",
"properties": {
"x": {"type": "integer"},
},
"required": ["x"],
},
)
# Invalid args should fail validation
result, log = run_tool(tool, {"x": "not an int"}, validate=True)
assert result is None
assert "Validation error" in log["error"]
# Skip validation
result, log = run_tool(tool, {"x": "not an int"}, validate=False)
# Execution may fail, but not due to validation
assert "Validation error" not in (log.get("error") or "")
class TestToolRegistry:
"""Test tool registry functionality."""
def test_register_and_get(self):
"""Test registering and retrieving tools."""
registry = ToolRegistry()
tool = ToolDef(name="test", description="Test", fn=lambda: None)
registry.register(tool)
retrieved = registry.get("test")
assert retrieved is not None
assert retrieved.name == "test"
def test_list_tools(self):
"""Test listing all tools."""
registry = ToolRegistry()
registry.register(ToolDef(name="t1", description="Tool 1", fn=lambda: None))
registry.register(ToolDef(name="t2", description="Tool 2", fn=lambda: None))
tools = registry.list_tools()
assert len(tools) == 2
names = {t["name"] for t in tools}
assert names == {"t1", "t2"}
def test_permission_check(self):
"""Test permission checking."""
registry = ToolRegistry()
tool = ToolDef(
name="restricted",
description="Restricted tool",
fn=lambda: None,
permission_scope=["admin", "write"],
)
registry.register(tool)
# Has matching permission
assert registry.allowed_for("restricted", ["admin"])
assert registry.allowed_for("restricted", ["write"])
# No matching permission
assert not registry.allowed_for("restricted", ["read"])
# Wildcard permissions
assert registry.allowed_for("restricted", ["*"])
class TestSSRFProtection:
"""Test SSRF protection in URL validation."""
def test_localhost_blocked(self):
"""Test that localhost URLs are blocked."""
with pytest.raises(SSRFProtectionError, match="Localhost"):
_validate_url("http://localhost/path")
with pytest.raises(SSRFProtectionError, match="Localhost"):
_validate_url("http://127.0.0.1/path")
def test_private_ip_blocked(self):
"""Test that private IPs are blocked after DNS resolution."""
# Note: This test may pass or fail depending on DNS resolution
# Testing the concept with a known internal hostname pattern
with pytest.raises(SSRFProtectionError):
_validate_url("http://test.local/path")
def test_non_http_scheme_blocked(self):
"""Test that non-HTTP schemes are blocked."""
with pytest.raises(SSRFProtectionError, match="scheme"):
_validate_url("file:///etc/passwd")
with pytest.raises(SSRFProtectionError, match="scheme"):
_validate_url("ftp://example.com/file")
def test_valid_url_passes(self):
"""Test that valid public URLs pass."""
# This should not raise
url = _validate_url("https://example.com/path")
assert url == "https://example.com/path"
class TestFileTools:
"""Test file read/write tools."""
def test_file_read_in_scope(self):
"""Test reading a file within scope."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a test file
test_file = os.path.join(tmpdir, "test.txt")
with open(test_file, "w") as f:
f.write("Hello, World!")
tool = make_file_read_tool(scope=tmpdir)
result, log = run_tool(tool, {"path": test_file})
assert result == "Hello, World!"
assert log["error"] is None
def test_file_read_outside_scope(self):
"""Test reading a file outside scope is blocked."""
with tempfile.TemporaryDirectory() as tmpdir:
tool = make_file_read_tool(scope=tmpdir)
# Try to read file outside scope
result, log = run_tool(tool, {"path": "/etc/passwd"})
assert result is None
assert "not allowed" in log["error"].lower() or "permission" in log["error"].lower()
def test_file_write_in_scope(self):
"""Test writing a file within scope."""
with tempfile.TemporaryDirectory() as tmpdir:
tool = make_file_write_tool(scope=tmpdir)
test_file = os.path.join(tmpdir, "output.txt")
result, log = run_tool(tool, {"path": test_file, "content": "Test content"})
assert log["error"] is None
assert os.path.exists(test_file)
with open(test_file) as f:
assert f.read() == "Test content"