Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ def _resolve_model_name(self, model: Optional[str]) -> str:
return match.group(1)
return model

def _get_generation_kwargs(
self, llm_request: LlmRequest
) -> dict[str, Any]:
generation_kwargs: dict[str, Any] = {}

if llm_request.config.temperature is not None:
generation_kwargs["temperature"] = llm_request.config.temperature
if llm_request.config.top_p is not None:
generation_kwargs["top_p"] = llm_request.config.top_p
if llm_request.config.top_k is not None:
generation_kwargs["top_k"] = llm_request.config.top_k
if llm_request.config.stop_sequences:
generation_kwargs["stop_sequences"] = llm_request.config.stop_sequences

return generation_kwargs

@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
Expand All @@ -401,6 +417,7 @@ async def generate_content_async(
if llm_request.tools_dict
else NOT_GIVEN
)
generation_kwargs = self._get_generation_kwargs(llm_request)

if not stream:
message = await self._anthropic_client.messages.create(
Expand All @@ -410,6 +427,7 @@ async def generate_content_async(
tools=tools,
tool_choice=tool_choice,
max_tokens=self.max_tokens,
**generation_kwargs,
)
yield message_to_generate_content_response(message)
else:
Expand Down Expand Up @@ -439,6 +457,7 @@ async def _generate_content_streaming(
tool_choice=tool_choice,
max_tokens=self.max_tokens,
stream=True,
**self._get_generation_kwargs(llm_request),
)

# Track content blocks being built during streaming.
Expand Down
105 changes: 105 additions & 0 deletions tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,108 @@ async def test_non_streaming_does_not_pass_stream_param():
mock_client.messages.create.assert_called_once()
_, kwargs = mock_client.messages.create.call_args
assert "stream" not in kwargs


@pytest.mark.asyncio
async def test_non_streaming_forwards_generation_config_kwargs():
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

mock_message = anthropic_types.Message(
id="msg_test",
content=[
anthropic_types.TextBlock(text="Hello!", type="text", citations=None)
],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(return_value=mock_message)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(
system_instruction="Test",
temperature=0.0,
top_p=0.8,
top_k=12,
stop_sequences=["DONE"],
),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
_ = [r async for r in llm.generate_content_async(llm_request, stream=False)]

_, kwargs = mock_client.messages.create.call_args
assert kwargs["temperature"] == 0.0
assert kwargs["top_p"] == 0.8
assert kwargs["top_k"] == 12
assert kwargs["stop_sequences"] == ["DONE"]


@pytest.mark.asyncio
async def test_streaming_forwards_generation_config_kwargs():
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

events = [
MagicMock(
type="message_start",
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
),
MagicMock(
type="content_block_start",
index=0,
content_block=anthropic_types.TextBlock(text="", type="text"),
),
MagicMock(
type="content_block_delta",
index=0,
delta=anthropic_types.TextDelta(text="Hi", type="text_delta"),
),
MagicMock(type="content_block_stop", index=0),
MagicMock(
type="message_delta",
delta=MagicMock(stop_reason="end_turn"),
usage=MagicMock(output_tokens=1),
),
MagicMock(type="message_stop"),
]

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(
return_value=_make_mock_stream_events(events)
)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(
system_instruction="Test",
temperature=0.2,
top_p=0.7,
top_k=8,
stop_sequences=["STOP_HERE"],
),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
_ = [r async for r in llm.generate_content_async(llm_request, stream=True)]

_, kwargs = mock_client.messages.create.call_args
assert kwargs["stream"] is True
assert kwargs["temperature"] == 0.2
assert kwargs["top_p"] == 0.7
assert kwargs["top_k"] == 8
assert kwargs["stop_sequences"] == ["STOP_HERE"]
Loading