"""Tests for PMAgent.""" import os import tempfile from types import SimpleNamespace from unittest.mock import AsyncMock, patch import pytest from app_factory.agents.pm_agent import PMAgent # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_api_response(text, input_tokens=10, output_tokens=20): """Build a fake Claude SDK completion response.""" return SimpleNamespace(text=text, input_tokens=input_tokens, output_tokens=output_tokens) def _build_agent(**kwargs): """Create a PMAgent with a mocked Claude SDK client.""" with patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod: mock_client = AsyncMock() mock_mod.return_value = mock_client agent = PMAgent(api_key="test-key", **kwargs) agent.client = mock_client return agent, mock_client # --------------------------------------------------------------------------- # Initialization # --------------------------------------------------------------------------- class TestInitialization: def test_no_api_key_uses_default_client(self): with patch.dict(os.environ, {}, clear=True): env = os.environ.copy() env.pop("ANTHROPIC_API_KEY", None) with patch.dict(os.environ, env, clear=True), \ patch("app_factory.agents.pm_agent.ClaudeSDKClient") as mock_mod: mock_mod.return_value = AsyncMock() agent = PMAgent() mock_mod.assert_called_once_with( api_key=None, auth_token=None, enable_debug=False, ) def test_api_key_from_param(self): agent, _ = _build_agent() assert agent.model == "claude-opus-4-6" def test_custom_model(self): agent, _ = _build_agent(model="claude-opus-4-6") assert agent.model == "claude-opus-4-6" # --------------------------------------------------------------------------- # Template loading # --------------------------------------------------------------------------- class TestTemplateLoading: def test_load_prd_template(self): agent, _ = _build_agent() template = agent._load_template("pm_prd_expansion.txt") assert "Product Manager" in template assert "Objective" in template def test_load_clarification_template(self): agent, _ = _build_agent() template = agent._load_template("pm_clarification.txt") assert "{requesting_agent}" in template assert "ESCALATE_TO_HUMAN" in template def test_load_missing_template_raises(self): agent, _ = _build_agent() with pytest.raises(FileNotFoundError): agent._load_template("nonexistent.txt") # --------------------------------------------------------------------------- # PRD expansion # --------------------------------------------------------------------------- class TestExpandPromptToPrd: @pytest.mark.asyncio async def test_returns_prd_markdown(self): agent, mock_client = _build_agent() prd_text = ( "# Objective\nBuild a todo app\n" "# Core Requirements\n1. Add tasks\n" "# Technical Architecture\nMonolith\n" "# Tech Stack\nPython, FastAPI\n" "# Success Criteria\nAll tests pass\n" "# Non-Functional Requirements\n<1s response" ) mock_client.complete = AsyncMock( return_value=_make_api_response(prd_text, 15, 100) ) result = await agent.expand_prompt_to_prd("Build a todo app") assert "Objective" in result assert "Core Requirements" in result assert "Technical Architecture" in result assert "Tech Stack" in result assert "Success Criteria" in result assert "Non-Functional Requirements" in result mock_client.complete.assert_awaited_once() @pytest.mark.asyncio async def test_tracks_token_usage(self): agent, mock_client = _build_agent() mock_client.complete = AsyncMock( return_value=_make_api_response("prd content", 50, 200) ) await agent.expand_prompt_to_prd("some input") usage = agent.get_token_usage() assert usage["input_tokens"] == 50 assert usage["output_tokens"] == 200 assert usage["total_tokens"] == 250 # --------------------------------------------------------------------------- # Clarification handling # --------------------------------------------------------------------------- class TestHandleClarification: @pytest.mark.asyncio async def test_auto_resolve(self): agent, mock_client = _build_agent() mock_client.complete = AsyncMock( return_value=_make_api_response("Use PostgreSQL for the database.") ) result = await agent.handle_clarification_request({ "requesting_agent": "dev_agent", "task_id": "2.1", "question": "Which database should I use?", "context": "PRD says relational DB", }) assert result == "Use PostgreSQL for the database." @pytest.mark.asyncio async def test_escalate_to_human(self): agent, mock_client = _build_agent() mock_client.complete = AsyncMock( return_value=_make_api_response("ESCALATE_TO_HUMAN") ) with patch("builtins.input", return_value="Use MySQL") as mock_input: result = await agent.handle_clarification_request({ "requesting_agent": "dev_agent", "task_id": "3.1", "question": "Which vendor should we pick?", "context": "", }) assert result == "Use MySQL" mock_input.assert_called_once() @pytest.mark.asyncio async def test_tracks_tokens(self): agent, mock_client = _build_agent() mock_client.complete = AsyncMock( return_value=_make_api_response("answer", 5, 10) ) await agent.handle_clarification_request({ "requesting_agent": "qa", "task_id": "1.1", "question": "q", "context": "c", }) usage = agent.get_token_usage() assert usage["input_tokens"] == 5 assert usage["output_tokens"] == 10 # --------------------------------------------------------------------------- # PRD updates # --------------------------------------------------------------------------- class TestUpdatePrd: def test_appends_with_version_header(self): agent, _ = _build_agent() with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f: f.write("# Original PRD\nSome content\n") prd_path = f.name try: agent.update_prd(prd_path, "Added authentication requirement.") with open(prd_path) as f: content = f.read() assert "# Original PRD" in content assert "## PRD Update -" in content assert "Added authentication requirement." in content finally: os.unlink(prd_path) def test_multiple_updates_maintain_history(self): agent, _ = _build_agent() with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f: f.write("# PRD v1\n") prd_path = f.name try: agent.update_prd(prd_path, "Update 1") agent.update_prd(prd_path, "Update 2") with open(prd_path) as f: content = f.read() assert content.count("## PRD Update -") == 2 assert "Update 1" in content assert "Update 2" in content finally: os.unlink(prd_path) # --------------------------------------------------------------------------- # Token usage # --------------------------------------------------------------------------- class TestTokenUsage: def test_initial_usage_zero(self): agent, _ = _build_agent() usage = agent.get_token_usage() assert usage == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} @pytest.mark.asyncio async def test_accumulates_across_calls(self): agent, mock_client = _build_agent() mock_client.complete = AsyncMock( side_effect=[ _make_api_response("r1", 10, 20), _make_api_response("r2", 30, 40), ] ) await agent.expand_prompt_to_prd("first call") await agent.expand_prompt_to_prd("second call") usage = agent.get_token_usage() assert usage["input_tokens"] == 40 assert usage["output_tokens"] == 60 assert usage["total_tokens"] == 100