258 lines
8.6 KiB
Python
258 lines
8.6 KiB
Python
"""Tests for PMAgent."""
|
|
|
|
import os
|
|
import tempfile
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app_factory.agents.pm_agent import PMAgent
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_api_response(text, input_tokens=10, output_tokens=20):
|
|
"""Build a fake Claude SDK completion response."""
|
|
return SimpleNamespace(text=text, input_tokens=input_tokens, output_tokens=output_tokens)
|
|
|
|
|
|
def _build_agent(**kwargs):
|
|
"""Create a PMAgent with a mocked Claude SDK client."""
|
|
with patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
|
mock_client = AsyncMock()
|
|
mock_mod.return_value = mock_client
|
|
agent = PMAgent(api_key="test-key", **kwargs)
|
|
agent.client = mock_client
|
|
return agent, mock_client
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Initialization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestInitialization:
|
|
def test_no_api_key_uses_default_client(self):
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
env = os.environ.copy()
|
|
env.pop("ANTHROPIC_API_KEY", None)
|
|
with patch.dict(os.environ, env, clear=True), \
|
|
patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
|
mock_mod.return_value = AsyncMock()
|
|
agent = PMAgent()
|
|
mock_mod.assert_called_once_with(
|
|
api_key=None,
|
|
auth_token=None,
|
|
enable_debug=False,
|
|
)
|
|
|
|
def test_api_key_from_param(self):
|
|
agent, _ = _build_agent()
|
|
assert agent.model == "claude-opus-4-6"
|
|
|
|
def test_custom_model(self):
|
|
agent, _ = _build_agent(model="claude-opus-4-6")
|
|
assert agent.model == "claude-opus-4-6"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Template loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestTemplateLoading:
|
|
def test_load_prd_template(self):
|
|
agent, _ = _build_agent()
|
|
template = agent._load_template("pm_prd_expansion.txt")
|
|
assert "Product Manager" in template
|
|
assert "Objective" in template
|
|
|
|
def test_load_clarification_template(self):
|
|
agent, _ = _build_agent()
|
|
template = agent._load_template("pm_clarification.txt")
|
|
assert "{requesting_agent}" in template
|
|
assert "ESCALATE_TO_HUMAN" in template
|
|
|
|
def test_load_missing_template_raises(self):
|
|
agent, _ = _build_agent()
|
|
with pytest.raises(FileNotFoundError):
|
|
agent._load_template("nonexistent.txt")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PRD expansion
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExpandPromptToPrd:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_prd_markdown(self):
|
|
agent, mock_client = _build_agent()
|
|
prd_text = (
|
|
"# Objective\nBuild a todo app\n"
|
|
"# Core Requirements\n1. Add tasks\n"
|
|
"# Technical Architecture\nMonolith\n"
|
|
"# Tech Stack\nPython, FastAPI\n"
|
|
"# Success Criteria\nAll tests pass\n"
|
|
"# Non-Functional Requirements\n<1s response"
|
|
)
|
|
mock_client.complete = AsyncMock(
|
|
return_value=_make_api_response(prd_text, 15, 100)
|
|
)
|
|
|
|
result = await agent.expand_prompt_to_prd("Build a todo app")
|
|
|
|
assert "Objective" in result
|
|
assert "Core Requirements" in result
|
|
assert "Technical Architecture" in result
|
|
assert "Tech Stack" in result
|
|
assert "Success Criteria" in result
|
|
assert "Non-Functional Requirements" in result
|
|
mock_client.complete.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracks_token_usage(self):
|
|
agent, mock_client = _build_agent()
|
|
mock_client.complete = AsyncMock(
|
|
return_value=_make_api_response("prd content", 50, 200)
|
|
)
|
|
|
|
await agent.expand_prompt_to_prd("some input")
|
|
|
|
usage = agent.get_token_usage()
|
|
assert usage["input_tokens"] == 50
|
|
assert usage["output_tokens"] == 200
|
|
assert usage["total_tokens"] == 250
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Clarification handling
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHandleClarification:
|
|
@pytest.mark.asyncio
|
|
async def test_auto_resolve(self):
|
|
agent, mock_client = _build_agent()
|
|
mock_client.complete = AsyncMock(
|
|
return_value=_make_api_response("Use PostgreSQL for the database.")
|
|
)
|
|
|
|
result = await agent.handle_clarification_request({
|
|
"requesting_agent": "dev_agent",
|
|
"task_id": "2.1",
|
|
"question": "Which database should I use?",
|
|
"context": "PRD says relational DB",
|
|
})
|
|
|
|
assert result == "Use PostgreSQL for the database."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_escalate_to_human(self):
|
|
agent, mock_client = _build_agent()
|
|
mock_client.complete = AsyncMock(
|
|
return_value=_make_api_response("ESCALATE_TO_HUMAN")
|
|
)
|
|
|
|
with patch("builtins.input", return_value="Use MySQL") as mock_input:
|
|
result = await agent.handle_clarification_request({
|
|
"requesting_agent": "dev_agent",
|
|
"task_id": "3.1",
|
|
"question": "Which vendor should we pick?",
|
|
"context": "",
|
|
})
|
|
|
|
assert result == "Use MySQL"
|
|
mock_input.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracks_tokens(self):
|
|
agent, mock_client = _build_agent()
|
|
mock_client.complete = AsyncMock(
|
|
return_value=_make_api_response("answer", 5, 10)
|
|
)
|
|
|
|
await agent.handle_clarification_request({
|
|
"requesting_agent": "qa",
|
|
"task_id": "1.1",
|
|
"question": "q",
|
|
"context": "c",
|
|
})
|
|
|
|
usage = agent.get_token_usage()
|
|
assert usage["input_tokens"] == 5
|
|
assert usage["output_tokens"] == 10
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PRD updates
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestUpdatePrd:
|
|
def test_appends_with_version_header(self):
|
|
agent, _ = _build_agent()
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
|
f.write("# Original PRD\nSome content\n")
|
|
prd_path = f.name
|
|
|
|
try:
|
|
agent.update_prd(prd_path, "Added authentication requirement.")
|
|
|
|
with open(prd_path) as f:
|
|
content = f.read()
|
|
|
|
assert "# Original PRD" in content
|
|
assert "## PRD Update -" in content
|
|
assert "Added authentication requirement." in content
|
|
finally:
|
|
os.unlink(prd_path)
|
|
|
|
def test_multiple_updates_maintain_history(self):
|
|
agent, _ = _build_agent()
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
|
f.write("# PRD v1\n")
|
|
prd_path = f.name
|
|
|
|
try:
|
|
agent.update_prd(prd_path, "Update 1")
|
|
agent.update_prd(prd_path, "Update 2")
|
|
|
|
with open(prd_path) as f:
|
|
content = f.read()
|
|
|
|
assert content.count("## PRD Update -") == 2
|
|
assert "Update 1" in content
|
|
assert "Update 2" in content
|
|
finally:
|
|
os.unlink(prd_path)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Token usage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestTokenUsage:
|
|
def test_initial_usage_zero(self):
|
|
agent, _ = _build_agent()
|
|
usage = agent.get_token_usage()
|
|
assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_accumulates_across_calls(self):
|
|
agent, mock_client = _build_agent()
|
|
mock_client.complete = AsyncMock(
|
|
side_effect=[
|
|
_make_api_response("r1", 10, 20),
|
|
_make_api_response("r2", 30, 40),
|
|
]
|
|
)
|
|
|
|
await agent.expand_prompt_to_prd("first call")
|
|
await agent.expand_prompt_to_prd("second call")
|
|
|
|
usage = agent.get_token_usage()
|
|
assert usage["input_tokens"] == 40
|
|
assert usage["output_tokens"] == 60
|
|
assert usage["total_tokens"] == 100
|