first commit
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
283
tests/test_architecture_tracker.py
Normal file
283
tests/test_architecture_tracker.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""Tests for ArchitectureTracker."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.core.architecture_tracker import ArchitectureTracker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_data_dir(tmp_path):
|
||||
"""Temporary data directory."""
|
||||
return str(tmp_path / "data")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker(tmp_data_dir):
|
||||
"""ArchitectureTracker with no Claude SDK client."""
|
||||
return ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker_with_client(tmp_data_dir):
|
||||
"""ArchitectureTracker with a mocked Claude SDK client."""
|
||||
mock_client = AsyncMock()
|
||||
t = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
t._client = mock_client
|
||||
return t, mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitialization:
|
||||
def test_creates_default_architecture(self, tracker):
|
||||
arch = tracker._architecture
|
||||
assert "modules" in arch
|
||||
assert "utilities" in arch
|
||||
assert "design_patterns" in arch
|
||||
assert "naming_conventions" in arch
|
||||
assert "tech_stack" in arch
|
||||
assert arch["version"] == 1
|
||||
assert "last_updated" in arch
|
||||
|
||||
def test_creates_data_directory(self, tmp_data_dir):
|
||||
tracker = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
assert os.path.isdir(tmp_data_dir)
|
||||
|
||||
def test_default_architecture_has_empty_lists(self, tracker):
|
||||
assert tracker._architecture["modules"] == []
|
||||
assert tracker._architecture["utilities"] == []
|
||||
assert tracker._architecture["design_patterns"] == []
|
||||
|
||||
def test_default_naming_conventions(self, tracker):
|
||||
conventions = tracker._architecture["naming_conventions"]
|
||||
assert conventions["variables"] == "snake_case"
|
||||
assert conventions["classes"] == "PascalCase"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load / Save Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPersistence:
|
||||
def test_save_and_load(self, tmp_data_dir):
|
||||
tracker1 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
tracker1.add_module("TestModule", "A test module", "test.py")
|
||||
|
||||
tracker2 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
assert len(tracker2._architecture["modules"]) == 1
|
||||
assert tracker2._architecture["modules"][0]["name"] == "TestModule"
|
||||
|
||||
def test_save_updates_timestamp(self, tracker):
|
||||
old_ts = tracker._architecture.get("last_updated", "")
|
||||
tracker.save_architecture(tracker._architecture)
|
||||
new_ts = tracker._architecture["last_updated"]
|
||||
# Timestamp should be updated (or at least present)
|
||||
assert new_ts is not None
|
||||
assert len(new_ts) > 0
|
||||
|
||||
def test_load_corrupt_file_returns_default(self, tmp_data_dir):
|
||||
os.makedirs(tmp_data_dir, exist_ok=True)
|
||||
corrupt_path = os.path.join(tmp_data_dir, "global_architecture.json")
|
||||
with open(corrupt_path, "w") as f:
|
||||
f.write("NOT VALID JSON {{{")
|
||||
|
||||
tracker = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
assert tracker._architecture["version"] == 1
|
||||
assert tracker._architecture["modules"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAddModule:
|
||||
def test_add_module(self, tracker):
|
||||
tracker.add_module("MyModule", "Does something", "src/my_module.py")
|
||||
assert len(tracker._architecture["modules"]) == 1
|
||||
mod = tracker._architecture["modules"][0]
|
||||
assert mod["name"] == "MyModule"
|
||||
assert mod["purpose"] == "Does something"
|
||||
assert mod["file_path"] == "src/my_module.py"
|
||||
|
||||
def test_add_module_persists(self, tmp_data_dir):
|
||||
t1 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
t1.add_module("Persisted", "persists", "p.py")
|
||||
|
||||
t2 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
assert any(m["name"] == "Persisted" for m in t2._architecture["modules"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_utility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAddUtility:
|
||||
def test_add_utility(self, tracker):
|
||||
tracker.add_utility("helper_func", "Helps with things", "utils.py")
|
||||
assert len(tracker._architecture["utilities"]) == 1
|
||||
util = tracker._architecture["utilities"][0]
|
||||
assert util["name"] == "helper_func"
|
||||
assert util["description"] == "Helps with things"
|
||||
assert util["file_path"] == "utils.py"
|
||||
|
||||
def test_add_utility_persists(self, tmp_data_dir):
|
||||
t1 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
t1.add_utility("persisted_func", "persists", "p.py")
|
||||
|
||||
t2 = ArchitectureTracker(data_dir=tmp_data_dir)
|
||||
assert any(u["name"] == "persisted_func" for u in t2._architecture["utilities"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_architecture_summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetArchitectureSummary:
|
||||
def test_returns_formatted_string(self, tracker):
|
||||
tracker.add_module("GraphEngine", "Runs the graph", "graph.py")
|
||||
tracker.add_utility("parse_json", "Parses JSON input", "utils.py")
|
||||
summary = tracker.get_architecture_summary()
|
||||
assert isinstance(summary, str)
|
||||
assert "Project Architecture Summary" in summary
|
||||
assert "GraphEngine" in summary
|
||||
assert "parse_json" in summary
|
||||
|
||||
def test_includes_tech_stack(self, tracker):
|
||||
summary = tracker.get_architecture_summary()
|
||||
assert "Tech Stack" in summary
|
||||
assert "Python" in summary
|
||||
|
||||
def test_includes_naming_conventions(self, tracker):
|
||||
summary = tracker.get_architecture_summary()
|
||||
assert "Naming Conventions" in summary
|
||||
assert "snake_case" in summary
|
||||
|
||||
def test_respects_max_tokens_limit(self, tracker):
|
||||
# Add many modules to produce a large summary
|
||||
for i in range(200):
|
||||
tracker._architecture["modules"].append({
|
||||
"name": f"Module_{i}",
|
||||
"purpose": f"Purpose of module {i} with extra text for padding",
|
||||
"file_path": f"src/module_{i}.py",
|
||||
})
|
||||
# Very small token limit
|
||||
summary = tracker.get_architecture_summary(max_tokens=50)
|
||||
# 50 tokens * 4 chars = 200 chars max
|
||||
assert len(summary) <= 200
|
||||
|
||||
def test_empty_architecture_still_returns_summary(self, tracker):
|
||||
summary = tracker.get_architecture_summary()
|
||||
assert "Project Architecture Summary" in summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_architecture (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateArchitecture:
|
||||
def test_basic_extraction_no_client(self, tracker, tmp_path):
|
||||
# Create a sample Python file
|
||||
py_file = tmp_path / "sample.py"
|
||||
py_file.write_text(
|
||||
'"""Sample module."""\n\n'
|
||||
'class SampleClass:\n'
|
||||
' """A sample class for testing."""\n'
|
||||
' pass\n\n'
|
||||
'def sample_function():\n'
|
||||
' """A sample utility function."""\n'
|
||||
' pass\n\n'
|
||||
'def _private_function():\n'
|
||||
' """Should be skipped."""\n'
|
||||
' pass\n'
|
||||
)
|
||||
|
||||
task = {"title": "Test task", "description": "Testing extraction"}
|
||||
asyncio.run(tracker.update_architecture(task, [str(py_file)]))
|
||||
|
||||
modules = tracker._architecture["modules"]
|
||||
utilities = tracker._architecture["utilities"]
|
||||
assert any(m["name"] == "SampleClass" for m in modules)
|
||||
assert any(u["name"] == "sample_function" for u in utilities)
|
||||
# Private functions should be skipped
|
||||
assert not any(u["name"] == "_private_function" for u in utilities)
|
||||
|
||||
def test_skips_non_python_files(self, tracker, tmp_path):
|
||||
txt_file = tmp_path / "readme.txt"
|
||||
txt_file.write_text("Not a Python file")
|
||||
|
||||
task = {"title": "Test task"}
|
||||
asyncio.run(tracker.update_architecture(task, [str(txt_file)]))
|
||||
assert tracker._architecture["modules"] == []
|
||||
assert tracker._architecture["utilities"] == []
|
||||
|
||||
def test_skips_nonexistent_files(self, tracker):
|
||||
task = {"title": "Test task"}
|
||||
asyncio.run(tracker.update_architecture(task, ["/nonexistent/file.py"]))
|
||||
assert tracker._architecture["modules"] == []
|
||||
|
||||
def test_ai_extraction_with_mock_client(self, tracker_with_client, tmp_path):
|
||||
tracker, mock_client = tracker_with_client
|
||||
|
||||
py_file = tmp_path / "ai_sample.py"
|
||||
py_file.write_text(
|
||||
'class AIExtracted:\n'
|
||||
' """Extracted by AI."""\n'
|
||||
' pass\n'
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps({
|
||||
"classes": [{"name": "AIExtracted", "purpose": "AI detected class"}],
|
||||
"functions": [{"name": "ai_helper", "description": "AI detected function"}],
|
||||
})
|
||||
mock_client.complete.return_value = mock_response
|
||||
|
||||
task = {"title": "AI test task"}
|
||||
asyncio.run(tracker.update_architecture(task, [str(py_file)]))
|
||||
|
||||
modules = tracker._architecture["modules"]
|
||||
utilities = tracker._architecture["utilities"]
|
||||
assert any(m["name"] == "AIExtracted" for m in modules)
|
||||
assert any(u["name"] == "ai_helper" for u in utilities)
|
||||
|
||||
def test_ai_extraction_fallback_on_failure(self, tracker_with_client, tmp_path):
|
||||
tracker, mock_client = tracker_with_client
|
||||
|
||||
py_file = tmp_path / "fallback_sample.py"
|
||||
py_file.write_text(
|
||||
'class FallbackClass:\n'
|
||||
' """Falls back to basic extraction."""\n'
|
||||
' pass\n'
|
||||
)
|
||||
|
||||
mock_client.complete.side_effect = Exception("API error")
|
||||
|
||||
task = {"title": "Fallback test"}
|
||||
asyncio.run(tracker.update_architecture(task, [str(py_file)]))
|
||||
|
||||
# Should fall back to basic extraction
|
||||
modules = tracker._architecture["modules"]
|
||||
assert any(m["name"] == "FallbackClass" for m in modules)
|
||||
|
||||
def test_no_duplicate_modules(self, tracker, tmp_path):
|
||||
py_file = tmp_path / "dup.py"
|
||||
py_file.write_text('class DupClass:\n pass\n')
|
||||
|
||||
task = {"title": "Dup test"}
|
||||
asyncio.run(tracker.update_architecture(task, [str(py_file)]))
|
||||
asyncio.run(tracker.update_architecture(task, [str(py_file)]))
|
||||
|
||||
names = [m["name"] for m in tracker._architecture["modules"]]
|
||||
assert names.count("DupClass") == 1
|
||||
345
tests/test_claude_client.py
Normal file
345
tests/test_claude_client.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for Claude Agent SDK client adapter."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.core.claude_client import ClaudeSDKClient
|
||||
|
||||
|
||||
class _FakeOptions:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class _FakeProcessError(RuntimeError):
|
||||
def __init__(self, message: str, stderr: str):
|
||||
super().__init__(message)
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class TestClaudeSDKClient:
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_uses_result_text_and_usage_tokens(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
yield SimpleNamespace(
|
||||
content=[SimpleNamespace(text="intermediate answer")],
|
||||
model="claude-sonnet-4-6",
|
||||
usage=None,
|
||||
is_error=False,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
result="final answer",
|
||||
usage={"input_tokens": 12, "output_tokens": 34},
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
result = await client.complete(
|
||||
"hello",
|
||||
model="claude-sonnet-4-6",
|
||||
system_prompt="You are a tester.",
|
||||
max_turns=1,
|
||||
)
|
||||
|
||||
assert result.text == "final answer"
|
||||
assert result.input_tokens == 12
|
||||
assert result.output_tokens == 34
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_raises_on_error_result(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
yield SimpleNamespace(
|
||||
result="something went wrong",
|
||||
usage={"input_tokens": 1, "output_tokens": 1},
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with pytest.raises(RuntimeError, match="something went wrong"):
|
||||
await client.complete("hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_falls_back_to_content_blocks(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
yield SimpleNamespace(
|
||||
content=[SimpleNamespace(text="first"), SimpleNamespace(text="second")],
|
||||
model="claude-sonnet-4-6",
|
||||
usage={"cache_creation_input_tokens": 7, "output_tokens": 2},
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
result = await client.complete("hello")
|
||||
|
||||
assert result.text == "first\nsecond"
|
||||
assert result.input_tokens == 7
|
||||
assert result.output_tokens == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_rejects_error_subtype(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
yield SimpleNamespace(
|
||||
subtype="error_during_execution",
|
||||
is_error=False,
|
||||
result=None,
|
||||
usage={"input_tokens": 0, "output_tokens": 0},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with pytest.raises(RuntimeError, match="error_during_execution"):
|
||||
await client.complete("hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_ignores_user_content_blocks(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
# Simulates a user/system message content block that should not be treated as model output.
|
||||
yield SimpleNamespace(
|
||||
content=[SimpleNamespace(text="[Request interrupted by user]")],
|
||||
is_error=False,
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
result="actual assistant result",
|
||||
subtype="success",
|
||||
is_error=False,
|
||||
usage={"input_tokens": 1, "output_tokens": 2},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
result = await client.complete("hello")
|
||||
|
||||
assert result.text == "actual assistant result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_wraps_query_exception_with_hint(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
raise RuntimeError("Command failed with exit code 1")
|
||||
yield # pragma: no cover
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with pytest.raises(RuntimeError, match="Hint: verify Claude auth is available"):
|
||||
await client.complete("hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_passes_debug_to_stderr_flag(self):
|
||||
client = ClaudeSDKClient(enable_debug=True)
|
||||
client._options_cls = _FakeOptions
|
||||
captured_options = None
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
nonlocal captured_options
|
||||
captured_options = options
|
||||
yield SimpleNamespace(
|
||||
result="ok",
|
||||
subtype="success",
|
||||
is_error=False,
|
||||
usage={"input_tokens": 0, "output_tokens": 0},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
await client.complete("hello")
|
||||
|
||||
assert captured_options is not None
|
||||
assert captured_options.kwargs["extra_args"] == {"debug-to-stderr": None}
|
||||
assert hasattr(captured_options.kwargs["debug_stderr"], "fileno")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_omits_debug_stderr_without_debug_flag(self):
|
||||
client = ClaudeSDKClient(enable_debug=False)
|
||||
client._options_cls = _FakeOptions
|
||||
captured_options = None
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
nonlocal captured_options
|
||||
captured_options = options
|
||||
yield SimpleNamespace(
|
||||
result="ok",
|
||||
subtype="success",
|
||||
is_error=False,
|
||||
usage={"input_tokens": 0, "output_tokens": 0},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
await client.complete("hello")
|
||||
|
||||
assert captured_options is not None
|
||||
assert "extra_args" not in captured_options.kwargs
|
||||
assert "debug_stderr" not in captured_options.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_includes_exception_stderr_detail(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
raise _FakeProcessError(
|
||||
"Command failed with exit code 1",
|
||||
"permission denied writing ~/.claude.json",
|
||||
)
|
||||
yield # pragma: no cover
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with pytest.raises(RuntimeError, match="permission denied writing ~/.claude.json"):
|
||||
await client.complete("hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_uses_fallback_home_when_default_home_not_writable(self, tmp_path):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
captured_options = None
|
||||
cwd_dir = tmp_path / "workspace"
|
||||
cwd_dir.mkdir(parents=True, exist_ok=True)
|
||||
expected_home = cwd_dir / ".app_factory" / "claude_home"
|
||||
expected_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
client._claude_home_is_writable = lambda home: False
|
||||
client._prepare_fallback_claude_home = (
|
||||
lambda source_home, fallback_home: expected_home
|
||||
)
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
nonlocal captured_options
|
||||
captured_options = options
|
||||
yield SimpleNamespace(
|
||||
result="ok",
|
||||
subtype="success",
|
||||
is_error=False,
|
||||
usage={"input_tokens": 0, "output_tokens": 0},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
await client.complete("hello", cwd=str(cwd_dir))
|
||||
|
||||
assert captured_options is not None
|
||||
assert captured_options.kwargs["env"]["HOME"] == str(expected_home)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_retries_rate_limit_event_then_succeeds(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
attempts = 0
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("Unknown message type: rate_limit_event")
|
||||
yield SimpleNamespace(
|
||||
result="ok",
|
||||
subtype="success",
|
||||
is_error=False,
|
||||
usage={"input_tokens": 2, "output_tokens": 3},
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
||||
result = await client.complete("hello")
|
||||
|
||||
assert result.text == "ok"
|
||||
assert attempts == 2
|
||||
mock_sleep.assert_awaited_once_with(0.2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_retries_rate_limit_event_exhausts_then_fails(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
attempts = 0
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
raise RuntimeError("Unknown message type: rate_limit_event")
|
||||
yield # pragma: no cover
|
||||
|
||||
client._query = _fake_query
|
||||
|
||||
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
||||
with pytest.raises(RuntimeError, match="rate_limit_event"):
|
||||
await client.complete("hello")
|
||||
|
||||
assert attempts == 4
|
||||
assert mock_sleep.await_args_list == [call(0.2), call(0.8), call(4.0)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_emits_observability_events(self):
|
||||
client = ClaudeSDKClient()
|
||||
client._options_cls = _FakeOptions
|
||||
mock_observability = MagicMock()
|
||||
|
||||
async def _fake_query(*, prompt, options):
|
||||
yield SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
id="toolu_1",
|
||||
name="Bash",
|
||||
input={"command": "echo hi", "api_key": "secret"},
|
||||
),
|
||||
SimpleNamespace(tool_use_id="toolu_1", content="done", is_error=False),
|
||||
SimpleNamespace(text="Tool finished"),
|
||||
],
|
||||
model="claude-sonnet-4-6",
|
||||
session_id="session-123",
|
||||
usage=None,
|
||||
is_error=False,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
subtype="success",
|
||||
duration_ms=250,
|
||||
duration_api_ms=200,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="session-123",
|
||||
usage={"input_tokens": 3, "output_tokens": 4},
|
||||
result="ok",
|
||||
)
|
||||
|
||||
client._query = _fake_query
|
||||
result = await client.complete(
|
||||
"hello",
|
||||
observability=mock_observability,
|
||||
agent_name="pm_agent",
|
||||
task_id="expand_prd",
|
||||
)
|
||||
|
||||
assert result.text == "ok"
|
||||
event_types = [
|
||||
call.kwargs["event_type"]
|
||||
for call in mock_observability.log_claude_event.call_args_list
|
||||
]
|
||||
assert "request_start" in event_types
|
||||
assert "tool_use" in event_types
|
||||
assert "tool_result" in event_types
|
||||
assert "result_message" in event_types
|
||||
assert "request_complete" in event_types
|
||||
|
||||
tool_use_calls = [
|
||||
c.kwargs for c in mock_observability.log_claude_event.call_args_list
|
||||
if c.kwargs.get("event_type") == "tool_use"
|
||||
]
|
||||
assert tool_use_calls
|
||||
assert tool_use_calls[0]["payload"]["tool_input"]["api_key"] == "[REDACTED]"
|
||||
279
tests/test_dev_agent.py
Normal file
279
tests/test_dev_agent.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tests for DevAgentManager."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
import pexpect
|
||||
import pytest
|
||||
|
||||
from app_factory.agents.dev_agent import DevAgentManager, PROMPT_TEMPLATE_PATH
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docker_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(mock_docker_client):
|
||||
return DevAgentManager(docker_client=mock_docker_client, max_retries=3, timeout=60)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_task():
|
||||
return {
|
||||
"task_id": "42",
|
||||
"title": "Implement login endpoint",
|
||||
"description": "Create a POST /login endpoint with JWT",
|
||||
"details": "Use bcrypt for password hashing, return JWT token",
|
||||
"testStrategy": "Unit test auth flow, integration test endpoint",
|
||||
}
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_init_with_client(self, mock_docker_client):
|
||||
mgr = DevAgentManager(docker_client=mock_docker_client, max_retries=5, timeout=900)
|
||||
assert mgr.docker_client is mock_docker_client
|
||||
assert mgr.max_retries == 5
|
||||
assert mgr.timeout == 900
|
||||
assert mgr._retry_counts == {}
|
||||
|
||||
def test_init_defaults(self, mock_docker_client):
|
||||
mgr = DevAgentManager(docker_client=mock_docker_client)
|
||||
assert mgr.max_retries == 3
|
||||
assert mgr.timeout == 1800
|
||||
|
||||
def test_init_creates_docker_client_from_env(self):
|
||||
mock_client = MagicMock()
|
||||
mock_docker = MagicMock()
|
||||
mock_docker.from_env.return_value = mock_client
|
||||
with patch.dict("sys.modules", {"docker": mock_docker}):
|
||||
mgr = DevAgentManager()
|
||||
assert mgr.docker_client is mock_client
|
||||
|
||||
|
||||
class TestPrepareTaskPrompt:
|
||||
def test_includes_all_fields(self, manager, sample_task):
|
||||
prompt = manager.prepare_task_prompt(sample_task, global_arch="Microservice arch")
|
||||
assert "42" in prompt
|
||||
assert "Implement login endpoint" in prompt
|
||||
assert "Create a POST /login endpoint with JWT" in prompt
|
||||
assert "Use bcrypt for password hashing" in prompt
|
||||
assert "Unit test auth flow" in prompt
|
||||
assert "Microservice arch" in prompt
|
||||
|
||||
def test_without_global_arch(self, manager, sample_task):
|
||||
prompt = manager.prepare_task_prompt(sample_task)
|
||||
assert "No architecture context provided." in prompt
|
||||
assert "42" in prompt
|
||||
assert "Implement login endpoint" in prompt
|
||||
|
||||
def test_with_empty_global_arch(self, manager, sample_task):
|
||||
prompt = manager.prepare_task_prompt(sample_task, global_arch="")
|
||||
assert "No architecture context provided." in prompt
|
||||
|
||||
def test_uses_id_fallback(self, manager):
|
||||
task = {"id": "99", "title": "Fallback task", "description": "desc"}
|
||||
prompt = manager.prepare_task_prompt(task)
|
||||
assert "99" in prompt
|
||||
|
||||
def test_template_file_exists(self):
|
||||
assert PROMPT_TEMPLATE_PATH.exists(), f"Template not found at {PROMPT_TEMPLATE_PATH}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestExecuteTask:
|
||||
async def test_success_path(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "Created src/auth.py\nModified tests/test_auth.py\n2 passed"
|
||||
mock_child.exitstatus = 0
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_task(sample_task, "container123", worktree)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["exit_code"] == 0
|
||||
assert "src/auth.py" in result["files_changed"]
|
||||
assert isinstance(result["output"], str)
|
||||
|
||||
async def test_failure_path(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "Error: compilation failed\n1 failed"
|
||||
mock_child.exitstatus = 1
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_task(sample_task, "container123", worktree)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["exit_code"] == 1
|
||||
|
||||
async def test_timeout(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.expect.side_effect = pexpect.TIMEOUT("timed out")
|
||||
mock_child.close.return_value = None
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_task(sample_task, "container123", worktree)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["output"] == "timeout"
|
||||
assert result["exit_code"] == -1
|
||||
assert result["files_changed"] == []
|
||||
|
||||
async def test_writes_prompt_file(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "done"
|
||||
mock_child.exitstatus = 0
|
||||
|
||||
written_content = None
|
||||
|
||||
original_open = open
|
||||
|
||||
def capturing_open(path, *args, **kwargs):
|
||||
nonlocal written_content
|
||||
if str(path).endswith(".task_prompt.txt") and "w" in (args[0] if args else ""):
|
||||
result = original_open(path, *args, **kwargs)
|
||||
# We'll check the file exists during execution
|
||||
return result
|
||||
return original_open(path, *args, **kwargs)
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
await manager.execute_task(sample_task, "cid", worktree)
|
||||
|
||||
# Prompt file should be cleaned up after execution
|
||||
prompt_path = os.path.join(worktree, ".task_prompt.txt")
|
||||
assert not os.path.exists(prompt_path), "Prompt file should be cleaned up"
|
||||
|
||||
async def test_spawns_correct_command(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = ""
|
||||
mock_child.exitstatus = 0
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child) as mock_spawn:
|
||||
await manager.execute_task(sample_task, "abc123", worktree)
|
||||
|
||||
mock_spawn.assert_called_once_with(
|
||||
"docker exec abc123 claude --print --prompt-file /workspace/.task_prompt.txt",
|
||||
timeout=60,
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
async def test_none_exitstatus_treated_as_error(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "output"
|
||||
mock_child.exitstatus = None
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_task(sample_task, "cid", worktree)
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert result["exit_code"] == -1
|
||||
|
||||
|
||||
class TestParseClaudeOutput:
|
||||
def test_extracts_files_changed(self, manager):
|
||||
output = "Created src/auth.py\nModified tests/test_auth.py\nUpdated config.json"
|
||||
result = manager.parse_claude_output(output)
|
||||
assert "src/auth.py" in result["files_changed"]
|
||||
assert "tests/test_auth.py" in result["files_changed"]
|
||||
assert "config.json" in result["files_changed"]
|
||||
|
||||
def test_extracts_test_results(self, manager):
|
||||
output = "Results: 5 passed, 2 failed"
|
||||
result = manager.parse_claude_output(output)
|
||||
assert result["test_results"]["passed"] == 5
|
||||
assert result["test_results"]["failed"] == 2
|
||||
|
||||
def test_extracts_errors(self, manager):
|
||||
output = "Error: could not import module\nFAILED: assertion mismatch"
|
||||
result = manager.parse_claude_output(output)
|
||||
assert len(result["errors"]) >= 1
|
||||
|
||||
def test_empty_output(self, manager):
|
||||
result = manager.parse_claude_output("")
|
||||
assert result["files_changed"] == []
|
||||
assert result["test_results"] == {}
|
||||
assert result["errors"] == []
|
||||
|
||||
def test_no_test_results(self, manager):
|
||||
output = "Created app.py\nDone."
|
||||
result = manager.parse_claude_output(output)
|
||||
assert result["test_results"] == {}
|
||||
|
||||
def test_deduplicates_files(self, manager):
|
||||
output = "Created app.py\nEditing app.py"
|
||||
result = manager.parse_claude_output(output)
|
||||
assert result["files_changed"].count("app.py") == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestExecuteWithRetry:
|
||||
async def test_succeeds_on_first_try(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "All good\n3 passed"
|
||||
mock_child.exitstatus = 0
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_with_retry(sample_task, "cid", worktree)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert manager.get_retry_count("42") == 1
|
||||
|
||||
async def test_succeeds_after_failures(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
|
||||
# First two calls fail, third succeeds
|
||||
fail_child = MagicMock()
|
||||
fail_child.before = "Error: build failed"
|
||||
fail_child.exitstatus = 1
|
||||
|
||||
success_child = MagicMock()
|
||||
success_child.before = "All good"
|
||||
success_child.exitstatus = 0
|
||||
|
||||
with patch(
|
||||
"app_factory.agents.dev_agent.pexpect.spawn",
|
||||
side_effect=[fail_child, fail_child, success_child],
|
||||
):
|
||||
result = await manager.execute_with_retry(sample_task, "cid", worktree)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert manager.get_retry_count("42") == 3
|
||||
|
||||
async def test_max_retries_exceeded(self, manager, sample_task, tmp_path):
|
||||
worktree = str(tmp_path)
|
||||
mock_child = MagicMock()
|
||||
mock_child.before = "Error: persistent failure"
|
||||
mock_child.exitstatus = 1
|
||||
|
||||
with patch("app_factory.agents.dev_agent.pexpect.spawn", return_value=mock_child):
|
||||
result = await manager.execute_with_retry(sample_task, "cid", worktree)
|
||||
|
||||
assert result["status"] == "needs_clarification"
|
||||
assert manager.get_retry_count("42") == 3
|
||||
|
||||
|
||||
class TestRetryCounters:
|
||||
def test_get_retry_count_default(self, manager):
|
||||
assert manager.get_retry_count("unknown") == 0
|
||||
|
||||
def test_get_retry_count_after_set(self, manager):
|
||||
manager._retry_counts["task-1"] = 2
|
||||
assert manager.get_retry_count("task-1") == 2
|
||||
|
||||
def test_reset_retry_count(self, manager):
|
||||
manager._retry_counts["task-1"] = 3
|
||||
manager.reset_retry_count("task-1")
|
||||
assert manager.get_retry_count("task-1") == 0
|
||||
|
||||
def test_reset_nonexistent_task(self, manager):
|
||||
# Should not raise
|
||||
manager.reset_retry_count("nonexistent")
|
||||
assert manager.get_retry_count("nonexistent") == 0
|
||||
575
tests/test_graph.py
Normal file
575
tests/test_graph.py
Normal file
@@ -0,0 +1,575 @@
|
||||
"""Tests for AppFactoryOrchestrator (LangGraph state machine)."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.core.graph import AppFactoryOrchestrator, AppFactoryState
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pm_agent():
|
||||
agent = AsyncMock()
|
||||
agent.expand_prompt_to_prd = AsyncMock(return_value="# Generated PRD\n\nObjective: Build an app")
|
||||
agent.handle_clarification_request = AsyncMock(return_value="Clarification: use REST API")
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_agent():
|
||||
agent = AsyncMock()
|
||||
agent.parse_prd = AsyncMock(return_value={"tasks": []})
|
||||
agent.get_unblocked_tasks = AsyncMock(return_value=[
|
||||
{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []},
|
||||
{"id": 2, "title": "Task 2", "status": "pending", "dependencies": []},
|
||||
])
|
||||
agent.update_task_status = AsyncMock()
|
||||
agent.get_task_details = AsyncMock(return_value={"id": 1, "title": "Task 1"})
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dev_manager():
|
||||
manager = AsyncMock()
|
||||
manager.execute_with_retry = AsyncMock(return_value={
|
||||
"status": "success",
|
||||
"output": "Task completed",
|
||||
"files_changed": ["app.py"],
|
||||
"exit_code": 0,
|
||||
})
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qa_agent():
|
||||
agent = AsyncMock()
|
||||
agent.max_retries = 3
|
||||
agent.review_and_merge = AsyncMock(return_value={
|
||||
"status": "merged",
|
||||
"commit_sha": "abc123",
|
||||
"review_summary": "LGTM",
|
||||
})
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workspace_manager():
|
||||
manager = AsyncMock()
|
||||
manager.create_worktree = AsyncMock(return_value="/tmp/worktree/task-1")
|
||||
manager.spin_up_clean_room = AsyncMock(return_value=MagicMock(id="container-123"))
|
||||
manager.cleanup_workspace = AsyncMock()
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_observability():
|
||||
obs = MagicMock()
|
||||
obs.log_state_transition = MagicMock()
|
||||
obs.log_error = MagicMock()
|
||||
return obs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_pm_agent, mock_task_agent, mock_dev_manager,
|
||||
mock_qa_agent, mock_workspace_manager, mock_observability):
|
||||
return AppFactoryOrchestrator(
|
||||
pm_agent=mock_pm_agent,
|
||||
task_agent=mock_task_agent,
|
||||
dev_manager=mock_dev_manager,
|
||||
qa_agent=mock_qa_agent,
|
||||
workspace_manager=mock_workspace_manager,
|
||||
observability=mock_observability,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State Schema Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStateSchema:
|
||||
def test_state_has_all_required_fields(self):
|
||||
required_fields = {
|
||||
"user_input", "prd", "tasks", "active_tasks", "completed_tasks",
|
||||
"blocked_tasks", "clarification_requests", "global_architecture",
|
||||
"iteration_count", "max_iterations", "errors",
|
||||
}
|
||||
assert required_fields == set(AppFactoryState.__annotations__.keys())
|
||||
|
||||
def test_state_can_be_instantiated(self):
|
||||
state: AppFactoryState = {
|
||||
"user_input": "build a todo app",
|
||||
"prd": "",
|
||||
"tasks": [],
|
||||
"active_tasks": {},
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
"iteration_count": 0,
|
||||
"max_iterations": 50,
|
||||
"errors": [],
|
||||
}
|
||||
assert state["user_input"] == "build a todo app"
|
||||
assert state["max_iterations"] == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph Construction Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGraphBuild:
|
||||
def test_graph_builds_without_errors(self):
|
||||
orch = AppFactoryOrchestrator()
|
||||
compiled = orch.build_graph()
|
||||
assert compiled is not None
|
||||
|
||||
def test_graph_builds_with_all_agents(self, orchestrator):
|
||||
compiled = orchestrator.build_graph()
|
||||
assert compiled is not None
|
||||
|
||||
def test_graph_builds_with_none_agents(self):
|
||||
orch = AppFactoryOrchestrator(
|
||||
pm_agent=None, task_agent=None, dev_manager=None,
|
||||
qa_agent=None, workspace_manager=None, observability=None,
|
||||
)
|
||||
compiled = orch.build_graph()
|
||||
assert compiled is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPMNode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pm_node_sets_prd(self, orchestrator, mock_pm_agent):
|
||||
state = {"user_input": "build a todo app", "errors": []}
|
||||
result = await orchestrator._pm_node(state)
|
||||
assert "prd" in result
|
||||
assert result["prd"] == "# Generated PRD\n\nObjective: Build an app"
|
||||
mock_pm_agent.expand_prompt_to_prd.assert_called_once_with("build a todo app")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pm_node_handles_no_input(self, orchestrator):
|
||||
state = {"user_input": "", "errors": []}
|
||||
result = await orchestrator._pm_node(state)
|
||||
assert result["prd"] == ""
|
||||
assert any("No user input" in e for e in result["errors"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pm_node_handles_agent_error(self, orchestrator, mock_pm_agent):
|
||||
mock_pm_agent.expand_prompt_to_prd.side_effect = RuntimeError("API down")
|
||||
state = {"user_input": "build app", "errors": []}
|
||||
result = await orchestrator._pm_node(state)
|
||||
assert result["prd"] == ""
|
||||
assert any("PM agent error" in e for e in result["errors"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pm_node_without_agent(self):
|
||||
orch = AppFactoryOrchestrator()
|
||||
state = {"user_input": "build a todo app", "errors": []}
|
||||
result = await orch._pm_node(state)
|
||||
assert "Mock PRD" in result["prd"]
|
||||
|
||||
|
||||
class TestTaskNode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_node_populates_tasks(self, orchestrator, mock_task_agent):
|
||||
state = {"prd": "some PRD", "tasks": [], "iteration_count": 0, "max_iterations": 50, "errors": []}
|
||||
result = await orchestrator._task_node(state)
|
||||
assert "tasks" in result
|
||||
assert len(result["tasks"]) == 2
|
||||
mock_task_agent.parse_prd.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_node_increments_iteration(self, orchestrator):
|
||||
state = {"prd": "PRD", "tasks": [], "iteration_count": 5, "max_iterations": 50, "errors": []}
|
||||
result = await orchestrator._task_node(state)
|
||||
assert result["iteration_count"] == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_node_refreshes_on_subsequent_passes(self, orchestrator, mock_task_agent):
|
||||
state = {
|
||||
"prd": "PRD",
|
||||
"tasks": [{"id": 1, "title": "existing", "status": "pending", "dependencies": []}],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
"errors": [],
|
||||
}
|
||||
result = await orchestrator._task_node(state)
|
||||
mock_task_agent.parse_prd.assert_not_called()
|
||||
mock_task_agent.get_unblocked_tasks.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_node_stops_at_max_iterations(self, orchestrator):
|
||||
state = {"prd": "PRD", "tasks": [], "iteration_count": 49, "max_iterations": 50, "errors": []}
|
||||
result = await orchestrator._task_node(state)
|
||||
assert result["iteration_count"] == 50
|
||||
assert any("Max iterations" in e for e in result["errors"])
|
||||
|
||||
|
||||
class TestDevDispatchNode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_dispatch_spawns_concurrent_tasks(self, orchestrator, mock_dev_manager, mock_workspace_manager):
|
||||
state = {
|
||||
"tasks": [
|
||||
{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []},
|
||||
{"id": 2, "title": "Task 2", "status": "pending", "dependencies": []},
|
||||
],
|
||||
"completed_tasks": [],
|
||||
"active_tasks": {},
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
}
|
||||
result = await orchestrator._dev_dispatch_node(state)
|
||||
assert "1" in result["completed_tasks"]
|
||||
assert "2" in result["completed_tasks"]
|
||||
assert mock_dev_manager.execute_with_retry.call_count == 2
|
||||
assert mock_workspace_manager.create_worktree.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_dispatch_handles_needs_clarification(self, orchestrator, mock_dev_manager):
|
||||
mock_dev_manager.execute_with_retry.return_value = {
|
||||
"status": "needs_clarification",
|
||||
"output": "Cannot figure out API format",
|
||||
"files_changed": [],
|
||||
"exit_code": -1,
|
||||
}
|
||||
state = {
|
||||
"tasks": [{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []}],
|
||||
"completed_tasks": [],
|
||||
"active_tasks": {},
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
}
|
||||
result = await orchestrator._dev_dispatch_node(state)
|
||||
assert len(result["clarification_requests"]) == 1
|
||||
assert result["clarification_requests"][0]["task_id"] == "1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_dispatch_skips_completed(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []}],
|
||||
"completed_tasks": ["1"],
|
||||
"active_tasks": {},
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
}
|
||||
result = await orchestrator._dev_dispatch_node(state)
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_dispatch_without_agents(self):
|
||||
orch = AppFactoryOrchestrator()
|
||||
state = {
|
||||
"tasks": [
|
||||
{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []},
|
||||
],
|
||||
"completed_tasks": [],
|
||||
"active_tasks": {},
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
}
|
||||
result = await orch._dev_dispatch_node(state)
|
||||
assert "1" in result["completed_tasks"]
|
||||
|
||||
|
||||
class TestQANode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_qa_node_processes_results(self, orchestrator, mock_qa_agent, mock_task_agent):
|
||||
state = {
|
||||
"tasks": [{"id": "1", "title": "Task 1"}],
|
||||
"active_tasks": {"1": {"status": "success", "worktree_path": "/tmp/wt"}},
|
||||
"completed_tasks": ["1"],
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"blocked_tasks": {},
|
||||
}
|
||||
result = await orchestrator._qa_node(state)
|
||||
mock_qa_agent.review_and_merge.assert_called_once()
|
||||
assert result["active_tasks"]["1"]["status"] == "merged"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qa_node_no_tasks_for_qa(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [],
|
||||
"active_tasks": {},
|
||||
"completed_tasks": [],
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"blocked_tasks": {},
|
||||
}
|
||||
result = await orchestrator._qa_node(state)
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qa_node_handles_qa_failure(self, orchestrator, mock_qa_agent):
|
||||
mock_qa_agent.review_and_merge.return_value = {
|
||||
"status": "tests_failed",
|
||||
"retry_count": 3,
|
||||
}
|
||||
state = {
|
||||
"tasks": [{"id": "1", "title": "Task 1"}],
|
||||
"active_tasks": {"1": {"status": "success", "worktree_path": "/tmp/wt"}},
|
||||
"completed_tasks": ["1"],
|
||||
"errors": [],
|
||||
"clarification_requests": [],
|
||||
"blocked_tasks": {},
|
||||
}
|
||||
result = await orchestrator._qa_node(state)
|
||||
assert len(result["clarification_requests"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routing Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoutingAfterTasks:
|
||||
def test_routes_to_dev_dispatch_when_unblocked(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [{"id": 1, "status": "pending", "dependencies": []}],
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_tasks(state) == "dev_dispatch"
|
||||
|
||||
def test_routes_to_end_when_all_done(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [{"id": 1, "status": "done", "dependencies": []}],
|
||||
"completed_tasks": ["1"],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_tasks(state) == "end"
|
||||
|
||||
def test_routes_to_clarification_when_blocked(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [{"id": 1, "status": "pending", "dependencies": [2]}],
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {"1": "dependency not met"},
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_tasks(state) == "clarification"
|
||||
|
||||
def test_routes_to_end_at_max_iterations(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [{"id": 1, "status": "pending", "dependencies": []}],
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 50,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_tasks(state) == "end"
|
||||
|
||||
def test_routes_to_end_when_no_tasks(self, orchestrator):
|
||||
state = {
|
||||
"tasks": [],
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_tasks(state) == "end"
|
||||
|
||||
|
||||
class TestRoutingAfterQA:
|
||||
def test_routes_to_task_node(self, orchestrator):
|
||||
state = {
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_qa(state) == "task_node"
|
||||
|
||||
def test_routes_to_clarification(self, orchestrator):
|
||||
state = {
|
||||
"clarification_requests": [{"question": "What API?"}],
|
||||
"iteration_count": 1,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_qa(state) == "clarification"
|
||||
|
||||
def test_routes_to_end_at_max_iterations(self, orchestrator):
|
||||
state = {
|
||||
"clarification_requests": [],
|
||||
"iteration_count": 50,
|
||||
"max_iterations": 50,
|
||||
}
|
||||
assert orchestrator._should_continue_after_qa(state) == "end"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Iteration Safety Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterationSafety:
|
||||
@pytest.mark.asyncio
|
||||
async def test_iteration_limit_prevents_infinite_loop(self, orchestrator):
|
||||
state = {
|
||||
"prd": "PRD",
|
||||
"tasks": [],
|
||||
"iteration_count": 49,
|
||||
"max_iterations": 50,
|
||||
"errors": [],
|
||||
}
|
||||
result = await orchestrator._task_node(state)
|
||||
assert result["iteration_count"] == 50
|
||||
assert any("Max iterations" in e for e in result["errors"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State Persistence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStatePersistence:
|
||||
def test_save_state(self, orchestrator):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "state.json")
|
||||
state = {
|
||||
"user_input": "test",
|
||||
"prd": "PRD content",
|
||||
"tasks": [{"id": 1}],
|
||||
"active_tasks": {},
|
||||
"completed_tasks": ["1"],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
"iteration_count": 5,
|
||||
"max_iterations": 50,
|
||||
"errors": [],
|
||||
}
|
||||
orchestrator.save_state(state, path)
|
||||
assert os.path.exists(path)
|
||||
with open(path) as f:
|
||||
loaded = json.load(f)
|
||||
assert loaded["user_input"] == "test"
|
||||
assert loaded["iteration_count"] == 5
|
||||
|
||||
def test_load_state(self, orchestrator):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "state.json")
|
||||
data = {"user_input": "hello", "prd": "PRD", "iteration_count": 3}
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
loaded = orchestrator.load_state(path)
|
||||
assert loaded["user_input"] == "hello"
|
||||
assert loaded["iteration_count"] == 3
|
||||
|
||||
def test_save_state_creates_directory(self, orchestrator):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "nested", "dir", "state.json")
|
||||
orchestrator.save_state({"user_input": "test"}, path)
|
||||
assert os.path.exists(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clarification Node Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClarificationNode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarification_resolves_requests(self, orchestrator, mock_pm_agent):
|
||||
state = {
|
||||
"clarification_requests": [
|
||||
{"requesting_agent": "dev", "task_id": "1", "question": "What format?", "context": ""},
|
||||
],
|
||||
"blocked_tasks": {"1": "needs clarification"},
|
||||
"errors": [],
|
||||
}
|
||||
result = await orchestrator._clarification_node(state)
|
||||
assert result["clarification_requests"] == []
|
||||
assert "1" not in result["blocked_tasks"]
|
||||
mock_pm_agent.handle_clarification_request.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarification_handles_empty_requests(self, orchestrator):
|
||||
state = {"clarification_requests": [], "blocked_tasks": {}, "errors": []}
|
||||
result = await orchestrator._clarification_node(state)
|
||||
assert result["clarification_requests"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarification_without_agent(self):
|
||||
orch = AppFactoryOrchestrator()
|
||||
state = {
|
||||
"clarification_requests": [{"task_id": "1", "question": "?"}],
|
||||
"blocked_tasks": {"1": "blocked"},
|
||||
"errors": [],
|
||||
}
|
||||
result = await orch._clarification_node(state)
|
||||
assert result["clarification_requests"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-End Run Test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_executes_end_to_end(self, mock_pm_agent, mock_task_agent,
|
||||
mock_dev_manager, mock_qa_agent,
|
||||
mock_workspace_manager, mock_observability):
|
||||
# After first get_unblocked_tasks returns tasks, the second call returns empty
|
||||
# to terminate the loop.
|
||||
mock_task_agent.get_unblocked_tasks = AsyncMock(
|
||||
side_effect=[
|
||||
[{"id": 1, "title": "Task 1", "status": "pending", "dependencies": []}],
|
||||
[], # No more tasks - triggers end
|
||||
]
|
||||
)
|
||||
|
||||
orch = AppFactoryOrchestrator(
|
||||
pm_agent=mock_pm_agent,
|
||||
task_agent=mock_task_agent,
|
||||
dev_manager=mock_dev_manager,
|
||||
qa_agent=mock_qa_agent,
|
||||
workspace_manager=mock_workspace_manager,
|
||||
observability=mock_observability,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
state_path = os.path.join(tmpdir, "state.json")
|
||||
with patch.object(orch, "save_state") as mock_save:
|
||||
result = await orch.run("build a todo app")
|
||||
|
||||
assert result["prd"] != ""
|
||||
mock_pm_agent.expand_prompt_to_prd.assert_called_once_with("build a todo app")
|
||||
mock_task_agent.parse_prd.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_no_agents(self):
|
||||
orch = AppFactoryOrchestrator()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
state_path = os.path.join(tmpdir, "state.json")
|
||||
with patch.object(orch, "save_state"):
|
||||
result = await orch.run("build something")
|
||||
assert "Mock PRD" in result["prd"]
|
||||
456
tests/test_main.py
Normal file
456
tests/test_main.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Tests for main.py entry point, error handling, and integration."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from io import StringIO
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from main import (
|
||||
AppFactoryError,
|
||||
ClarificationTimeout,
|
||||
ConfigurationError,
|
||||
DockerDaemonError,
|
||||
GracefulShutdown,
|
||||
GitError,
|
||||
MCPConnectionError,
|
||||
main,
|
||||
parse_args,
|
||||
print_summary,
|
||||
run_factory,
|
||||
validate_environment,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_args tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseArgs:
|
||||
def test_prompt_required(self):
|
||||
with pytest.raises(SystemExit):
|
||||
parse_args([])
|
||||
|
||||
def test_prompt_only(self):
|
||||
args = parse_args(["--prompt", "Build a REST API"])
|
||||
assert args.prompt == "Build a REST API"
|
||||
assert args.repo_path == os.getcwd()
|
||||
assert args.max_concurrent_tasks == 5
|
||||
assert args.debug is False
|
||||
assert args.dry_run is False
|
||||
|
||||
def test_all_options(self):
|
||||
args = parse_args([
|
||||
"--prompt", "Build an app",
|
||||
"--repo-path", "/tmp/project",
|
||||
"--max-concurrent-tasks", "3",
|
||||
"--debug",
|
||||
"--dry-run",
|
||||
])
|
||||
assert args.prompt == "Build an app"
|
||||
assert args.repo_path == "/tmp/project"
|
||||
assert args.max_concurrent_tasks == 3
|
||||
assert args.debug is True
|
||||
assert args.dry_run is True
|
||||
|
||||
def test_max_concurrent_tasks_default(self):
|
||||
args = parse_args(["--prompt", "test"])
|
||||
assert args.max_concurrent_tasks == 5
|
||||
|
||||
def test_repo_path_default_is_cwd(self):
|
||||
args = parse_args(["--prompt", "test"])
|
||||
assert args.repo_path == os.getcwd()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_environment tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEnvironment:
|
||||
def test_valid_config(self):
|
||||
env = {
|
||||
"ANTHROPIC_API_KEY": "sk-test-key",
|
||||
"LANGSMITH_API_KEY": "ls-key",
|
||||
"LANGSMITH_PROJECT": "my-project",
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False), \
|
||||
patch("main.subprocess.run") as mock_run, \
|
||||
patch("main.shutil.which", return_value="/usr/bin/git"):
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
config = validate_environment()
|
||||
|
||||
assert config["api_key"] == "sk-test-key"
|
||||
assert config["auth_token"] == ""
|
||||
assert config["langsmith_api_key"] == "ls-key"
|
||||
assert config["langsmith_project"] == "my-project"
|
||||
|
||||
def test_missing_api_key_still_works(self):
|
||||
"""API key is optional (Claude Code OAuth supported)."""
|
||||
with patch.dict(os.environ, {}, clear=True), \
|
||||
patch("main.subprocess.run") as mock_run, \
|
||||
patch("main.shutil.which", return_value="/usr/bin/git"):
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
config = validate_environment()
|
||||
assert config["api_key"] == ""
|
||||
assert config["auth_token"] == ""
|
||||
|
||||
def test_docker_not_running(self):
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test-key"}, clear=False), \
|
||||
patch("main.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=1)
|
||||
with pytest.raises(DockerDaemonError, match="not running"):
|
||||
validate_environment()
|
||||
|
||||
def test_docker_not_found(self):
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test-key"}, clear=False), \
|
||||
patch("main.subprocess.run", side_effect=FileNotFoundError):
|
||||
with pytest.raises(DockerDaemonError, match="not found"):
|
||||
validate_environment()
|
||||
|
||||
def test_docker_timeout(self):
|
||||
import subprocess as sp
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test-key"}, clear=False), \
|
||||
patch("main.subprocess.run", side_effect=sp.TimeoutExpired("docker", 10)):
|
||||
with pytest.raises(DockerDaemonError, match="not responding"):
|
||||
validate_environment()
|
||||
|
||||
def test_git_not_found(self):
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "sk-test-key"}, clear=False), \
|
||||
patch("main.subprocess.run") as mock_run, \
|
||||
patch("main.shutil.which", return_value=None):
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
with pytest.raises(GitError, match="git not found"):
|
||||
validate_environment()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# print_summary tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrintSummary:
|
||||
def test_basic_summary(self, capsys):
|
||||
result = {
|
||||
"completed_tasks": ["1", "2"],
|
||||
"tasks": [{"id": 1}, {"id": 2}, {"id": 3}],
|
||||
"errors": [],
|
||||
"iteration_count": 5,
|
||||
}
|
||||
start = datetime.now(timezone.utc)
|
||||
print_summary(result, start)
|
||||
captured = capsys.readouterr().out
|
||||
assert "2 / 3" in captured
|
||||
assert "Iterations" in captured
|
||||
|
||||
def test_summary_with_errors(self, capsys):
|
||||
result = {
|
||||
"completed_tasks": [],
|
||||
"tasks": [{"id": 1}],
|
||||
"errors": ["Error one", "Error two"],
|
||||
"iteration_count": 1,
|
||||
}
|
||||
start = datetime.now(timezone.utc)
|
||||
print_summary(result, start)
|
||||
captured = capsys.readouterr().out
|
||||
assert "Errors" in captured
|
||||
assert "Error one" in captured
|
||||
|
||||
def test_summary_truncates_many_errors(self, capsys):
|
||||
result = {
|
||||
"completed_tasks": [],
|
||||
"tasks": [],
|
||||
"errors": [f"Error {i}" for i in range(10)],
|
||||
"iteration_count": 0,
|
||||
}
|
||||
start = datetime.now(timezone.utc)
|
||||
print_summary(result, start)
|
||||
captured = capsys.readouterr().out
|
||||
assert "and 5 more" in captured
|
||||
|
||||
def test_summary_with_langsmith(self, capsys):
|
||||
result = {
|
||||
"completed_tasks": [],
|
||||
"tasks": [],
|
||||
"errors": [],
|
||||
"iteration_count": 0,
|
||||
}
|
||||
start = datetime.now(timezone.utc)
|
||||
with patch.dict(os.environ, {"LANGSMITH_PROJECT": "test-proj"}):
|
||||
print_summary(result, start)
|
||||
captured = capsys.readouterr().out
|
||||
assert "test-proj" in captured
|
||||
|
||||
def test_summary_empty_result(self, capsys):
|
||||
result = {
|
||||
"completed_tasks": [],
|
||||
"tasks": [],
|
||||
"errors": [],
|
||||
"iteration_count": 0,
|
||||
}
|
||||
start = datetime.now(timezone.utc)
|
||||
print_summary(result, start)
|
||||
captured = capsys.readouterr().out
|
||||
assert "0 / 0" in captured
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GracefulShutdown tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGracefulShutdown:
|
||||
def test_initial_state(self):
|
||||
with patch("main.signal.signal"):
|
||||
gs = GracefulShutdown()
|
||||
assert gs.shutdown_requested is False
|
||||
assert gs.workspace_manager is None
|
||||
|
||||
def test_registers_signals(self):
|
||||
with patch("main.signal.signal") as mock_signal:
|
||||
gs = GracefulShutdown()
|
||||
calls = [c[0] for c in mock_signal.call_args_list]
|
||||
assert (signal.SIGINT, gs._handler) in calls
|
||||
assert (signal.SIGTERM, gs._handler) in calls
|
||||
|
||||
def test_first_signal_sets_flag(self):
|
||||
with patch("main.signal.signal"):
|
||||
gs = GracefulShutdown()
|
||||
# Simulate first signal
|
||||
with patch("builtins.print"):
|
||||
gs._handler(signal.SIGINT, None)
|
||||
assert gs.shutdown_requested is True
|
||||
|
||||
def test_second_signal_force_exits(self):
|
||||
with patch("main.signal.signal"):
|
||||
gs = GracefulShutdown()
|
||||
gs.shutdown_requested = True
|
||||
with patch("builtins.print"), pytest.raises(SystemExit):
|
||||
gs._handler(signal.SIGINT, None)
|
||||
|
||||
def test_first_signal_triggers_cleanup(self):
|
||||
mock_ws = MagicMock()
|
||||
mock_ws.cleanup_all = AsyncMock()
|
||||
with patch("main.signal.signal"):
|
||||
gs = GracefulShutdown(workspace_manager=mock_ws)
|
||||
|
||||
# Simulate handler with a running loop
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
async def _run():
|
||||
with patch("builtins.print"):
|
||||
gs._handler(signal.SIGINT, None)
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
loop.close()
|
||||
assert gs.shutdown_requested is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFactory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initializes_all_components(self):
|
||||
mock_orchestrator_instance = MagicMock()
|
||||
mock_orchestrator_instance.run = AsyncMock(return_value={
|
||||
"completed_tasks": ["1"],
|
||||
"tasks": [{"id": 1}],
|
||||
"errors": [],
|
||||
"iteration_count": 1,
|
||||
})
|
||||
|
||||
args = MagicMock()
|
||||
args.prompt = "Build a REST API"
|
||||
args.repo_path = "/tmp/test-repo"
|
||||
|
||||
config = {
|
||||
"api_key": "sk-test-key",
|
||||
"auth_token": "",
|
||||
"langsmith_api_key": "",
|
||||
"langsmith_project": "app-factory",
|
||||
}
|
||||
|
||||
with patch("app_factory.core.observability.ObservabilityManager") as mock_obs, \
|
||||
patch("app_factory.core.workspace.WorkspaceManager") as mock_ws, \
|
||||
patch("app_factory.core.architecture_tracker.ArchitectureTracker") as mock_arch, \
|
||||
patch("app_factory.agents.pm_agent.PMAgent") as mock_pm, \
|
||||
patch("app_factory.agents.task_agent.TaskMasterAgent") as mock_task, \
|
||||
patch("app_factory.agents.dev_agent.DevAgentManager") as mock_dev, \
|
||||
patch("app_factory.agents.qa_agent.QAAgent") as mock_qa, \
|
||||
patch("app_factory.core.graph.AppFactoryOrchestrator") as mock_orch, \
|
||||
patch("main.GracefulShutdown", create=True):
|
||||
mock_orch.return_value = mock_orchestrator_instance
|
||||
mock_ws.return_value = MagicMock(docker_client=MagicMock())
|
||||
|
||||
result = await run_factory(args, config)
|
||||
|
||||
assert result["completed_tasks"] == ["1"]
|
||||
mock_obs.assert_called_once()
|
||||
mock_ws.assert_called_once_with(repo_path="/tmp/test-repo")
|
||||
mock_arch.assert_called_once_with(
|
||||
api_key="sk-test-key",
|
||||
auth_token=None,
|
||||
debug=False,
|
||||
observability=mock_obs.return_value,
|
||||
)
|
||||
mock_pm.assert_called_once_with(
|
||||
api_key="sk-test-key",
|
||||
auth_token=None,
|
||||
debug=False,
|
||||
observability=mock_obs.return_value,
|
||||
)
|
||||
mock_task.assert_called_once_with(project_root="/tmp/test-repo")
|
||||
mock_dev.assert_called_once()
|
||||
mock_qa.assert_called_once_with(
|
||||
repo_path="/tmp/test-repo",
|
||||
api_key="sk-test-key",
|
||||
auth_token=None,
|
||||
debug=False,
|
||||
observability=mock_obs.return_value,
|
||||
)
|
||||
mock_orch.assert_called_once()
|
||||
mock_orchestrator_instance.run.assert_awaited_once_with("Build a REST API")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# main() integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMainEntryPoint:
|
||||
def test_dry_run_validates_without_executing(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment") as mock_validate, \
|
||||
patch("builtins.print") as mock_print:
|
||||
mock_args.return_value = MagicMock(
|
||||
prompt="test", debug=False, dry_run=True,
|
||||
)
|
||||
mock_validate.return_value = {"api_key": "sk-test"}
|
||||
main()
|
||||
|
||||
mock_print.assert_called_with(
|
||||
"Dry-run: configuration is valid. All checks passed."
|
||||
)
|
||||
|
||||
def test_configuration_error_exits(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment", side_effect=ConfigurationError("no key")), \
|
||||
patch("builtins.print"), \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_args.return_value = MagicMock(debug=False, dry_run=False)
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_docker_error_exits(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment", side_effect=DockerDaemonError("not running")), \
|
||||
patch("builtins.print"), \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_args.return_value = MagicMock(debug=False, dry_run=False)
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_git_error_exits(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment", side_effect=GitError("no git")), \
|
||||
patch("builtins.print"), \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_args.return_value = MagicMock(debug=False, dry_run=False)
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_clarification_timeout_exits(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment") as mock_validate, \
|
||||
patch("main.asyncio.run", side_effect=ClarificationTimeout("task 5")), \
|
||||
patch("builtins.print"), \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_args.return_value = MagicMock(
|
||||
prompt="test", debug=False, dry_run=False,
|
||||
)
|
||||
mock_validate.return_value = {"api_key": "sk-test"}
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_generic_exception_exits(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment") as mock_validate, \
|
||||
patch("main.asyncio.run", side_effect=RuntimeError("boom")), \
|
||||
patch("builtins.print"), \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_args.return_value = MagicMock(
|
||||
prompt="test", debug=False, dry_run=False,
|
||||
)
|
||||
mock_validate.return_value = {"api_key": "sk-test"}
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_debug_flag_sets_logging(self):
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment") as mock_validate, \
|
||||
patch("main.logging.basicConfig") as mock_logging, \
|
||||
patch("builtins.print"):
|
||||
mock_args.return_value = MagicMock(
|
||||
prompt="test", debug=True, dry_run=True,
|
||||
)
|
||||
mock_validate.return_value = {"api_key": "sk-test"}
|
||||
main()
|
||||
|
||||
mock_logging.assert_called_once()
|
||||
call_kwargs = mock_logging.call_args[1]
|
||||
assert call_kwargs["level"] == 10 # logging.DEBUG
|
||||
|
||||
def test_successful_run(self):
|
||||
mock_result = {
|
||||
"completed_tasks": ["1"],
|
||||
"tasks": [{"id": 1}],
|
||||
"errors": [],
|
||||
"iteration_count": 3,
|
||||
}
|
||||
with patch("main.load_dotenv"), \
|
||||
patch("main.parse_args") as mock_args, \
|
||||
patch("main.validate_environment") as mock_validate, \
|
||||
patch("main.asyncio.run", return_value=mock_result), \
|
||||
patch("main.print_summary") as mock_summary:
|
||||
mock_args.return_value = MagicMock(
|
||||
prompt="test", debug=False, dry_run=False,
|
||||
)
|
||||
mock_validate.return_value = {"api_key": "sk-test"}
|
||||
main()
|
||||
|
||||
mock_summary.assert_called_once()
|
||||
# Verify the result was passed to print_summary
|
||||
assert mock_summary.call_args[0][0] == mock_result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception hierarchy tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
def test_all_exceptions_inherit_from_base(self):
|
||||
assert issubclass(ClarificationTimeout, AppFactoryError)
|
||||
assert issubclass(DockerDaemonError, AppFactoryError)
|
||||
assert issubclass(GitError, AppFactoryError)
|
||||
assert issubclass(MCPConnectionError, AppFactoryError)
|
||||
assert issubclass(ConfigurationError, AppFactoryError)
|
||||
|
||||
def test_base_inherits_from_exception(self):
|
||||
assert issubclass(AppFactoryError, Exception)
|
||||
426
tests/test_observability.py
Normal file
426
tests/test_observability.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Tests for ObservabilityManager."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.core.observability import ObservabilityManager, _StructuredFormatter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def obs():
|
||||
"""ObservabilityManager with LangSmith disabled (import fails)."""
|
||||
# langsmith may not be installed; simulate import failure
|
||||
with patch.dict(sys.modules, {"langsmith": None}):
|
||||
manager = ObservabilityManager("test-project")
|
||||
assert manager._client is None
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def obs_with_client():
|
||||
"""ObservabilityManager with a mocked LangSmith client."""
|
||||
mock_client = MagicMock()
|
||||
mock_langsmith = MagicMock()
|
||||
mock_langsmith.Client.return_value = mock_client
|
||||
with patch.dict(sys.modules, {"langsmith": mock_langsmith}):
|
||||
manager = ObservabilityManager("test-project")
|
||||
return manager, mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitialization:
|
||||
def test_explicit_project_name(self, obs):
|
||||
assert obs.project_name == "test-project"
|
||||
|
||||
def test_default_project_from_env(self):
|
||||
with patch.dict(os.environ, {"LANGSMITH_PROJECT": "env-project"}):
|
||||
with patch.dict(sys.modules, {"langsmith": None}):
|
||||
manager = ObservabilityManager()
|
||||
assert manager.project_name == "env-project"
|
||||
|
||||
def test_fallback_project_name(self):
|
||||
env = os.environ.copy()
|
||||
env.pop("LANGSMITH_PROJECT", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
with patch.dict(sys.modules, {"langsmith": None}):
|
||||
manager = ObservabilityManager()
|
||||
assert manager.project_name == "app-factory"
|
||||
|
||||
def test_graceful_degradation_no_langsmith(self, obs):
|
||||
assert obs._client is None
|
||||
|
||||
def test_logger_created(self, obs):
|
||||
assert obs.logger is not None
|
||||
assert obs.logger.level == logging.DEBUG
|
||||
assert obs.logger.propagate is False
|
||||
|
||||
def test_claude_event_mode_from_env(self):
|
||||
with patch.dict(os.environ, {"APP_FACTORY_CLAUDE_EVENT_MODE": "verbose"}):
|
||||
with patch.dict(sys.modules, {"langsmith": None}):
|
||||
manager = ObservabilityManager("test-project")
|
||||
assert manager._claude_event_mode == "verbose"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tracing lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTracing:
|
||||
def test_start_trace_returns_run_id(self, obs):
|
||||
run_id = obs.start_trace("pm_agent", "1.2")
|
||||
assert isinstance(run_id, str)
|
||||
assert len(run_id) == 32 # hex uuid
|
||||
|
||||
def test_start_and_end_trace(self, obs):
|
||||
run_id = obs.start_trace("pm_agent", "1.2", inputs={"key": "val"})
|
||||
obs.end_trace(run_id, outputs={"result": "ok"})
|
||||
assert run_id not in obs._active_runs
|
||||
|
||||
def test_end_trace_with_error(self, obs):
|
||||
run_id = obs.start_trace("pm_agent", "1.2")
|
||||
obs.end_trace(run_id, error="something went wrong")
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_errors"] == 1
|
||||
|
||||
def test_langsmith_create_run_called(self, obs_with_client):
|
||||
manager, mock_client = obs_with_client
|
||||
manager.start_trace("agent", "1.1", inputs={"x": 1})
|
||||
mock_client.create_run.assert_called_once()
|
||||
|
||||
def test_langsmith_update_run_called(self, obs_with_client):
|
||||
manager, mock_client = obs_with_client
|
||||
run_id = manager.start_trace("agent", "1.1")
|
||||
manager.end_trace(run_id, outputs={"y": 2})
|
||||
mock_client.update_run.assert_called_once()
|
||||
|
||||
def test_langsmith_failure_does_not_raise(self):
|
||||
mock_client = MagicMock()
|
||||
mock_client.create_run.side_effect = Exception("network error")
|
||||
mock_langsmith = MagicMock()
|
||||
mock_langsmith.Client.return_value = mock_client
|
||||
with patch.dict(sys.modules, {"langsmith": mock_langsmith}):
|
||||
manager = ObservabilityManager("test-project")
|
||||
# Should not raise
|
||||
run_id = manager.start_trace("agent", "1.1")
|
||||
assert isinstance(run_id, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decorator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDecorator:
|
||||
def test_sync_decorator(self, obs):
|
||||
@obs.trace_agent_execution("agent", "1.1")
|
||||
def my_func(x):
|
||||
return x * 2
|
||||
|
||||
result = my_func(5)
|
||||
assert result == 10
|
||||
assert obs.get_metrics()["total_traces"] == 1
|
||||
|
||||
def test_async_decorator(self, obs):
|
||||
@obs.trace_agent_execution("agent", "1.1")
|
||||
async def my_async_func(x):
|
||||
return x + 1
|
||||
|
||||
result = asyncio.run(my_async_func(10))
|
||||
assert result == 11
|
||||
assert obs.get_metrics()["total_traces"] == 1
|
||||
|
||||
def test_decorator_records_error(self, obs):
|
||||
@obs.trace_agent_execution("agent", "1.1")
|
||||
def failing():
|
||||
raise ValueError("boom")
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
failing()
|
||||
assert obs.get_metrics()["total_errors"] == 1
|
||||
|
||||
def test_decorator_preserves_function_name(self, obs):
|
||||
@obs.trace_agent_execution("agent", "1.1")
|
||||
def my_named_func():
|
||||
pass
|
||||
|
||||
assert my_named_func.__name__ == "my_named_func"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAsyncContextManager:
|
||||
def test_trace_context_success(self, obs):
|
||||
async def run():
|
||||
async with obs.trace_context("agent", "1.1") as run_id:
|
||||
assert isinstance(run_id, str)
|
||||
return run_id
|
||||
|
||||
run_id = asyncio.run(run())
|
||||
assert run_id not in obs._active_runs
|
||||
assert obs.get_metrics()["total_traces"] == 1
|
||||
|
||||
def test_trace_context_exception(self, obs):
|
||||
async def run():
|
||||
with pytest.raises(RuntimeError):
|
||||
async with obs.trace_context("agent", "1.1") as run_id:
|
||||
raise RuntimeError("test error")
|
||||
return run_id
|
||||
|
||||
asyncio.run(run())
|
||||
assert obs.get_metrics()["total_errors"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async trace_agent helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTraceAgent:
|
||||
def test_trace_agent_success(self, obs):
|
||||
async def do_test():
|
||||
async def my_coro():
|
||||
return 42
|
||||
return await obs.trace_agent("agent", "1.1", my_coro)
|
||||
|
||||
result = asyncio.run(do_test())
|
||||
assert result == 42
|
||||
assert obs.get_metrics()["total_traces"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLoggingFormat:
|
||||
def test_structured_formatter(self):
|
||||
formatter = _StructuredFormatter(use_color=False)
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="hello world",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.agent_name = "PM_AGENT"
|
||||
record.task_id = "task-1.2"
|
||||
formatted = formatter.format(record)
|
||||
assert "[PM_AGENT]" in formatted
|
||||
assert "[task-1.2]" in formatted
|
||||
assert "[INFO]" in formatted
|
||||
assert "hello world" in formatted
|
||||
# Check ISO timestamp pattern
|
||||
assert formatted.startswith("[20")
|
||||
assert "\x1b[" not in formatted
|
||||
|
||||
def test_formatter_defaults(self):
|
||||
formatter = _StructuredFormatter(use_color=False)
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.WARNING,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="warn msg",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
formatted = formatter.format(record)
|
||||
assert "[SYSTEM]" in formatted
|
||||
assert "[-]" in formatted
|
||||
|
||||
def test_formatter_can_render_colors(self):
|
||||
formatter = _StructuredFormatter(use_color=True)
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="Trace started: run_id=abc123",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.agent_name = "PM_AGENT"
|
||||
record.task_id = "task-1.2"
|
||||
formatted = formatter.format(record)
|
||||
assert "\x1b[" in formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging methods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLoggingMethods:
|
||||
class _CaptureHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.messages: list[str] = []
|
||||
|
||||
def emit(self, record):
|
||||
self.messages.append(record.getMessage())
|
||||
|
||||
def test_log_state_transition(self, obs):
|
||||
capture = self._CaptureHandler()
|
||||
obs.logger.addHandler(capture)
|
||||
obs.log_state_transition("idle", "processing", {"reason": "new task"})
|
||||
obs.logger.removeHandler(capture)
|
||||
assert any("idle -> processing" in m for m in capture.messages)
|
||||
|
||||
def test_log_token_usage(self, obs):
|
||||
obs.log_token_usage("pm_agent", "1.1", input_tokens=100, output_tokens=50, model="claude-3")
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_tokens"] == 150
|
||||
assert metrics["per_agent"]["pm_agent"]["tokens"] == 150
|
||||
|
||||
def test_log_error(self, obs):
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError as e:
|
||||
obs.log_error("pm_agent", "1.1", e, context={"step": "parse"})
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_errors"] == 1
|
||||
assert metrics["per_agent"]["pm_agent"]["errors"] == 1
|
||||
|
||||
def test_log_claude_event(self, obs):
|
||||
obs.log_claude_event("pm_agent", "1.1", "tool_use", {"tool_name": "Bash"})
|
||||
obs.log_claude_event("pm_agent", "1.1", "result_message", {"subtype": "success"})
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_claude_events"] == 2
|
||||
assert metrics["total_tool_calls"] == 1
|
||||
assert metrics["per_agent"]["pm_agent"]["claude_events"] == 2
|
||||
assert metrics["per_agent"]["pm_agent"]["tool_calls"] == 1
|
||||
|
||||
def test_log_claude_event_readable_format(self, obs):
|
||||
capture = self._CaptureHandler()
|
||||
obs.logger.addHandler(capture)
|
||||
obs.log_claude_event(
|
||||
"pm_agent",
|
||||
"1.1",
|
||||
"tool_use",
|
||||
{"tool_name": "Bash", "tool_use_id": "toolu_1", "tool_input": {"command": "pwd"}},
|
||||
)
|
||||
obs.logger.removeHandler(capture)
|
||||
combined = "\n".join(capture.messages)
|
||||
assert "Claude tool call: Bash" in combined
|
||||
assert "command=pwd" in combined
|
||||
|
||||
def test_log_claude_event_noisy_tool_result_is_suppressed_on_success(self, obs):
|
||||
capture = self._CaptureHandler()
|
||||
obs.logger.addHandler(capture)
|
||||
obs.log_claude_event(
|
||||
"pm_agent",
|
||||
"1.1",
|
||||
"tool_use",
|
||||
{"tool_name": "Read", "tool_use_id": "toolu_1", "tool_input": {"file_path": "README.md"}},
|
||||
)
|
||||
obs.log_claude_event(
|
||||
"pm_agent",
|
||||
"1.1",
|
||||
"tool_result",
|
||||
{"tool_use_id": "toolu_1", "is_error": False, "content": "x" * 1000},
|
||||
)
|
||||
obs.logger.removeHandler(capture)
|
||||
combined = "\n".join(capture.messages)
|
||||
assert "Claude tool call: Read path=README.md" in combined
|
||||
assert "Claude tool result: Read" not in combined
|
||||
|
||||
def test_log_claude_event_noisy_tool_result_logs_errors_with_context(self, obs):
|
||||
capture = self._CaptureHandler()
|
||||
obs.logger.addHandler(capture)
|
||||
obs.log_claude_event(
|
||||
"pm_agent",
|
||||
"1.1",
|
||||
"tool_use",
|
||||
{"tool_name": "Read", "tool_use_id": "toolu_1", "tool_input": {"file_path": "README.md"}},
|
||||
)
|
||||
obs.log_claude_event(
|
||||
"pm_agent",
|
||||
"1.1",
|
||||
"tool_result",
|
||||
{"tool_use_id": "toolu_1", "is_error": True, "content": "permission denied"},
|
||||
)
|
||||
obs.logger.removeHandler(capture)
|
||||
combined = "\n".join(capture.messages)
|
||||
assert "Claude tool result: Read status=error path=README.md error=permission denied" in combined
|
||||
|
||||
def test_log_claude_event_filters_noise_by_default(self, obs):
|
||||
capture = self._CaptureHandler()
|
||||
obs.logger.addHandler(capture)
|
||||
obs.log_claude_event("pm_agent", "1.1", "stream_event", {"event": "delta"})
|
||||
obs.logger.removeHandler(capture)
|
||||
combined = "\n".join(capture.messages)
|
||||
assert "stream_event" not in combined
|
||||
|
||||
def test_log_claude_event_verbose_mode_includes_noise(self):
|
||||
with patch.dict(sys.modules, {"langsmith": None}):
|
||||
manager = ObservabilityManager("test-project", claude_event_mode="verbose")
|
||||
|
||||
capture = self._CaptureHandler()
|
||||
manager.logger.addHandler(capture)
|
||||
manager.log_claude_event("pm_agent", "1.1", "stream_event", {"event": "delta"})
|
||||
manager.logger.removeHandler(capture)
|
||||
combined = "\n".join(capture.messages)
|
||||
assert "Claude event: type=stream_event" in combined
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetrics:
|
||||
def test_initial_metrics(self, obs):
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_tokens"] == 0
|
||||
assert metrics["total_traces"] == 0
|
||||
assert metrics["total_errors"] == 0
|
||||
assert metrics["total_claude_events"] == 0
|
||||
assert metrics["total_tool_calls"] == 0
|
||||
|
||||
def test_metrics_accumulate(self, obs):
|
||||
obs.start_trace("a1", "1.1")
|
||||
obs.start_trace("a1", "1.2")
|
||||
obs.start_trace("a2", "2.1")
|
||||
obs.log_token_usage("a1", "1.1", 10, 20)
|
||||
obs.log_token_usage("a2", "2.1", 5, 5)
|
||||
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_traces"] == 3
|
||||
assert metrics["total_tokens"] == 40
|
||||
assert metrics["per_agent"]["a1"]["traces"] == 2
|
||||
assert metrics["per_agent"]["a2"]["traces"] == 1
|
||||
assert metrics["per_agent"]["a1"]["tokens"] == 30
|
||||
assert metrics["per_agent"]["a2"]["tokens"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graceful degradation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGracefulDegradation:
|
||||
def test_all_operations_work_without_langsmith(self, obs):
|
||||
"""Every public method should work fine with _client=None."""
|
||||
assert obs._client is None
|
||||
|
||||
run_id = obs.start_trace("agent", "1.1", inputs={"a": 1})
|
||||
obs.end_trace(run_id, outputs={"b": 2})
|
||||
obs.end_trace("nonexistent-id", error="whatever")
|
||||
obs.log_state_transition("a", "b")
|
||||
obs.log_token_usage("agent", "1.1", 10, 20)
|
||||
try:
|
||||
raise RuntimeError("test")
|
||||
except RuntimeError as e:
|
||||
obs.log_error("agent", "1.1", e)
|
||||
metrics = obs.get_metrics()
|
||||
assert metrics["total_traces"] == 1
|
||||
257
tests/test_pm_agent.py
Normal file
257
tests/test_pm_agent.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Tests for PMAgent."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.agents.pm_agent import PMAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_api_response(text, input_tokens=10, output_tokens=20):
|
||||
"""Build a fake Claude SDK completion response."""
|
||||
return SimpleNamespace(text=text, input_tokens=input_tokens, output_tokens=output_tokens)
|
||||
|
||||
|
||||
def _build_agent(**kwargs):
|
||||
"""Create a PMAgent with a mocked Claude SDK client."""
|
||||
with patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
||||
mock_client = AsyncMock()
|
||||
mock_mod.return_value = mock_client
|
||||
agent = PMAgent(api_key="test-key", **kwargs)
|
||||
agent.client = mock_client
|
||||
return agent, mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitialization:
|
||||
def test_no_api_key_uses_default_client(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
env = os.environ.copy()
|
||||
env.pop("ANTHROPIC_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
||||
mock_mod.return_value = AsyncMock()
|
||||
agent = PMAgent()
|
||||
mock_mod.assert_called_once_with(
|
||||
api_key=None,
|
||||
auth_token=None,
|
||||
enable_debug=False,
|
||||
)
|
||||
|
||||
def test_api_key_from_param(self):
|
||||
agent, _ = _build_agent()
|
||||
assert agent.model == "claude-opus-4-6"
|
||||
|
||||
def test_custom_model(self):
|
||||
agent, _ = _build_agent(model="claude-opus-4-6")
|
||||
assert agent.model == "claude-opus-4-6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTemplateLoading:
|
||||
def test_load_prd_template(self):
|
||||
agent, _ = _build_agent()
|
||||
template = agent._load_template("pm_prd_expansion.txt")
|
||||
assert "Product Manager" in template
|
||||
assert "Objective" in template
|
||||
|
||||
def test_load_clarification_template(self):
|
||||
agent, _ = _build_agent()
|
||||
template = agent._load_template("pm_clarification.txt")
|
||||
assert "{requesting_agent}" in template
|
||||
assert "ESCALATE_TO_HUMAN" in template
|
||||
|
||||
def test_load_missing_template_raises(self):
|
||||
agent, _ = _build_agent()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
agent._load_template("nonexistent.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PRD expansion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExpandPromptToPrd:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_prd_markdown(self):
|
||||
agent, mock_client = _build_agent()
|
||||
prd_text = (
|
||||
"# Objective\nBuild a todo app\n"
|
||||
"# Core Requirements\n1. Add tasks\n"
|
||||
"# Technical Architecture\nMonolith\n"
|
||||
"# Tech Stack\nPython, FastAPI\n"
|
||||
"# Success Criteria\nAll tests pass\n"
|
||||
"# Non-Functional Requirements\n<1s response"
|
||||
)
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response(prd_text, 15, 100)
|
||||
)
|
||||
|
||||
result = await agent.expand_prompt_to_prd("Build a todo app")
|
||||
|
||||
assert "Objective" in result
|
||||
assert "Core Requirements" in result
|
||||
assert "Technical Architecture" in result
|
||||
assert "Tech Stack" in result
|
||||
assert "Success Criteria" in result
|
||||
assert "Non-Functional Requirements" in result
|
||||
mock_client.complete.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_token_usage(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("prd content", 50, 200)
|
||||
)
|
||||
|
||||
await agent.expand_prompt_to_prd("some input")
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 50
|
||||
assert usage["output_tokens"] == 200
|
||||
assert usage["total_tokens"] == 250
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clarification handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleClarification:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_resolve(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("Use PostgreSQL for the database.")
|
||||
)
|
||||
|
||||
result = await agent.handle_clarification_request({
|
||||
"requesting_agent": "dev_agent",
|
||||
"task_id": "2.1",
|
||||
"question": "Which database should I use?",
|
||||
"context": "PRD says relational DB",
|
||||
})
|
||||
|
||||
assert result == "Use PostgreSQL for the database."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_human(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("ESCALATE_TO_HUMAN")
|
||||
)
|
||||
|
||||
with patch("builtins.input", return_value="Use MySQL") as mock_input:
|
||||
result = await agent.handle_clarification_request({
|
||||
"requesting_agent": "dev_agent",
|
||||
"task_id": "3.1",
|
||||
"question": "Which vendor should we pick?",
|
||||
"context": "",
|
||||
})
|
||||
|
||||
assert result == "Use MySQL"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_tokens(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("answer", 5, 10)
|
||||
)
|
||||
|
||||
await agent.handle_clarification_request({
|
||||
"requesting_agent": "qa",
|
||||
"task_id": "1.1",
|
||||
"question": "q",
|
||||
"context": "c",
|
||||
})
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 5
|
||||
assert usage["output_tokens"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PRD updates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdatePrd:
|
||||
def test_appends_with_version_header(self):
|
||||
agent, _ = _build_agent()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
||||
f.write("# Original PRD\nSome content\n")
|
||||
prd_path = f.name
|
||||
|
||||
try:
|
||||
agent.update_prd(prd_path, "Added authentication requirement.")
|
||||
|
||||
with open(prd_path) as f:
|
||||
content = f.read()
|
||||
|
||||
assert "# Original PRD" in content
|
||||
assert "## PRD Update -" in content
|
||||
assert "Added authentication requirement." in content
|
||||
finally:
|
||||
os.unlink(prd_path)
|
||||
|
||||
def test_multiple_updates_maintain_history(self):
|
||||
agent, _ = _build_agent()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
||||
f.write("# PRD v1\n")
|
||||
prd_path = f.name
|
||||
|
||||
try:
|
||||
agent.update_prd(prd_path, "Update 1")
|
||||
agent.update_prd(prd_path, "Update 2")
|
||||
|
||||
with open(prd_path) as f:
|
||||
content = f.read()
|
||||
|
||||
assert content.count("## PRD Update -") == 2
|
||||
assert "Update 1" in content
|
||||
assert "Update 2" in content
|
||||
finally:
|
||||
os.unlink(prd_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenUsage:
|
||||
def test_initial_usage_zero(self):
|
||||
agent, _ = _build_agent()
|
||||
usage = agent.get_token_usage()
|
||||
assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accumulates_across_calls(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
side_effect=[
|
||||
_make_api_response("r1", 10, 20),
|
||||
_make_api_response("r2", 30, 40),
|
||||
]
|
||||
)
|
||||
|
||||
await agent.expand_prompt_to_prd("first call")
|
||||
await agent.expand_prompt_to_prd("second call")
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 40
|
||||
assert usage["output_tokens"] == 60
|
||||
assert usage["total_tokens"] == 100
|
||||
550
tests/test_qa_agent.py
Normal file
550
tests/test_qa_agent.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Tests for QAAgent."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import git as gitmod
|
||||
import pytest
|
||||
|
||||
from app_factory.agents.qa_agent import QAAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_review_response(text, input_tokens=10, output_tokens=20):
|
||||
"""Build a fake Claude SDK completion response."""
|
||||
return SimpleNamespace(text=text, input_tokens=input_tokens, output_tokens=output_tokens)
|
||||
|
||||
|
||||
def _build_agent(repo_path="/fake/repo", **kwargs):
|
||||
"""Create a QAAgent with mocked git.Repo and Claude SDK client."""
|
||||
with patch("app_factory.agents.qa_agent.git.Repo") as mock_repo_cls, \
|
||||
patch("app_factory.agents.qa_agent.ClaudeSDKClient") as mock_sdk_client:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
mock_client = AsyncMock()
|
||||
mock_sdk_client.return_value = mock_client
|
||||
agent = QAAgent(repo_path=repo_path, api_key="test-key", **kwargs)
|
||||
agent.client = mock_client
|
||||
agent.repo = mock_repo
|
||||
return agent, mock_repo, mock_client
|
||||
|
||||
|
||||
APPROVED_REVIEW = """\
|
||||
APPROVED: true
|
||||
ISSUES:
|
||||
- [severity: info] Minor style suggestion
|
||||
SUMMARY: Code looks good overall."""
|
||||
|
||||
REJECTED_REVIEW = """\
|
||||
APPROVED: false
|
||||
ISSUES:
|
||||
- [severity: critical] SQL injection in query builder
|
||||
- [severity: warning] Missing input validation
|
||||
SUMMARY: Critical security issue found."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitialization:
|
||||
def test_no_api_key_uses_default_client(self):
|
||||
with patch("app_factory.agents.qa_agent.git.Repo"), \
|
||||
patch("app_factory.agents.qa_agent.ClaudeSDKClient") as mock_sdk_client, \
|
||||
patch.dict(os.environ, {}, clear=True):
|
||||
env = os.environ.copy()
|
||||
env.pop("ANTHROPIC_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
mock_sdk_client.return_value = AsyncMock()
|
||||
agent = QAAgent(repo_path="/fake")
|
||||
mock_sdk_client.assert_called_once_with(
|
||||
api_key=None,
|
||||
auth_token=None,
|
||||
enable_debug=False,
|
||||
)
|
||||
|
||||
def test_creates_with_api_key(self):
|
||||
agent, mock_repo, _ = _build_agent()
|
||||
assert agent.max_retries == 3
|
||||
|
||||
def test_custom_max_retries(self):
|
||||
agent, _, _ = _build_agent(max_retries=5)
|
||||
assert agent.max_retries == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rebase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRebaseOntoMain:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebase_success(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
with patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo):
|
||||
result = await agent.rebase_onto_main("/worktree/path", "task-1")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["conflicts"] == []
|
||||
mock_wt_repo.git.rebase.assert_called_once_with("main")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebase_conflict_unresolvable(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
mock_wt_repo.git.rebase.side_effect = gitmod.GitCommandError("rebase", "CONFLICT")
|
||||
mock_wt_repo.git.status.return_value = "UU conflicted_file.py"
|
||||
|
||||
with patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo), \
|
||||
patch.object(agent, "auto_resolve_conflicts", return_value=False):
|
||||
result = await agent.rebase_onto_main("/worktree/path", "task-1")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "conflicted_file.py" in result["conflicts"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebase_conflict_auto_resolved(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
mock_wt_repo.git.rebase.side_effect = gitmod.GitCommandError("rebase", "CONFLICT")
|
||||
mock_wt_repo.git.status.return_value = "UU file.py"
|
||||
|
||||
with patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo), \
|
||||
patch.object(agent, "auto_resolve_conflicts", return_value=True):
|
||||
result = await agent.rebase_onto_main("/worktree/path", "task-1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_failure_continues(self):
|
||||
"""If fetch fails (no remote), rebase should still be attempted."""
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
mock_wt_repo.git.fetch.side_effect = gitmod.GitCommandError("fetch", "No remote")
|
||||
|
||||
with patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo):
|
||||
result = await agent.rebase_onto_main("/worktree/path", "task-1")
|
||||
|
||||
assert result["success"] is True
|
||||
mock_wt_repo.git.rebase.assert_called_once_with("main")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Linter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunLinter:
|
||||
def test_lint_passes(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_result = subprocess.CompletedProcess(
|
||||
args=["ruff", "check", "."],
|
||||
returncode=0,
|
||||
stdout="All checks passed!\n",
|
||||
stderr="",
|
||||
)
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", return_value=mock_result):
|
||||
result = agent.run_linter("/worktree/path")
|
||||
|
||||
assert result["passed"] is True
|
||||
assert result["errors"] == []
|
||||
|
||||
def test_lint_fails_with_errors(self):
|
||||
agent, _, _ = _build_agent()
|
||||
ruff_output = (
|
||||
"app/main.py:10:1: E501 Line too long (120 > 88 characters)\n"
|
||||
"app/main.py:15:5: F841 Local variable 'x' is assigned but never used\n"
|
||||
"Found 2 errors.\n"
|
||||
)
|
||||
mock_result = subprocess.CompletedProcess(
|
||||
args=["ruff", "check", "."],
|
||||
returncode=1,
|
||||
stdout=ruff_output,
|
||||
stderr="",
|
||||
)
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", return_value=mock_result):
|
||||
result = agent.run_linter("/worktree/path")
|
||||
|
||||
assert result["passed"] is False
|
||||
assert len(result["errors"]) == 2
|
||||
|
||||
def test_lint_ruff_not_found(self):
|
||||
agent, _, _ = _build_agent()
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", side_effect=FileNotFoundError):
|
||||
result = agent.run_linter("/worktree/path")
|
||||
|
||||
assert result["passed"] is True
|
||||
assert "ruff not found" in result["warnings"][0]
|
||||
|
||||
def test_lint_timeout(self):
|
||||
agent, _, _ = _build_agent()
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run",
|
||||
side_effect=subprocess.TimeoutExpired(cmd="ruff", timeout=120)):
|
||||
result = agent.run_linter("/worktree/path")
|
||||
|
||||
assert result["passed"] is False
|
||||
assert "timed out" in result["errors"][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunTests:
|
||||
def test_all_tests_pass(self):
|
||||
agent, _, _ = _build_agent()
|
||||
pytest_output = (
|
||||
"tests/test_foo.py::test_one PASSED\n"
|
||||
"tests/test_foo.py::test_two PASSED\n"
|
||||
"========================= 2 passed =========================\n"
|
||||
)
|
||||
mock_result = subprocess.CompletedProcess(
|
||||
args=["python", "-m", "pytest"],
|
||||
returncode=0,
|
||||
stdout=pytest_output,
|
||||
stderr="",
|
||||
)
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", return_value=mock_result):
|
||||
result = agent.run_tests("/worktree/path")
|
||||
|
||||
assert result["passed"] is True
|
||||
assert result["total"] == 2
|
||||
assert result["failures"] == 0
|
||||
assert result["errors"] == 0
|
||||
|
||||
def test_some_tests_fail(self):
|
||||
agent, _, _ = _build_agent()
|
||||
pytest_output = (
|
||||
"tests/test_foo.py::test_one PASSED\n"
|
||||
"tests/test_foo.py::test_two FAILED\n"
|
||||
"=================== 1 failed, 1 passed ====================\n"
|
||||
)
|
||||
mock_result = subprocess.CompletedProcess(
|
||||
args=["python", "-m", "pytest"],
|
||||
returncode=1,
|
||||
stdout=pytest_output,
|
||||
stderr="",
|
||||
)
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", return_value=mock_result):
|
||||
result = agent.run_tests("/worktree/path")
|
||||
|
||||
assert result["passed"] is False
|
||||
assert result["total"] == 2
|
||||
assert result["failures"] == 1
|
||||
|
||||
def test_pytest_not_found(self):
|
||||
agent, _, _ = _build_agent()
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run", side_effect=FileNotFoundError):
|
||||
result = agent.run_tests("/worktree/path")
|
||||
|
||||
assert result["passed"] is False
|
||||
assert "pytest not found" in result["output"]
|
||||
|
||||
def test_pytest_timeout(self):
|
||||
agent, _, _ = _build_agent()
|
||||
with patch("app_factory.agents.qa_agent.subprocess.run",
|
||||
side_effect=subprocess.TimeoutExpired(cmd="pytest", timeout=300)):
|
||||
result = agent.run_tests("/worktree/path")
|
||||
|
||||
assert result["passed"] is False
|
||||
assert "timed out" in result["output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code Review
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCodeReview:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_approved(self):
|
||||
agent, _, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_review_response(APPROVED_REVIEW)
|
||||
)
|
||||
|
||||
result = await agent.code_review("diff content", task={"id": "1", "title": "Add feature"})
|
||||
|
||||
assert result["approved"] is True
|
||||
assert len(result["issues"]) == 1
|
||||
assert result["issues"][0]["severity"] == "info"
|
||||
assert result["summary"] != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_rejected(self):
|
||||
agent, _, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_review_response(REJECTED_REVIEW)
|
||||
)
|
||||
|
||||
result = await agent.code_review("diff with issues")
|
||||
|
||||
assert result["approved"] is False
|
||||
assert len(result["issues"]) == 2
|
||||
assert result["issues"][0]["severity"] == "critical"
|
||||
assert "security" in result["summary"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_no_task_context(self):
|
||||
agent, _, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_review_response(APPROVED_REVIEW)
|
||||
)
|
||||
|
||||
result = await agent.code_review("diff content", task=None)
|
||||
|
||||
assert result["approved"] is True
|
||||
mock_client.complete.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_loads_template(self):
|
||||
agent, _, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_review_response(APPROVED_REVIEW)
|
||||
)
|
||||
|
||||
await agent.code_review("some diff")
|
||||
|
||||
call_args = mock_client.complete.call_args
|
||||
prompt_text = call_args.kwargs["prompt"]
|
||||
assert "Review Checklist" in prompt_text
|
||||
assert "OWASP" in prompt_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Merge to main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMergeToMain:
|
||||
def test_merge_success(self):
|
||||
agent, mock_repo, _ = _build_agent()
|
||||
mock_repo.head.commit.hexsha = "abc123def456"
|
||||
|
||||
result = agent.merge_to_main("/worktree/path", "42")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["commit_sha"] == "abc123def456"
|
||||
mock_repo.git.checkout.assert_called_once_with("main")
|
||||
mock_repo.git.merge.assert_called_once_with(
|
||||
"--no-ff", "feature/task-42", m="Merge feature/task-42"
|
||||
)
|
||||
|
||||
def test_merge_failure(self):
|
||||
agent, mock_repo, _ = _build_agent()
|
||||
mock_repo.git.merge.side_effect = gitmod.GitCommandError("merge", "conflict")
|
||||
|
||||
result = agent.merge_to_main("/worktree/path", "42")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["commit_sha"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full pipeline: review_and_merge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReviewAndMerge:
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path(self):
|
||||
agent, mock_repo, mock_client = _build_agent()
|
||||
mock_repo.head.commit.hexsha = "merged123"
|
||||
|
||||
mock_wt_repo = MagicMock()
|
||||
mock_wt_repo.git.diff.return_value = "diff --git a/file.py"
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": True, "conflicts": []}), \
|
||||
patch.object(agent, "run_linter",
|
||||
return_value={"passed": True, "errors": [], "warnings": []}), \
|
||||
patch.object(agent, "run_tests",
|
||||
return_value={"passed": True, "total": 5, "failures": 0, "errors": 0, "output": "ok"}), \
|
||||
patch.object(agent, "code_review", new_callable=AsyncMock,
|
||||
return_value={"approved": True, "issues": [], "summary": "All good"}), \
|
||||
patch.object(agent, "merge_to_main",
|
||||
return_value={"success": True, "commit_sha": "merged123"}), \
|
||||
patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo):
|
||||
|
||||
result = await agent.review_and_merge("task-1", "/worktree/path")
|
||||
|
||||
assert result["status"] == "merged"
|
||||
assert result["commit_sha"] == "merged123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebase_failure(self):
|
||||
agent, _, _ = _build_agent()
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": False, "conflicts": ["file.py"]}):
|
||||
result = await agent.review_and_merge("task-2", "/worktree/path")
|
||||
|
||||
assert result["status"] == "rebase_failed"
|
||||
assert "file.py" in result["conflicts"]
|
||||
assert result["retry_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lint_failure(self):
|
||||
agent, _, _ = _build_agent()
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": True, "conflicts": []}), \
|
||||
patch.object(agent, "run_linter",
|
||||
return_value={"passed": False, "errors": ["E501 line too long"], "warnings": []}):
|
||||
result = await agent.review_and_merge("task-3", "/worktree/path")
|
||||
|
||||
assert result["status"] == "lint_failed"
|
||||
assert len(result["errors"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_failure(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": True, "conflicts": []}), \
|
||||
patch.object(agent, "run_linter",
|
||||
return_value={"passed": True, "errors": [], "warnings": []}), \
|
||||
patch.object(agent, "run_tests",
|
||||
return_value={"passed": False, "total": 3, "failures": 1, "errors": 0, "output": "FAILED"}):
|
||||
result = await agent.review_and_merge("task-4", "/worktree/path")
|
||||
|
||||
assert result["status"] == "tests_failed"
|
||||
assert result["failures"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_failure(self):
|
||||
agent, _, _ = _build_agent()
|
||||
mock_wt_repo = MagicMock()
|
||||
mock_wt_repo.git.diff.return_value = "diff"
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": True, "conflicts": []}), \
|
||||
patch.object(agent, "run_linter",
|
||||
return_value={"passed": True, "errors": [], "warnings": []}), \
|
||||
patch.object(agent, "run_tests",
|
||||
return_value={"passed": True, "total": 3, "failures": 0, "errors": 0, "output": "ok"}), \
|
||||
patch.object(agent, "code_review", new_callable=AsyncMock,
|
||||
return_value={"approved": False, "issues": [{"severity": "critical", "description": "vuln"}], "summary": "Bad"}), \
|
||||
patch("app_factory.agents.qa_agent.git.Repo", return_value=mock_wt_repo):
|
||||
result = await agent.review_and_merge("task-5", "/worktree/path")
|
||||
|
||||
assert result["status"] == "review_failed"
|
||||
assert len(result["issues"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse test results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseTestResults:
|
||||
def test_all_passed(self):
|
||||
agent, _, _ = _build_agent()
|
||||
output = "========================= 5 passed =========================\n"
|
||||
result = agent.parse_test_results(output)
|
||||
assert result["passed"] is True
|
||||
assert result["total"] == 5
|
||||
assert result["failures"] == 0
|
||||
assert result["errors"] == 0
|
||||
|
||||
def test_mixed_results(self):
|
||||
agent, _, _ = _build_agent()
|
||||
output = "================ 1 failed, 4 passed, 1 error ================\n"
|
||||
result = agent.parse_test_results(output)
|
||||
assert result["passed"] is False
|
||||
assert result["total"] == 6
|
||||
assert result["failures"] == 1
|
||||
assert result["errors"] == 1
|
||||
|
||||
def test_all_failed(self):
|
||||
agent, _, _ = _build_agent()
|
||||
output = "========================= 3 failed =========================\n"
|
||||
result = agent.parse_test_results(output)
|
||||
assert result["passed"] is False
|
||||
assert result["total"] == 3
|
||||
assert result["failures"] == 3
|
||||
|
||||
def test_no_tests(self):
|
||||
agent, _, _ = _build_agent()
|
||||
output = "no tests ran\n"
|
||||
result = agent.parse_test_results(output)
|
||||
assert result["passed"] is False
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_errors_only(self):
|
||||
agent, _, _ = _build_agent()
|
||||
output = "========================= 2 error =========================\n"
|
||||
result = agent.parse_test_results(output)
|
||||
assert result["passed"] is False
|
||||
assert result["errors"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry counter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRetryCounter:
|
||||
def test_initial_count_zero(self):
|
||||
agent, _, _ = _build_agent()
|
||||
assert agent.get_retry_count("task-1") == 0
|
||||
|
||||
def test_increment_and_get(self):
|
||||
agent, _, _ = _build_agent()
|
||||
agent._increment_retry("task-1")
|
||||
assert agent.get_retry_count("task-1") == 1
|
||||
agent._increment_retry("task-1")
|
||||
assert agent.get_retry_count("task-1") == 2
|
||||
|
||||
def test_separate_task_counters(self):
|
||||
agent, _, _ = _build_agent()
|
||||
agent._increment_retry("task-1")
|
||||
agent._increment_retry("task-1")
|
||||
agent._increment_retry("task-2")
|
||||
assert agent.get_retry_count("task-1") == 2
|
||||
assert agent.get_retry_count("task-2") == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_failure_increments_counter(self):
|
||||
agent, _, _ = _build_agent()
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": False, "conflicts": ["a.py"]}):
|
||||
await agent.review_and_merge("task-99", "/wt")
|
||||
|
||||
assert agent.get_retry_count("task-99") == 1
|
||||
|
||||
with patch.object(agent, "rebase_onto_main", new_callable=AsyncMock,
|
||||
return_value={"success": False, "conflicts": ["a.py"]}):
|
||||
await agent.review_and_merge("task-99", "/wt")
|
||||
|
||||
assert agent.get_retry_count("task-99") == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Review response parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseReviewResponse:
|
||||
def test_approved_with_info(self):
|
||||
agent, _, _ = _build_agent()
|
||||
result = agent._parse_review_response(APPROVED_REVIEW)
|
||||
assert result["approved"] is True
|
||||
assert len(result["issues"]) == 1
|
||||
assert result["issues"][0]["severity"] == "info"
|
||||
assert "good" in result["summary"].lower()
|
||||
|
||||
def test_rejected_with_critical(self):
|
||||
agent, _, _ = _build_agent()
|
||||
result = agent._parse_review_response(REJECTED_REVIEW)
|
||||
assert result["approved"] is False
|
||||
assert len(result["issues"]) == 2
|
||||
assert result["issues"][0]["severity"] == "critical"
|
||||
assert result["issues"][1]["severity"] == "warning"
|
||||
|
||||
def test_empty_response(self):
|
||||
agent, _, _ = _build_agent()
|
||||
result = agent._parse_review_response("")
|
||||
assert result["approved"] is False
|
||||
assert result["issues"] == []
|
||||
assert result["summary"] == ""
|
||||
273
tests/test_task_agent.py
Normal file
273
tests/test_task_agent.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for TaskMasterAgent."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.agents.task_agent import TaskMasterAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(tmp_path):
|
||||
return TaskMasterAgent(project_root=str(tmp_path))
|
||||
|
||||
|
||||
def _cli_result(data, returncode=0, stderr=""):
|
||||
"""Build a mock subprocess.CompletedProcess returning JSON data."""
|
||||
result = MagicMock(spec=subprocess.CompletedProcess)
|
||||
result.returncode = returncode
|
||||
result.stdout = json.dumps(data) if isinstance(data, dict) else data
|
||||
result.stderr = stderr
|
||||
return result
|
||||
|
||||
|
||||
# --- parse_prd ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_prd_writes_file_and_calls_cli(agent, tmp_path):
|
||||
prd_text = "# My PRD\nBuild a thing."
|
||||
cli_output = {"tasks": [{"id": 1, "title": "Task 1"}], "count": 1}
|
||||
|
||||
with patch("subprocess.run", return_value=_cli_result(cli_output)) as mock_run:
|
||||
result = await agent.parse_prd(prd_text, num_tasks=5)
|
||||
|
||||
prd_path = tmp_path / ".taskmaster" / "docs" / "prd.md"
|
||||
assert prd_path.exists()
|
||||
assert prd_path.read_text() == prd_text
|
||||
|
||||
mock_run.assert_called_once()
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert call_args[0] == "task-master"
|
||||
assert "parse-prd" in call_args
|
||||
assert "--num-tasks" in call_args
|
||||
assert "5" in call_args
|
||||
|
||||
assert result == cli_output
|
||||
|
||||
|
||||
# --- get_unblocked_tasks ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unblocked_tasks_filters_correctly(agent):
|
||||
tasks_data = {
|
||||
"tasks": [
|
||||
{"id": 1, "title": "Done task", "status": "done", "dependencies": []},
|
||||
{"id": 2, "title": "Pending, no deps", "status": "pending", "dependencies": []},
|
||||
{"id": 3, "title": "Pending, dep done", "status": "pending", "dependencies": [1]},
|
||||
{"id": 4, "title": "Pending, dep not done", "status": "pending", "dependencies": [5]},
|
||||
{"id": 5, "title": "In-progress", "status": "in-progress", "dependencies": []},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("subprocess.run", return_value=_cli_result(tasks_data)):
|
||||
unblocked = await agent.get_unblocked_tasks()
|
||||
|
||||
ids = [t["id"] for t in unblocked]
|
||||
assert 2 in ids
|
||||
assert 3 in ids
|
||||
assert 4 not in ids # dependency 5 is not done
|
||||
assert 1 not in ids # status is done, not pending
|
||||
assert 5 not in ids # status is in-progress
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unblocked_tasks_empty(agent):
|
||||
with patch("subprocess.run", return_value=_cli_result({"tasks": []})):
|
||||
unblocked = await agent.get_unblocked_tasks()
|
||||
|
||||
assert unblocked == []
|
||||
|
||||
|
||||
# --- update_task_status ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_status_without_notes(agent):
|
||||
with patch("subprocess.run", return_value=_cli_result({})) as mock_run:
|
||||
await agent.update_task_status("3", "done")
|
||||
|
||||
assert mock_run.call_count == 1
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--id=3" in call_args
|
||||
assert "--status=done" in call_args
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_status_with_notes(agent):
|
||||
with patch("subprocess.run", return_value=_cli_result({})) as mock_run:
|
||||
await agent.update_task_status("3", "in-progress", notes="Started work")
|
||||
|
||||
assert mock_run.call_count == 2
|
||||
first_call = mock_run.call_args_list[0][0][0]
|
||||
second_call = mock_run.call_args_list[1][0][0]
|
||||
assert "set-status" in first_call
|
||||
assert "update-subtask" in second_call
|
||||
assert "--prompt=Started work" in second_call
|
||||
|
||||
|
||||
# --- get_task_details ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_details_returns_correct_structure(agent):
|
||||
task_data = {
|
||||
"task": {
|
||||
"id": 2,
|
||||
"title": "Auth system",
|
||||
"description": "Implement JWT auth",
|
||||
"details": "Use bcrypt for hashing",
|
||||
"testStrategy": "Unit tests for auth",
|
||||
"dependencies": [1],
|
||||
"subtasks": [{"id": "2.1", "title": "Setup JWT"}],
|
||||
"status": "pending",
|
||||
"priority": "high",
|
||||
}
|
||||
}
|
||||
|
||||
with patch("subprocess.run", return_value=_cli_result(task_data)):
|
||||
result = await agent.get_task_details("2")
|
||||
|
||||
assert result["id"] == 2
|
||||
assert result["title"] == "Auth system"
|
||||
assert result["description"] == "Implement JWT auth"
|
||||
assert result["details"] == "Use bcrypt for hashing"
|
||||
assert result["testStrategy"] == "Unit tests for auth"
|
||||
assert result["dependencies"] == [1]
|
||||
assert len(result["subtasks"]) == 1
|
||||
assert result["status"] == "pending"
|
||||
assert result["priority"] == "high"
|
||||
|
||||
|
||||
# --- get_next_task ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_next_task_uses_cli(agent):
|
||||
task_data = {"task": {"id": 3, "title": "Next task", "status": "pending"}}
|
||||
|
||||
with patch("subprocess.run", return_value=_cli_result(task_data)):
|
||||
result = await agent.get_next_task()
|
||||
|
||||
assert result["id"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_next_task_fallback_when_cli_fails(agent):
|
||||
tasks_data = {
|
||||
"tasks": [
|
||||
{"id": 1, "title": "Done", "status": "done", "dependencies": []},
|
||||
{"id": 2, "title": "Low", "status": "pending", "dependencies": [], "priority": "low"},
|
||||
{"id": 3, "title": "High", "status": "pending", "dependencies": [], "priority": "high"},
|
||||
]
|
||||
}
|
||||
|
||||
fail_result = _cli_result("", returncode=1, stderr="error")
|
||||
|
||||
call_count = 0
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
cmd = args[0]
|
||||
if "next" in cmd:
|
||||
return fail_result
|
||||
return _cli_result(tasks_data)
|
||||
|
||||
with patch("subprocess.run", side_effect=side_effect):
|
||||
result = await agent.get_next_task()
|
||||
|
||||
assert result["id"] == 3 # high priority comes first
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_next_task_returns_none_when_all_done(agent):
|
||||
tasks_data = {
|
||||
"tasks": [
|
||||
{"id": 1, "title": "Done", "status": "done", "dependencies": []},
|
||||
]
|
||||
}
|
||||
|
||||
fail_result = _cli_result("", returncode=1, stderr="error")
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
cmd = args[0]
|
||||
if "next" in cmd:
|
||||
return fail_result
|
||||
return _cli_result(tasks_data)
|
||||
|
||||
with patch("subprocess.run", side_effect=side_effect):
|
||||
result = await agent.get_next_task()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- expand_task ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expand_task(agent):
|
||||
expand_data = {"subtasks": [{"id": "1.1"}, {"id": "1.2"}]}
|
||||
|
||||
with patch("subprocess.run", return_value=_cli_result(expand_data)) as mock_run:
|
||||
result = await agent.expand_task("1", num_subtasks=2)
|
||||
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--id=1" in call_args
|
||||
assert "--num=2" in call_args
|
||||
assert "--force" in call_args
|
||||
assert result == expand_data
|
||||
|
||||
|
||||
# --- retry logic ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_succeeds_after_failures(agent):
|
||||
agent.base_delay = 0.01 # speed up test
|
||||
|
||||
fail = _cli_result("", returncode=1, stderr="transient error")
|
||||
success = _cli_result({"tasks": []})
|
||||
|
||||
with patch("subprocess.run", side_effect=[fail, fail, success]):
|
||||
result = await agent.get_unblocked_tasks()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_exhausted_raises(agent):
|
||||
agent.base_delay = 0.01
|
||||
|
||||
fail = _cli_result("", returncode=1, stderr="persistent error")
|
||||
|
||||
with patch("subprocess.run", return_value=fail):
|
||||
with pytest.raises(RuntimeError, match="All 3 attempts failed"):
|
||||
await agent.get_unblocked_tasks()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_exponential_backoff(agent):
|
||||
agent.base_delay = 0.01
|
||||
|
||||
fail = _cli_result("", returncode=1, stderr="error")
|
||||
success = _cli_result({"tasks": []})
|
||||
|
||||
delays = []
|
||||
original_sleep = asyncio.sleep
|
||||
|
||||
async def mock_sleep(duration):
|
||||
delays.append(duration)
|
||||
# don't actually sleep
|
||||
|
||||
with patch("subprocess.run", side_effect=[fail, fail, success]):
|
||||
with patch("asyncio.sleep", side_effect=mock_sleep):
|
||||
await agent.get_unblocked_tasks()
|
||||
|
||||
assert len(delays) == 2
|
||||
assert delays[0] == pytest.approx(0.01) # base_delay * 2^0
|
||||
assert delays[1] == pytest.approx(0.02) # base_delay * 2^1
|
||||
269
tests/test_workspace.py
Normal file
269
tests/test_workspace.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Tests for WorkspaceManager."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import docker as docker_mod
|
||||
import git
|
||||
import pytest
|
||||
|
||||
from app_factory.core.workspace import (
|
||||
DockerProvisionError,
|
||||
GitWorktreeError,
|
||||
WorkspaceError,
|
||||
WorkspaceManager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_git_repo(tmp_path):
|
||||
"""Create a real temporary git repository."""
|
||||
repo_dir = tmp_path / "repo"
|
||||
repo_dir.mkdir()
|
||||
repo = git.Repo.init(repo_dir)
|
||||
readme = repo_dir / "README.md"
|
||||
readme.write_text("# Test Repo")
|
||||
repo.index.add(["README.md"])
|
||||
repo.index.commit("Initial commit")
|
||||
if repo.active_branch.name != "main":
|
||||
repo.git.branch("-M", "main")
|
||||
return str(repo_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docker_client():
|
||||
"""Create a mock Docker client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.images.pull.return_value = MagicMock()
|
||||
mock_container = MagicMock()
|
||||
mock_container.id = "abc123container"
|
||||
mock_client.containers.create.return_value = mock_container
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager(temp_git_repo, mock_docker_client):
|
||||
"""Create a WorkspaceManager with a real git repo and mocked Docker."""
|
||||
with patch("app_factory.core.workspace.docker.from_env", return_value=mock_docker_client):
|
||||
wm = WorkspaceManager(temp_git_repo)
|
||||
return wm
|
||||
|
||||
|
||||
class TestWorkspaceManagerInit:
|
||||
def test_init_valid_repo(self, temp_git_repo, mock_docker_client):
|
||||
with patch("app_factory.core.workspace.docker.from_env", return_value=mock_docker_client):
|
||||
wm = WorkspaceManager(temp_git_repo)
|
||||
assert wm.repo_path == Path(temp_git_repo).resolve()
|
||||
assert wm.docker_image == "python:3.11-slim"
|
||||
assert wm.active_workspaces == {}
|
||||
|
||||
def test_init_custom_image(self, temp_git_repo, mock_docker_client):
|
||||
with patch("app_factory.core.workspace.docker.from_env", return_value=mock_docker_client):
|
||||
wm = WorkspaceManager(temp_git_repo, docker_image="node:20-slim")
|
||||
assert wm.docker_image == "node:20-slim"
|
||||
|
||||
def test_init_invalid_repo(self, tmp_path, mock_docker_client):
|
||||
with patch("app_factory.core.workspace.docker.from_env", return_value=mock_docker_client):
|
||||
with pytest.raises(GitWorktreeError, match="Invalid git repository"):
|
||||
WorkspaceManager(str(tmp_path))
|
||||
|
||||
def test_init_nonexistent_path(self, mock_docker_client):
|
||||
with patch("app_factory.core.workspace.docker.from_env", return_value=mock_docker_client):
|
||||
with pytest.raises(GitWorktreeError, match="Repository path not found"):
|
||||
WorkspaceManager("/nonexistent/path/to/repo")
|
||||
|
||||
def test_init_docker_unavailable(self, temp_git_repo):
|
||||
with patch(
|
||||
"app_factory.core.workspace.docker.from_env",
|
||||
side_effect=docker_mod.errors.DockerException("Connection refused"),
|
||||
):
|
||||
with pytest.raises(DockerProvisionError, match="Failed to connect to Docker"):
|
||||
WorkspaceManager(temp_git_repo)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCreateWorktree:
|
||||
async def test_create_worktree_success(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
path = await wm.create_worktree("task-001")
|
||||
assert os.path.isdir(path)
|
||||
assert "task-001" in path
|
||||
assert "feature/task-task-001" in [b.name for b in wm.repo.branches]
|
||||
# Cleanup
|
||||
wm.repo.git.worktree("remove", path, "--force")
|
||||
|
||||
async def test_create_worktree_invalid_base_branch(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
with pytest.raises(GitWorktreeError, match="does not exist"):
|
||||
await wm.create_worktree("task-002", base_branch="nonexistent")
|
||||
|
||||
async def test_create_worktree_branch_already_exists(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
wm.repo.git.branch("feature/task-task-003")
|
||||
with pytest.raises(GitWorktreeError, match="Branch already exists"):
|
||||
await wm.create_worktree("task-003")
|
||||
|
||||
async def test_create_worktree_path_already_exists(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
worktree_dir = wm.repo_path.parent / "worktrees" / "task-004"
|
||||
worktree_dir.mkdir(parents=True)
|
||||
with pytest.raises(GitWorktreeError, match="Worktree path already exists"):
|
||||
await wm.create_worktree("task-004")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSpinUpCleanRoom:
|
||||
async def test_spin_up_clean_room_success(self, workspace_manager, mock_docker_client):
|
||||
wm = workspace_manager
|
||||
container = await wm.spin_up_clean_room("/tmp/fake_worktree", "task-010")
|
||||
|
||||
assert container.id == "abc123container"
|
||||
mock_docker_client.images.pull.assert_called_once_with("python:3.11-slim")
|
||||
mock_docker_client.containers.create.assert_called_once_with(
|
||||
image="python:3.11-slim",
|
||||
name="appfactory-task-task-010",
|
||||
volumes={"/tmp/fake_worktree": {"bind": "/workspace", "mode": "rw"}},
|
||||
working_dir="/workspace",
|
||||
network_mode="none",
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
command="sleep infinity",
|
||||
)
|
||||
assert "task-010" in wm.active_workspaces
|
||||
info = wm.active_workspaces["task-010"]
|
||||
assert info["worktree_path"] == "/tmp/fake_worktree"
|
||||
assert info["container_id"] == "abc123container"
|
||||
|
||||
async def test_spin_up_image_pull_failure(self, workspace_manager, mock_docker_client):
|
||||
wm = workspace_manager
|
||||
mock_docker_client.images.pull.side_effect = docker_mod.errors.APIError("pull failed")
|
||||
with pytest.raises(DockerProvisionError, match="Failed to pull image"):
|
||||
await wm.spin_up_clean_room("/tmp/fake", "task-011")
|
||||
|
||||
async def test_spin_up_container_create_failure(self, workspace_manager, mock_docker_client):
|
||||
wm = workspace_manager
|
||||
mock_docker_client.containers.create.side_effect = docker_mod.errors.APIError("create failed")
|
||||
with pytest.raises(DockerProvisionError, match="Failed to create container"):
|
||||
await wm.spin_up_clean_room("/tmp/fake", "task-012")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCleanupWorkspace:
|
||||
async def test_cleanup_workspace_success(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
path = await wm.create_worktree("task-020")
|
||||
mock_container = MagicMock()
|
||||
wm.active_workspaces["task-020"] = {
|
||||
"task_id": "task-020",
|
||||
"worktree_path": path,
|
||||
"container_id": "cont123",
|
||||
"container": mock_container,
|
||||
}
|
||||
|
||||
await wm.cleanup_workspace("task-020")
|
||||
|
||||
mock_container.stop.assert_called_once_with(timeout=5)
|
||||
mock_container.remove.assert_called_once_with(force=True)
|
||||
assert "task-020" not in wm.active_workspaces
|
||||
assert not os.path.exists(path)
|
||||
|
||||
async def test_cleanup_already_stopped_container(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
mock_container = MagicMock()
|
||||
mock_container.stop.side_effect = Exception("already stopped")
|
||||
mock_container.remove.return_value = None
|
||||
|
||||
wm.active_workspaces["task-021"] = {
|
||||
"task_id": "task-021",
|
||||
"worktree_path": "/tmp/nonexistent",
|
||||
"container_id": "cont456",
|
||||
"container": mock_container,
|
||||
}
|
||||
|
||||
# Should not raise even if stop fails
|
||||
await wm.cleanup_workspace("task-021")
|
||||
assert "task-021" not in wm.active_workspaces
|
||||
|
||||
async def test_cleanup_no_container(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
path = await wm.create_worktree("task-022")
|
||||
wm.active_workspaces["task-022"] = {
|
||||
"task_id": "task-022",
|
||||
"worktree_path": path,
|
||||
"container_id": None,
|
||||
"container": None,
|
||||
}
|
||||
await wm.cleanup_workspace("task-022")
|
||||
assert "task-022" not in wm.active_workspaces
|
||||
|
||||
|
||||
class TestGetActiveWorkspaces:
|
||||
def test_get_empty(self, workspace_manager):
|
||||
assert workspace_manager.get_active_workspaces() == []
|
||||
|
||||
def test_get_with_workspaces(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
wm.active_workspaces["t1"] = {
|
||||
"task_id": "t1",
|
||||
"worktree_path": "/path/t1",
|
||||
"container_id": "c1",
|
||||
"container": MagicMock(),
|
||||
}
|
||||
wm.active_workspaces["t2"] = {
|
||||
"task_id": "t2",
|
||||
"worktree_path": "/path/t2",
|
||||
"container_id": "c2",
|
||||
"container": MagicMock(),
|
||||
}
|
||||
|
||||
result = wm.get_active_workspaces()
|
||||
assert len(result) == 2
|
||||
assert {"task_id": "t1", "worktree_path": "/path/t1", "container_id": "c1"} in result
|
||||
assert {"task_id": "t2", "worktree_path": "/path/t2", "container_id": "c2"} in result
|
||||
for item in result:
|
||||
assert "container" not in item
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCleanupAll:
|
||||
async def test_cleanup_all_success(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
mock_c1 = MagicMock()
|
||||
mock_c2 = MagicMock()
|
||||
wm.active_workspaces["t1"] = {
|
||||
"task_id": "t1",
|
||||
"worktree_path": "/tmp/nonexistent1",
|
||||
"container_id": "c1",
|
||||
"container": mock_c1,
|
||||
}
|
||||
wm.active_workspaces["t2"] = {
|
||||
"task_id": "t2",
|
||||
"worktree_path": "/tmp/nonexistent2",
|
||||
"container_id": "c2",
|
||||
"container": mock_c2,
|
||||
}
|
||||
|
||||
await wm.cleanup_all()
|
||||
|
||||
mock_c1.stop.assert_called_once()
|
||||
mock_c2.stop.assert_called_once()
|
||||
assert wm.active_workspaces == {}
|
||||
|
||||
async def test_cleanup_all_with_errors(self, workspace_manager):
|
||||
wm = workspace_manager
|
||||
mock_c1 = MagicMock()
|
||||
mock_c1.stop.side_effect = Exception("stop failed")
|
||||
mock_c1.remove.side_effect = Exception("remove failed")
|
||||
|
||||
wm.active_workspaces["t1"] = {
|
||||
"task_id": "t1",
|
||||
"worktree_path": "/tmp/nonexistent",
|
||||
"container_id": "c1",
|
||||
"container": mock_c1,
|
||||
}
|
||||
|
||||
with pytest.raises(WorkspaceError, match="Cleanup all completed with errors"):
|
||||
await wm.cleanup_all()
|
||||
|
||||
assert wm.active_workspaces == {}
|
||||
Reference in New Issue
Block a user