Files
ai_ops2/tests/test_claude_client.py
2026-02-25 23:49:54 -05:00

346 lines
12 KiB
Python

"""Tests for Claude Agent SDK client adapter."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from app_factory.core.claude_client import ClaudeSDKClient
class _FakeOptions:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _FakeProcessError(RuntimeError):
def __init__(self, message: str, stderr: str):
super().__init__(message)
self.stderr = stderr
class TestClaudeSDKClient:
@pytest.mark.asyncio
async def test_complete_uses_result_text_and_usage_tokens(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
yield SimpleNamespace(
content=[SimpleNamespace(text="intermediate answer")],
model="claude-sonnet-4-6",
usage=None,
is_error=False,
)
yield SimpleNamespace(
result="final answer",
usage={"input_tokens": 12, "output_tokens": 34},
is_error=False,
)
client._query = _fake_query
result = await client.complete(
"hello",
model="claude-sonnet-4-6",
system_prompt="You are a tester.",
max_turns=1,
)
assert result.text == "final answer"
assert result.input_tokens == 12
assert result.output_tokens == 34
@pytest.mark.asyncio
async def test_complete_raises_on_error_result(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
yield SimpleNamespace(
result="something went wrong",
usage={"input_tokens": 1, "output_tokens": 1},
is_error=True,
)
client._query = _fake_query
with pytest.raises(RuntimeError, match="something went wrong"):
await client.complete("hello")
@pytest.mark.asyncio
async def test_complete_falls_back_to_content_blocks(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
yield SimpleNamespace(
content=[SimpleNamespace(text="first"), SimpleNamespace(text="second")],
model="claude-sonnet-4-6",
usage={"cache_creation_input_tokens": 7, "output_tokens": 2},
is_error=False,
)
client._query = _fake_query
result = await client.complete("hello")
assert result.text == "first\nsecond"
assert result.input_tokens == 7
assert result.output_tokens == 2
@pytest.mark.asyncio
async def test_complete_rejects_error_subtype(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
yield SimpleNamespace(
subtype="error_during_execution",
is_error=False,
result=None,
usage={"input_tokens": 0, "output_tokens": 0},
)
client._query = _fake_query
with pytest.raises(RuntimeError, match="error_during_execution"):
await client.complete("hello")
@pytest.mark.asyncio
async def test_complete_ignores_user_content_blocks(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
# Simulates a user/system message content block that should not be treated as model output.
yield SimpleNamespace(
content=[SimpleNamespace(text="[Request interrupted by user]")],
is_error=False,
usage=None,
)
yield SimpleNamespace(
result="actual assistant result",
subtype="success",
is_error=False,
usage={"input_tokens": 1, "output_tokens": 2},
)
client._query = _fake_query
result = await client.complete("hello")
assert result.text == "actual assistant result"
@pytest.mark.asyncio
async def test_complete_wraps_query_exception_with_hint(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
raise RuntimeError("Command failed with exit code 1")
yield # pragma: no cover
client._query = _fake_query
with pytest.raises(RuntimeError, match="Hint: verify Claude auth is available"):
await client.complete("hello")
@pytest.mark.asyncio
async def test_complete_passes_debug_to_stderr_flag(self):
client = ClaudeSDKClient(enable_debug=True)
client._options_cls = _FakeOptions
captured_options = None
async def _fake_query(*, prompt, options):
nonlocal captured_options
captured_options = options
yield SimpleNamespace(
result="ok",
subtype="success",
is_error=False,
usage={"input_tokens": 0, "output_tokens": 0},
)
client._query = _fake_query
await client.complete("hello")
assert captured_options is not None
assert captured_options.kwargs["extra_args"] == {"debug-to-stderr": None}
assert hasattr(captured_options.kwargs["debug_stderr"], "fileno")
@pytest.mark.asyncio
async def test_complete_omits_debug_stderr_without_debug_flag(self):
client = ClaudeSDKClient(enable_debug=False)
client._options_cls = _FakeOptions
captured_options = None
async def _fake_query(*, prompt, options):
nonlocal captured_options
captured_options = options
yield SimpleNamespace(
result="ok",
subtype="success",
is_error=False,
usage={"input_tokens": 0, "output_tokens": 0},
)
client._query = _fake_query
await client.complete("hello")
assert captured_options is not None
assert "extra_args" not in captured_options.kwargs
assert "debug_stderr" not in captured_options.kwargs
@pytest.mark.asyncio
async def test_complete_includes_exception_stderr_detail(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
async def _fake_query(*, prompt, options):
raise _FakeProcessError(
"Command failed with exit code 1",
"permission denied writing ~/.claude.json",
)
yield # pragma: no cover
client._query = _fake_query
with pytest.raises(RuntimeError, match="permission denied writing ~/.claude.json"):
await client.complete("hello")
@pytest.mark.asyncio
async def test_complete_uses_fallback_home_when_default_home_not_writable(self, tmp_path):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
captured_options = None
cwd_dir = tmp_path / "workspace"
cwd_dir.mkdir(parents=True, exist_ok=True)
expected_home = cwd_dir / ".app_factory" / "claude_home"
expected_home.mkdir(parents=True, exist_ok=True)
client._claude_home_is_writable = lambda home: False
client._prepare_fallback_claude_home = (
lambda source_home, fallback_home: expected_home
)
async def _fake_query(*, prompt, options):
nonlocal captured_options
captured_options = options
yield SimpleNamespace(
result="ok",
subtype="success",
is_error=False,
usage={"input_tokens": 0, "output_tokens": 0},
)
client._query = _fake_query
await client.complete("hello", cwd=str(cwd_dir))
assert captured_options is not None
assert captured_options.kwargs["env"]["HOME"] == str(expected_home)
@pytest.mark.asyncio
async def test_complete_retries_rate_limit_event_then_succeeds(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
attempts = 0
async def _fake_query(*, prompt, options):
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("Unknown message type: rate_limit_event")
yield SimpleNamespace(
result="ok",
subtype="success",
is_error=False,
usage={"input_tokens": 2, "output_tokens": 3},
)
client._query = _fake_query
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
result = await client.complete("hello")
assert result.text == "ok"
assert attempts == 2
mock_sleep.assert_awaited_once_with(0.2)
@pytest.mark.asyncio
async def test_complete_retries_rate_limit_event_exhausts_then_fails(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
attempts = 0
async def _fake_query(*, prompt, options):
nonlocal attempts
attempts += 1
raise RuntimeError("Unknown message type: rate_limit_event")
yield # pragma: no cover
client._query = _fake_query
with patch("app_factory.core.claude_client.asyncio.sleep", new=AsyncMock()) as mock_sleep:
with pytest.raises(RuntimeError, match="rate_limit_event"):
await client.complete("hello")
assert attempts == 4
assert mock_sleep.await_args_list == [call(0.2), call(0.8), call(4.0)]
@pytest.mark.asyncio
async def test_complete_emits_observability_events(self):
client = ClaudeSDKClient()
client._options_cls = _FakeOptions
mock_observability = MagicMock()
async def _fake_query(*, prompt, options):
yield SimpleNamespace(
content=[
SimpleNamespace(
id="toolu_1",
name="Bash",
input={"command": "echo hi", "api_key": "secret"},
),
SimpleNamespace(tool_use_id="toolu_1", content="done", is_error=False),
SimpleNamespace(text="Tool finished"),
],
model="claude-sonnet-4-6",
session_id="session-123",
usage=None,
is_error=False,
)
yield SimpleNamespace(
subtype="success",
duration_ms=250,
duration_api_ms=200,
is_error=False,
num_turns=1,
session_id="session-123",
usage={"input_tokens": 3, "output_tokens": 4},
result="ok",
)
client._query = _fake_query
result = await client.complete(
"hello",
observability=mock_observability,
agent_name="pm_agent",
task_id="expand_prd",
)
assert result.text == "ok"
event_types = [
call.kwargs["event_type"]
for call in mock_observability.log_claude_event.call_args_list
]
assert "request_start" in event_types
assert "tool_use" in event_types
assert "tool_result" in event_types
assert "result_message" in event_types
assert "request_complete" in event_types
tool_use_calls = [
c.kwargs for c in mock_observability.log_claude_event.call_args_list
if c.kwargs.get("event_type") == "tool_use"
]
assert tool_use_calls
assert tool_use_calls[0]["payload"]["tool_input"]["api_key"] == "[REDACTED]"