247 lines
8.4 KiB
Python
247 lines
8.4 KiB
Python
"""Tests for enhanced core module functionality."""
|
|
|
|
import pytest
|
|
|
|
from fusionagi.core import (
|
|
EventBus,
|
|
StateManager,
|
|
Orchestrator,
|
|
InvalidStateTransitionError,
|
|
VALID_STATE_TRANSITIONS,
|
|
JsonFileBackend,
|
|
)
|
|
from fusionagi.schemas.task import Task, TaskState, TaskPriority
|
|
|
|
|
|
class TestStateManagerWithBackend:
|
|
"""Test StateManager with persistence backend integration."""
|
|
|
|
def test_state_manager_basic_operations(self):
|
|
"""Test basic get/set operations."""
|
|
sm = StateManager()
|
|
task = Task(task_id="test-1", goal="Test goal")
|
|
|
|
sm.set_task(task)
|
|
retrieved = sm.get_task("test-1")
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.task_id == "test-1"
|
|
assert retrieved.goal == "Test goal"
|
|
|
|
def test_state_manager_task_state(self):
|
|
"""Test task state operations."""
|
|
sm = StateManager()
|
|
task = Task(task_id="test-2", goal="Test")
|
|
sm.set_task(task)
|
|
|
|
assert sm.get_task_state("test-2") == TaskState.PENDING
|
|
|
|
sm.set_task_state("test-2", TaskState.ACTIVE)
|
|
assert sm.get_task_state("test-2") == TaskState.ACTIVE
|
|
|
|
def test_state_manager_trace(self):
|
|
"""Test trace append and retrieval."""
|
|
sm = StateManager()
|
|
task = Task(task_id="test-3", goal="Test")
|
|
sm.set_task(task)
|
|
|
|
sm.append_trace("test-3", {"step": "step1", "result": "ok"})
|
|
sm.append_trace("test-3", {"step": "step2", "result": "ok"})
|
|
|
|
trace = sm.get_trace("test-3")
|
|
assert len(trace) == 2
|
|
assert trace[0]["step"] == "step1"
|
|
assert trace[1]["step"] == "step2"
|
|
|
|
def test_state_manager_list_tasks(self):
|
|
"""Test listing tasks with filter."""
|
|
sm = StateManager()
|
|
|
|
sm.set_task(Task(task_id="t1", goal="Goal 1", state=TaskState.PENDING))
|
|
sm.set_task(Task(task_id="t2", goal="Goal 2", state=TaskState.ACTIVE))
|
|
sm.set_task(Task(task_id="t3", goal="Goal 3", state=TaskState.ACTIVE))
|
|
|
|
all_tasks = sm.list_tasks()
|
|
assert len(all_tasks) == 3
|
|
|
|
active_tasks = sm.list_tasks(state=TaskState.ACTIVE)
|
|
assert len(active_tasks) == 2
|
|
|
|
def test_state_manager_task_count(self):
|
|
"""Test task counting."""
|
|
sm = StateManager()
|
|
assert sm.task_count() == 0
|
|
|
|
sm.set_task(Task(task_id="t1", goal="Goal 1"))
|
|
sm.set_task(Task(task_id="t2", goal="Goal 2"))
|
|
|
|
assert sm.task_count() == 2
|
|
|
|
|
|
class TestJsonFileBackend:
|
|
"""Test JsonFileBackend persistence."""
|
|
|
|
def test_json_file_backend_roundtrip(self, tmp_path):
|
|
"""Test task and trace persist to JSON file."""
|
|
path = tmp_path / "state.json"
|
|
backend = JsonFileBackend(path)
|
|
task = Task(task_id="tb1", goal="Backend goal", state=TaskState.ACTIVE)
|
|
backend.set_task(task)
|
|
backend.append_trace("tb1", {"step": "s1", "result": "ok"})
|
|
assert path.exists()
|
|
backend2 = JsonFileBackend(path)
|
|
loaded = backend2.get_task("tb1")
|
|
assert loaded is not None
|
|
assert loaded.goal == "Backend goal"
|
|
assert loaded.state == TaskState.ACTIVE
|
|
trace = backend2.get_trace("tb1")
|
|
assert len(trace) == 1
|
|
assert trace[0]["step"] == "s1"
|
|
|
|
def test_json_file_backend_set_task_state(self, tmp_path):
|
|
"""Test set_task_state updates persisted task."""
|
|
path = tmp_path / "state.json"
|
|
backend = JsonFileBackend(path)
|
|
task = Task(task_id="tb2", goal="Goal", state=TaskState.PENDING)
|
|
backend.set_task(task)
|
|
backend.set_task_state("tb2", TaskState.COMPLETED)
|
|
backend2 = JsonFileBackend(path)
|
|
assert backend2.get_task_state("tb2") == TaskState.COMPLETED
|
|
|
|
|
|
class TestOrchestratorStateTransitions:
|
|
"""Test Orchestrator state transition validation."""
|
|
|
|
def test_valid_transitions(self):
|
|
"""Test valid state transitions."""
|
|
bus = EventBus()
|
|
state = StateManager()
|
|
orch = Orchestrator(event_bus=bus, state_manager=state)
|
|
|
|
task_id = orch.submit_task(goal="Test task")
|
|
|
|
# PENDING -> ACTIVE is valid
|
|
orch.set_task_state(task_id, TaskState.ACTIVE)
|
|
assert orch.get_task_state(task_id) == TaskState.ACTIVE
|
|
|
|
# ACTIVE -> COMPLETED is valid
|
|
orch.set_task_state(task_id, TaskState.COMPLETED)
|
|
assert orch.get_task_state(task_id) == TaskState.COMPLETED
|
|
|
|
def test_invalid_transition_raises(self):
|
|
"""Test that invalid transitions raise an error."""
|
|
bus = EventBus()
|
|
state = StateManager()
|
|
orch = Orchestrator(event_bus=bus, state_manager=state)
|
|
|
|
task_id = orch.submit_task(goal="Test task")
|
|
orch.set_task_state(task_id, TaskState.ACTIVE)
|
|
orch.set_task_state(task_id, TaskState.COMPLETED)
|
|
|
|
# COMPLETED -> ACTIVE is invalid (terminal state)
|
|
with pytest.raises(InvalidStateTransitionError) as exc_info:
|
|
orch.set_task_state(task_id, TaskState.ACTIVE)
|
|
|
|
assert exc_info.value.task_id == task_id
|
|
assert exc_info.value.from_state == TaskState.COMPLETED
|
|
assert exc_info.value.to_state == TaskState.ACTIVE
|
|
|
|
def test_can_transition(self):
|
|
"""Test can_transition helper method."""
|
|
bus = EventBus()
|
|
state = StateManager()
|
|
orch = Orchestrator(event_bus=bus, state_manager=state)
|
|
|
|
task_id = orch.submit_task(goal="Test task")
|
|
|
|
assert orch.can_transition(task_id, TaskState.ACTIVE) is True
|
|
assert orch.can_transition(task_id, TaskState.CANCELLED) is True
|
|
assert orch.can_transition(task_id, TaskState.COMPLETED) is False # Can't skip ACTIVE
|
|
|
|
def test_force_transition(self):
|
|
"""Test force=True bypasses validation."""
|
|
bus = EventBus()
|
|
state = StateManager()
|
|
orch = Orchestrator(event_bus=bus, state_manager=state)
|
|
|
|
task_id = orch.submit_task(goal="Test task")
|
|
orch.set_task_state(task_id, TaskState.ACTIVE)
|
|
orch.set_task_state(task_id, TaskState.COMPLETED)
|
|
|
|
# Force allows invalid transition
|
|
orch.set_task_state(task_id, TaskState.PENDING, force=True)
|
|
assert orch.get_task_state(task_id) == TaskState.PENDING
|
|
|
|
def test_failed_to_pending_retry(self):
|
|
"""Test that FAILED can transition to PENDING for retry."""
|
|
bus = EventBus()
|
|
state = StateManager()
|
|
orch = Orchestrator(event_bus=bus, state_manager=state)
|
|
|
|
task_id = orch.submit_task(goal="Test task")
|
|
orch.set_task_state(task_id, TaskState.ACTIVE)
|
|
orch.set_task_state(task_id, TaskState.FAILED)
|
|
|
|
# FAILED -> PENDING is valid (retry)
|
|
orch.set_task_state(task_id, TaskState.PENDING)
|
|
assert orch.get_task_state(task_id) == TaskState.PENDING
|
|
|
|
|
|
class TestEventBus:
|
|
"""Test EventBus functionality."""
|
|
|
|
def test_publish_subscribe(self):
|
|
"""Test basic pub/sub."""
|
|
bus = EventBus()
|
|
received = []
|
|
|
|
def handler(event_type, payload):
|
|
received.append({"type": event_type, "payload": payload})
|
|
|
|
bus.subscribe("test_event", handler)
|
|
bus.publish("test_event", {"data": "value"})
|
|
|
|
assert len(received) == 1
|
|
assert received[0]["payload"]["data"] == "value"
|
|
|
|
def test_multiple_subscribers(self):
|
|
"""Test multiple subscribers receive events."""
|
|
bus = EventBus()
|
|
received1 = []
|
|
received2 = []
|
|
|
|
bus.subscribe("test", lambda t, p: received1.append(p))
|
|
bus.subscribe("test", lambda t, p: received2.append(p))
|
|
|
|
bus.publish("test", {"n": 1})
|
|
|
|
assert len(received1) == 1
|
|
assert len(received2) == 1
|
|
|
|
def test_unsubscribe(self):
|
|
"""Test unsubscribe stops delivery."""
|
|
bus = EventBus()
|
|
received = []
|
|
|
|
def handler(t, p):
|
|
received.append(p)
|
|
|
|
bus.subscribe("test", handler)
|
|
bus.publish("test", {})
|
|
assert len(received) == 1
|
|
|
|
bus.unsubscribe("test", handler)
|
|
bus.publish("test", {})
|
|
assert len(received) == 1 # No new messages
|
|
|
|
def test_clear(self):
|
|
"""Test clear removes all subscribers."""
|
|
bus = EventBus()
|
|
received = []
|
|
|
|
bus.subscribe("test", lambda t, p: received.append(p))
|
|
bus.clear()
|
|
bus.publish("test", {})
|
|
|
|
assert len(received) == 0
|