"""Tests for enhanced core module functionality.""" import pytest from fusionagi.core import ( EventBus, StateManager, Orchestrator, InvalidStateTransitionError, VALID_STATE_TRANSITIONS, JsonFileBackend, ) from fusionagi.schemas.task import Task, TaskState, TaskPriority class TestStateManagerWithBackend: """Test StateManager with persistence backend integration.""" def test_state_manager_basic_operations(self): """Test basic get/set operations.""" sm = StateManager() task = Task(task_id="test-1", goal="Test goal") sm.set_task(task) retrieved = sm.get_task("test-1") assert retrieved is not None assert retrieved.task_id == "test-1" assert retrieved.goal == "Test goal" def test_state_manager_task_state(self): """Test task state operations.""" sm = StateManager() task = Task(task_id="test-2", goal="Test") sm.set_task(task) assert sm.get_task_state("test-2") == TaskState.PENDING sm.set_task_state("test-2", TaskState.ACTIVE) assert sm.get_task_state("test-2") == TaskState.ACTIVE def test_state_manager_trace(self): """Test trace append and retrieval.""" sm = StateManager() task = Task(task_id="test-3", goal="Test") sm.set_task(task) sm.append_trace("test-3", {"step": "step1", "result": "ok"}) sm.append_trace("test-3", {"step": "step2", "result": "ok"}) trace = sm.get_trace("test-3") assert len(trace) == 2 assert trace[0]["step"] == "step1" assert trace[1]["step"] == "step2" def test_state_manager_list_tasks(self): """Test listing tasks with filter.""" sm = StateManager() sm.set_task(Task(task_id="t1", goal="Goal 1", state=TaskState.PENDING)) sm.set_task(Task(task_id="t2", goal="Goal 2", state=TaskState.ACTIVE)) sm.set_task(Task(task_id="t3", goal="Goal 3", state=TaskState.ACTIVE)) all_tasks = sm.list_tasks() assert len(all_tasks) == 3 active_tasks = sm.list_tasks(state=TaskState.ACTIVE) assert len(active_tasks) == 2 def test_state_manager_task_count(self): """Test task counting.""" sm = StateManager() assert sm.task_count() == 0 sm.set_task(Task(task_id="t1", goal="Goal 1")) sm.set_task(Task(task_id="t2", goal="Goal 2")) assert sm.task_count() == 2 class TestJsonFileBackend: """Test JsonFileBackend persistence.""" def test_json_file_backend_roundtrip(self, tmp_path): """Test task and trace persist to JSON file.""" path = tmp_path / "state.json" backend = JsonFileBackend(path) task = Task(task_id="tb1", goal="Backend goal", state=TaskState.ACTIVE) backend.set_task(task) backend.append_trace("tb1", {"step": "s1", "result": "ok"}) assert path.exists() backend2 = JsonFileBackend(path) loaded = backend2.get_task("tb1") assert loaded is not None assert loaded.goal == "Backend goal" assert loaded.state == TaskState.ACTIVE trace = backend2.get_trace("tb1") assert len(trace) == 1 assert trace[0]["step"] == "s1" def test_json_file_backend_set_task_state(self, tmp_path): """Test set_task_state updates persisted task.""" path = tmp_path / "state.json" backend = JsonFileBackend(path) task = Task(task_id="tb2", goal="Goal", state=TaskState.PENDING) backend.set_task(task) backend.set_task_state("tb2", TaskState.COMPLETED) backend2 = JsonFileBackend(path) assert backend2.get_task_state("tb2") == TaskState.COMPLETED class TestOrchestratorStateTransitions: """Test Orchestrator state transition validation.""" def test_valid_transitions(self): """Test valid state transitions.""" bus = EventBus() state = StateManager() orch = Orchestrator(event_bus=bus, state_manager=state) task_id = orch.submit_task(goal="Test task") # PENDING -> ACTIVE is valid orch.set_task_state(task_id, TaskState.ACTIVE) assert orch.get_task_state(task_id) == TaskState.ACTIVE # ACTIVE -> COMPLETED is valid orch.set_task_state(task_id, TaskState.COMPLETED) assert orch.get_task_state(task_id) == TaskState.COMPLETED def test_invalid_transition_raises(self): """Test that invalid transitions raise an error.""" bus = EventBus() state = StateManager() orch = Orchestrator(event_bus=bus, state_manager=state) task_id = orch.submit_task(goal="Test task") orch.set_task_state(task_id, TaskState.ACTIVE) orch.set_task_state(task_id, TaskState.COMPLETED) # COMPLETED -> ACTIVE is invalid (terminal state) with pytest.raises(InvalidStateTransitionError) as exc_info: orch.set_task_state(task_id, TaskState.ACTIVE) assert exc_info.value.task_id == task_id assert exc_info.value.from_state == TaskState.COMPLETED assert exc_info.value.to_state == TaskState.ACTIVE def test_can_transition(self): """Test can_transition helper method.""" bus = EventBus() state = StateManager() orch = Orchestrator(event_bus=bus, state_manager=state) task_id = orch.submit_task(goal="Test task") assert orch.can_transition(task_id, TaskState.ACTIVE) is True assert orch.can_transition(task_id, TaskState.CANCELLED) is True assert orch.can_transition(task_id, TaskState.COMPLETED) is False # Can't skip ACTIVE def test_force_transition(self): """Test force=True bypasses validation.""" bus = EventBus() state = StateManager() orch = Orchestrator(event_bus=bus, state_manager=state) task_id = orch.submit_task(goal="Test task") orch.set_task_state(task_id, TaskState.ACTIVE) orch.set_task_state(task_id, TaskState.COMPLETED) # Force allows invalid transition orch.set_task_state(task_id, TaskState.PENDING, force=True) assert orch.get_task_state(task_id) == TaskState.PENDING def test_failed_to_pending_retry(self): """Test that FAILED can transition to PENDING for retry.""" bus = EventBus() state = StateManager() orch = Orchestrator(event_bus=bus, state_manager=state) task_id = orch.submit_task(goal="Test task") orch.set_task_state(task_id, TaskState.ACTIVE) orch.set_task_state(task_id, TaskState.FAILED) # FAILED -> PENDING is valid (retry) orch.set_task_state(task_id, TaskState.PENDING) assert orch.get_task_state(task_id) == TaskState.PENDING class TestEventBus: """Test EventBus functionality.""" def test_publish_subscribe(self): """Test basic pub/sub.""" bus = EventBus() received = [] def handler(event_type, payload): received.append({"type": event_type, "payload": payload}) bus.subscribe("test_event", handler) bus.publish("test_event", {"data": "value"}) assert len(received) == 1 assert received[0]["payload"]["data"] == "value" def test_multiple_subscribers(self): """Test multiple subscribers receive events.""" bus = EventBus() received1 = [] received2 = [] bus.subscribe("test", lambda t, p: received1.append(p)) bus.subscribe("test", lambda t, p: received2.append(p)) bus.publish("test", {"n": 1}) assert len(received1) == 1 assert len(received2) == 1 def test_unsubscribe(self): """Test unsubscribe stops delivery.""" bus = EventBus() received = [] def handler(t, p): received.append(p) bus.subscribe("test", handler) bus.publish("test", {}) assert len(received) == 1 bus.unsubscribe("test", handler) bus.publish("test", {}) assert len(received) == 1 # No new messages def test_clear(self): """Test clear removes all subscribers.""" bus = EventBus() received = [] bus.subscribe("test", lambda t, p: received.append(p)) bus.clear() bus.publish("test", {}) assert len(received) == 0