diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index a14c767f23..d35b4db128 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -377,6 +377,22 @@ def _resolve_model_name(self, model: Optional[str]) -> str: return match.group(1) return model + def _get_generation_kwargs( + self, llm_request: LlmRequest + ) -> dict[str, Any]: + generation_kwargs: dict[str, Any] = {} + + if llm_request.config.temperature is not None: + generation_kwargs["temperature"] = llm_request.config.temperature + if llm_request.config.top_p is not None: + generation_kwargs["top_p"] = llm_request.config.top_p + if llm_request.config.top_k is not None: + generation_kwargs["top_k"] = llm_request.config.top_k + if llm_request.config.stop_sequences: + generation_kwargs["stop_sequences"] = llm_request.config.stop_sequences + + return generation_kwargs + @override async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False @@ -401,6 +417,7 @@ async def generate_content_async( if llm_request.tools_dict else NOT_GIVEN ) + generation_kwargs = self._get_generation_kwargs(llm_request) if not stream: message = await self._anthropic_client.messages.create( @@ -410,6 +427,7 @@ async def generate_content_async( tools=tools, tool_choice=tool_choice, max_tokens=self.max_tokens, + **generation_kwargs, ) yield message_to_generate_content_response(message) else: @@ -439,6 +457,7 @@ async def _generate_content_streaming( tool_choice=tool_choice, max_tokens=self.max_tokens, stream=True, + **self._get_generation_kwargs(llm_request), ) # Track content blocks being built during streaming. diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index fb44d5c8e7..a7c4b8472a 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1350,3 +1350,108 @@ async def test_non_streaming_does_not_pass_stream_param(): mock_client.messages.create.assert_called_once() _, kwargs = mock_client.messages.create.call_args assert "stream" not in kwargs + + +@pytest.mark.asyncio +async def test_non_streaming_forwards_generation_config_kwargs(): + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + mock_message = anthropic_types.Message( + id="msg_test", + content=[ + anthropic_types.TextBlock(text="Hello!", type="text", citations=None) + ], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + temperature=0.0, + top_p=0.8, + top_k=12, + stop_sequences=["DONE"], + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=False)] + + _, kwargs = mock_client.messages.create.call_args + assert kwargs["temperature"] == 0.0 + assert kwargs["top_p"] == 0.8 + assert kwargs["top_k"] == 12 + assert kwargs["stop_sequences"] == ["DONE"] + + +@pytest.mark.asyncio +async def test_streaming_forwards_generation_config_kwargs(): + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hi", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=1), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + temperature=0.2, + top_p=0.7, + top_k=8, + stop_sequences=["STOP_HERE"], + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=True)] + + _, kwargs = mock_client.messages.create.call_args + assert kwargs["stream"] is True + assert kwargs["temperature"] == 0.2 + assert kwargs["top_p"] == 0.7 + assert kwargs["top_k"] == 8 + assert kwargs["stop_sequences"] == ["STOP_HERE"]