first commit
This commit is contained in:
257
tests/test_pm_agent.py
Normal file
257
tests/test_pm_agent.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Tests for PMAgent."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory.agents.pm_agent import PMAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_api_response(text, input_tokens=10, output_tokens=20):
|
||||
"""Build a fake Claude SDK completion response."""
|
||||
return SimpleNamespace(text=text, input_tokens=input_tokens, output_tokens=output_tokens)
|
||||
|
||||
|
||||
def _build_agent(**kwargs):
|
||||
"""Create a PMAgent with a mocked Claude SDK client."""
|
||||
with patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
||||
mock_client = AsyncMock()
|
||||
mock_mod.return_value = mock_client
|
||||
agent = PMAgent(api_key="test-key", **kwargs)
|
||||
agent.client = mock_client
|
||||
return agent, mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInitialization:
|
||||
def test_no_api_key_uses_default_client(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
env = os.environ.copy()
|
||||
env.pop("ANTHROPIC_API_KEY", None)
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod:
|
||||
mock_mod.return_value = AsyncMock()
|
||||
agent = PMAgent()
|
||||
mock_mod.assert_called_once_with(
|
||||
api_key=None,
|
||||
auth_token=None,
|
||||
enable_debug=False,
|
||||
)
|
||||
|
||||
def test_api_key_from_param(self):
|
||||
agent, _ = _build_agent()
|
||||
assert agent.model == "claude-opus-4-6"
|
||||
|
||||
def test_custom_model(self):
|
||||
agent, _ = _build_agent(model="claude-opus-4-6")
|
||||
assert agent.model == "claude-opus-4-6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTemplateLoading:
|
||||
def test_load_prd_template(self):
|
||||
agent, _ = _build_agent()
|
||||
template = agent._load_template("pm_prd_expansion.txt")
|
||||
assert "Product Manager" in template
|
||||
assert "Objective" in template
|
||||
|
||||
def test_load_clarification_template(self):
|
||||
agent, _ = _build_agent()
|
||||
template = agent._load_template("pm_clarification.txt")
|
||||
assert "{requesting_agent}" in template
|
||||
assert "ESCALATE_TO_HUMAN" in template
|
||||
|
||||
def test_load_missing_template_raises(self):
|
||||
agent, _ = _build_agent()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
agent._load_template("nonexistent.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PRD expansion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExpandPromptToPrd:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_prd_markdown(self):
|
||||
agent, mock_client = _build_agent()
|
||||
prd_text = (
|
||||
"# Objective\nBuild a todo app\n"
|
||||
"# Core Requirements\n1. Add tasks\n"
|
||||
"# Technical Architecture\nMonolith\n"
|
||||
"# Tech Stack\nPython, FastAPI\n"
|
||||
"# Success Criteria\nAll tests pass\n"
|
||||
"# Non-Functional Requirements\n<1s response"
|
||||
)
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response(prd_text, 15, 100)
|
||||
)
|
||||
|
||||
result = await agent.expand_prompt_to_prd("Build a todo app")
|
||||
|
||||
assert "Objective" in result
|
||||
assert "Core Requirements" in result
|
||||
assert "Technical Architecture" in result
|
||||
assert "Tech Stack" in result
|
||||
assert "Success Criteria" in result
|
||||
assert "Non-Functional Requirements" in result
|
||||
mock_client.complete.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_token_usage(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("prd content", 50, 200)
|
||||
)
|
||||
|
||||
await agent.expand_prompt_to_prd("some input")
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 50
|
||||
assert usage["output_tokens"] == 200
|
||||
assert usage["total_tokens"] == 250
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Clarification handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleClarification:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_resolve(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("Use PostgreSQL for the database.")
|
||||
)
|
||||
|
||||
result = await agent.handle_clarification_request({
|
||||
"requesting_agent": "dev_agent",
|
||||
"task_id": "2.1",
|
||||
"question": "Which database should I use?",
|
||||
"context": "PRD says relational DB",
|
||||
})
|
||||
|
||||
assert result == "Use PostgreSQL for the database."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_to_human(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("ESCALATE_TO_HUMAN")
|
||||
)
|
||||
|
||||
with patch("builtins.input", return_value="Use MySQL") as mock_input:
|
||||
result = await agent.handle_clarification_request({
|
||||
"requesting_agent": "dev_agent",
|
||||
"task_id": "3.1",
|
||||
"question": "Which vendor should we pick?",
|
||||
"context": "",
|
||||
})
|
||||
|
||||
assert result == "Use MySQL"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_tokens(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
return_value=_make_api_response("answer", 5, 10)
|
||||
)
|
||||
|
||||
await agent.handle_clarification_request({
|
||||
"requesting_agent": "qa",
|
||||
"task_id": "1.1",
|
||||
"question": "q",
|
||||
"context": "c",
|
||||
})
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 5
|
||||
assert usage["output_tokens"] == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PRD updates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdatePrd:
|
||||
def test_appends_with_version_header(self):
|
||||
agent, _ = _build_agent()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
||||
f.write("# Original PRD\nSome content\n")
|
||||
prd_path = f.name
|
||||
|
||||
try:
|
||||
agent.update_prd(prd_path, "Added authentication requirement.")
|
||||
|
||||
with open(prd_path) as f:
|
||||
content = f.read()
|
||||
|
||||
assert "# Original PRD" in content
|
||||
assert "## PRD Update -" in content
|
||||
assert "Added authentication requirement." in content
|
||||
finally:
|
||||
os.unlink(prd_path)
|
||||
|
||||
def test_multiple_updates_maintain_history(self):
|
||||
agent, _ = _build_agent()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
||||
f.write("# PRD v1\n")
|
||||
prd_path = f.name
|
||||
|
||||
try:
|
||||
agent.update_prd(prd_path, "Update 1")
|
||||
agent.update_prd(prd_path, "Update 2")
|
||||
|
||||
with open(prd_path) as f:
|
||||
content = f.read()
|
||||
|
||||
assert content.count("## PRD Update -") == 2
|
||||
assert "Update 1" in content
|
||||
assert "Update 2" in content
|
||||
finally:
|
||||
os.unlink(prd_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token usage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTokenUsage:
|
||||
def test_initial_usage_zero(self):
|
||||
agent, _ = _build_agent()
|
||||
usage = agent.get_token_usage()
|
||||
assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accumulates_across_calls(self):
|
||||
agent, mock_client = _build_agent()
|
||||
mock_client.complete = AsyncMock(
|
||||
side_effect=[
|
||||
_make_api_response("r1", 10, 20),
|
||||
_make_api_response("r2", 30, 40),
|
||||
]
|
||||
)
|
||||
|
||||
await agent.expand_prompt_to_prd("first call")
|
||||
await agent.expand_prompt_to_prd("second call")
|
||||
|
||||
usage = agent.get_token_usage()
|
||||
assert usage["input_tokens"] == 40
|
||||
assert usage["output_tokens"] == 60
|
||||
assert usage["total_tokens"] == 100
|
||||
Reference in New Issue
Block a user