346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""Tests for Claude Agent SDK client adapter."""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
|
|
|
import pytest
|
|
|
|
from app_factory.core.claude_client import ClaudeSDKClient
|
|
|
|
|
|
class _FakeOptions:
|
|
def __init__(self, **kwargs):
|
|
self.kwargs = kwargs
|
|
|
|
|
|
class _FakeProcessError(RuntimeError):
|
|
def __init__(self, message: str, stderr: str):
|
|
super().__init__(message)
|
|
self.stderr = stderr
|
|
|
|
|
|
class TestClaudeSDKClient:
|
|
@pytest.mark.asyncio
|
|
async def test_complete_uses_result_text_and_usage_tokens(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
yield SimpleNamespace(
|
|
content=[SimpleNamespace(text="intermediate answer")],
|
|
model="claude-sonnet-4-6",
|
|
usage=None,
|
|
is_error=False,
|
|
)
|
|
yield SimpleNamespace(
|
|
result="final answer",
|
|
usage={"input_tokens": 12, "output_tokens": 34},
|
|
is_error=False,
|
|
)
|
|
|
|
client._query = _fake_query
|
|
result = await client.complete(
|
|
"hello",
|
|
model="claude-sonnet-4-6",
|
|
system_prompt="You are a tester.",
|
|
max_turns=1,
|
|
)
|
|
|
|
assert result.text == "final answer"
|
|
assert result.input_tokens == 12
|
|
assert result.output_tokens == 34
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_raises_on_error_result(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
yield SimpleNamespace(
|
|
result="something went wrong",
|
|
usage={"input_tokens": 1, "output_tokens": 1},
|
|
is_error=True,
|
|
)
|
|
|
|
client._query = _fake_query
|
|
|
|
with pytest.raises(RuntimeError, match="something went wrong"):
|
|
await client.complete("hello")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_falls_back_to_content_blocks(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
yield SimpleNamespace(
|
|
content=[SimpleNamespace(text="first"), SimpleNamespace(text="second")],
|
|
model="claude-sonnet-4-6",
|
|
usage={"cache_creation_input_tokens": 7, "output_tokens": 2},
|
|
is_error=False,
|
|
)
|
|
|
|
client._query = _fake_query
|
|
result = await client.complete("hello")
|
|
|
|
assert result.text == "first\nsecond"
|
|
assert result.input_tokens == 7
|
|
assert result.output_tokens == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_rejects_error_subtype(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
yield SimpleNamespace(
|
|
subtype="error_during_execution",
|
|
is_error=False,
|
|
result=None,
|
|
usage={"input_tokens": 0, "output_tokens": 0},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
|
|
with pytest.raises(RuntimeError, match="error_during_execution"):
|
|
await client.complete("hello")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_ignores_user_content_blocks(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
# Simulates a user/system message content block that should not be treated as model output.
|
|
yield SimpleNamespace(
|
|
content=[SimpleNamespace(text="[Request interrupted by user]")],
|
|
is_error=False,
|
|
usage=None,
|
|
)
|
|
yield SimpleNamespace(
|
|
result="actual assistant result",
|
|
subtype="success",
|
|
is_error=False,
|
|
usage={"input_tokens": 1, "output_tokens": 2},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
result = await client.complete("hello")
|
|
|
|
assert result.text == "actual assistant result"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_wraps_query_exception_with_hint(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
raise RuntimeError("Command failed with exit code 1")
|
|
yield # pragma: no cover
|
|
|
|
client._query = _fake_query
|
|
|
|
with pytest.raises(RuntimeError, match="Hint: verify Claude auth is available"):
|
|
await client.complete("hello")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_passes_debug_to_stderr_flag(self):
|
|
client = ClaudeSDKClient(enable_debug=True)
|
|
client._options_cls = _FakeOptions
|
|
captured_options = None
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
nonlocal captured_options
|
|
captured_options = options
|
|
yield SimpleNamespace(
|
|
result="ok",
|
|
subtype="success",
|
|
is_error=False,
|
|
usage={"input_tokens": 0, "output_tokens": 0},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
await client.complete("hello")
|
|
|
|
assert captured_options is not None
|
|
assert captured_options.kwargs["extra_args"] == {"debug-to-stderr": None}
|
|
assert hasattr(captured_options.kwargs["debug_stderr"], "fileno")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_omits_debug_stderr_without_debug_flag(self):
|
|
client = ClaudeSDKClient(enable_debug=False)
|
|
client._options_cls = _FakeOptions
|
|
captured_options = None
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
nonlocal captured_options
|
|
captured_options = options
|
|
yield SimpleNamespace(
|
|
result="ok",
|
|
subtype="success",
|
|
is_error=False,
|
|
usage={"input_tokens": 0, "output_tokens": 0},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
await client.complete("hello")
|
|
|
|
assert captured_options is not None
|
|
assert "extra_args" not in captured_options.kwargs
|
|
assert "debug_stderr" not in captured_options.kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_includes_exception_stderr_detail(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
raise _FakeProcessError(
|
|
"Command failed with exit code 1",
|
|
"permission denied writing ~/.claude.json",
|
|
)
|
|
yield # pragma: no cover
|
|
|
|
client._query = _fake_query
|
|
|
|
with pytest.raises(RuntimeError, match="permission denied writing ~/.claude.json"):
|
|
await client.complete("hello")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_uses_fallback_home_when_default_home_not_writable(self, tmp_path):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
captured_options = None
|
|
cwd_dir = tmp_path / "workspace"
|
|
cwd_dir.mkdir(parents=True, exist_ok=True)
|
|
expected_home = cwd_dir / ".app_factory" / "claude_home"
|
|
expected_home.mkdir(parents=True, exist_ok=True)
|
|
|
|
client._claude_home_is_writable = lambda home: False
|
|
client._prepare_fallback_claude_home = (
|
|
lambda source_home, fallback_home: expected_home
|
|
)
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
nonlocal captured_options
|
|
captured_options = options
|
|
yield SimpleNamespace(
|
|
result="ok",
|
|
subtype="success",
|
|
is_error=False,
|
|
usage={"input_tokens": 0, "output_tokens": 0},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
await client.complete("hello", cwd=str(cwd_dir))
|
|
|
|
assert captured_options is not None
|
|
assert captured_options.kwargs["env"]["HOME"] == str(expected_home)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_retries_rate_limit_event_then_succeeds(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
attempts = 0
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
nonlocal attempts
|
|
attempts += 1
|
|
if attempts == 1:
|
|
raise RuntimeError("Unknown message type: rate_limit_event")
|
|
yield SimpleNamespace(
|
|
result="ok",
|
|
subtype="success",
|
|
is_error=False,
|
|
usage={"input_tokens": 2, "output_tokens": 3},
|
|
)
|
|
|
|
client._query = _fake_query
|
|
|
|
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
|
result = await client.complete("hello")
|
|
|
|
assert result.text == "ok"
|
|
assert attempts == 2
|
|
mock_sleep.assert_awaited_once_with(0.2)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_retries_rate_limit_event_exhausts_then_fails(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
attempts = 0
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
nonlocal attempts
|
|
attempts += 1
|
|
raise RuntimeError("Unknown message type: rate_limit_event")
|
|
yield # pragma: no cover
|
|
|
|
client._query = _fake_query
|
|
|
|
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
|
with pytest.raises(RuntimeError, match="rate_limit_event"):
|
|
await client.complete("hello")
|
|
|
|
assert attempts == 4
|
|
assert mock_sleep.await_args_list == [call(0.2), call(0.8), call(4.0)]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_emits_observability_events(self):
|
|
client = ClaudeSDKClient()
|
|
client._options_cls = _FakeOptions
|
|
mock_observability = MagicMock()
|
|
|
|
async def _fake_query(*, prompt, options):
|
|
yield SimpleNamespace(
|
|
content=[
|
|
SimpleNamespace(
|
|
id="toolu_1",
|
|
name="Bash",
|
|
input={"command": "echo hi", "api_key": "secret"},
|
|
),
|
|
SimpleNamespace(tool_use_id="toolu_1", content="done", is_error=False),
|
|
SimpleNamespace(text="Tool finished"),
|
|
],
|
|
model="claude-sonnet-4-6",
|
|
session_id="session-123",
|
|
usage=None,
|
|
is_error=False,
|
|
)
|
|
yield SimpleNamespace(
|
|
subtype="success",
|
|
duration_ms=250,
|
|
duration_api_ms=200,
|
|
is_error=False,
|
|
num_turns=1,
|
|
session_id="session-123",
|
|
usage={"input_tokens": 3, "output_tokens": 4},
|
|
result="ok",
|
|
)
|
|
|
|
client._query = _fake_query
|
|
result = await client.complete(
|
|
"hello",
|
|
observability=mock_observability,
|
|
agent_name="pm_agent",
|
|
task_id="expand_prd",
|
|
)
|
|
|
|
assert result.text == "ok"
|
|
event_types = [
|
|
call.kwargs["event_type"]
|
|
for call in mock_observability.log_claude_event.call_args_list
|
|
]
|
|
assert "request_start" in event_types
|
|
assert "tool_use" in event_types
|
|
assert "tool_result" in event_types
|
|
assert "result_message" in event_types
|
|
assert "request_complete" in event_types
|
|
|
|
tool_use_calls = [
|
|
c.kwargs for c in mock_observability.log_claude_event.call_args_list
|
|
if c.kwargs.get("event_type") == "tool_use"
|
|
]
|
|
assert tool_use_calls
|
|
assert tool_use_calls[0]["payload"]["tool_input"]["api_key"] == "[REDACTED]"
|