diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 3eb2fe5b3f..0509680090 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -187,7 +187,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: """ text = '' - tool_call_parts = [] async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -327,14 +326,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if text: yield self.__build_full_text_response(text) text = '' - if tool_call_parts: - logger.debug('Returning aggregated tool_call_parts') - yield LlmResponse( - content=types.Content(role='model', parts=tool_call_parts), - model_version=self._model_version, - live_session_id=live_session_id, - ) - tool_call_parts = [] yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, @@ -362,10 +353,21 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if text: yield self.__build_full_text_response(text) text = '' - tool_call_parts.extend([ + # Yield tool calls immediately. gemini-3.1-flash-live-preview does + # not send turn_complete until AFTER it receives the tool response, + # so buffering tool calls until turn_complete deadlocks run_live() + # on that model. Earlier versions of this method (<= ADK 1.27) + # yielded immediately as well; the accumulation pattern introduced + # in 1.28 broke 3.1 Live compatibility. + parts = [ types.Part(function_call=function_call) for function_call in message.tool_call.function_calls - ]) + ] + yield LlmResponse( + content=types.Content(role='model', parts=parts), + model_version=self._model_version, + live_session_id=live_session_id, + ) if message.session_resumption_update: logger.debug('Received session resumption message: %s', message) yield ( @@ -383,14 +385,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: live_session_id=live_session_id, ) - if tool_call_parts: - logger.debug('Exited loop with pending tool_call_parts') - yield LlmResponse( - content=types.Content(role='model', parts=tool_call_parts), - model_version=self._model_version, - live_session_id=self._gemini_session.session_id, - ) - async def close(self): """Closes the llm server connection.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d56ecde591..045ffc20f3 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1023,35 +1023,40 @@ async def mock_receive_generator(): assert len(responses) == 3 - # First response: the audio content and grounding metadata - assert responses[0].grounding_metadata == grounding_metadata - assert responses[0].content == mock_content + # First response: the tool call — yielded immediately on arrival + # (no longer buffered until turn_complete, which would deadlock on 3.1) assert responses[0].content is not None assert responses[0].content.parts is not None - assert responses[0].content.parts[0].inline_data == audio_blob - - # Second response: the tool call, buffered until turn_complete - assert responses[1].content is not None - assert responses[1].content.parts is not None - assert responses[1].content.parts[0].function_call is not None + assert responses[0].content.parts[0].function_call is not None assert ( - responses[1].content.parts[0].function_call.name + responses[0].content.parts[0].function_call.name == 'enterprise_web_search' ) - assert responses[1].content.parts[0].function_call.args == { + assert responses[0].content.parts[0].function_call.args == { 'query': 'Google stock price today' } - assert responses[1].grounding_metadata is None + + # Second response: the audio content and grounding metadata + assert responses[1].grounding_metadata == grounding_metadata + assert responses[1].content == mock_content + assert responses[1].content is not None + assert responses[1].content.parts is not None + assert responses[1].content.parts[0].inline_data == audio_blob # Third response: the turn_complete assert responses[2].turn_complete is True @pytest.mark.asyncio -async def test_receive_multiple_tool_calls_buffered_until_turn_complete( +async def test_receive_multiple_tool_call_messages_yielded_immediately( gemini_connection, mock_gemini_session ): - """Test receive buffers multiple tool call messages until turn complete.""" + """Test receive yields each tool_call message immediately (no buffering). + + Tool calls MUST be yielded the moment they arrive, not accumulated until + turn_complete. gemini-3.1-flash-live-preview does not send turn_complete + until AFTER it receives the tool response — buffering causes a deadlock. + """ # First tool call message mock_tool_call_msg1 = mock.create_autospec( types.LiveServerMessage, instance=True @@ -1120,20 +1125,71 @@ async def mock_receive_generator(): responses = [resp async for resp in gemini_connection.receive()] - # Expected: One LlmResponse with both tool calls, then one with turn_complete - assert len(responses) == 2 + # Expected: one response per tool_call message, then one for turn_complete. + assert len(responses) == 3 - # First response: single LlmResponse carrying both function calls + # First response: tool_1 yielded immediately (not waiting for turn_complete) + assert responses[0].content is not None + assert len(responses[0].content.parts) == 1 + assert responses[0].content.parts[0].function_call.name == 'tool_1' + assert responses[0].content.parts[0].function_call.args == {'arg': 'value1'} + + # Second response: tool_2 yielded immediately + assert responses[1].content is not None + assert len(responses[1].content.parts) == 1 + assert responses[1].content.parts[0].function_call.name == 'tool_2' + assert responses[1].content.parts[0].function_call.args == {'arg': 'value2'} + + # Third response: turn_complete True + assert responses[2].turn_complete is True + + +@pytest.mark.asyncio +async def test_receive_tool_call_yielded_without_turn_complete( + gemini_connection, mock_gemini_session +): + """Regression test for the Gemini 3.1 Flash Live deadlock. + + Scenario: model sends a tool_call message but NOT turn_complete (as + gemini-3.1-flash-live-preview does — it waits for the tool response + before sending turn_complete). receive() must yield the tool_call so + the flow layer can execute the tool and send the response back. + + Before the fix: receive() buffered the tool_call internally and only + yielded on turn_complete, causing run_live() to hang indefinitely on + 3.1 models (tool never dispatched -> response never sent -> server + never completes the turn -> WebSocket eventually times out). + """ + function_call = types.FunctionCall(name='get_weather', args={'city': 'Paris'}) + mock_tool_call = mock.create_autospec( + types.LiveServerToolCall, instance=True + ) + mock_tool_call.function_calls = [function_call] + + mock_msg = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_msg.usage_metadata = None + mock_msg.server_content = None + mock_msg.tool_call = mock_tool_call + mock_msg.session_resumption_update = None + mock_msg.go_away = None + + async def mock_receive_generator(): + yield mock_msg + # NOTE: deliberately no turn_complete — mimics Gemini 3.1 Live behavior. + # The generator simply exhausts after the tool_call message. + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Must yield the tool_call even without turn_complete. + assert len(responses) == 1 assert responses[0].content is not None - parts = responses[0].content.parts - assert len(parts) == 2 - assert parts[0].function_call.name == 'tool_1' - assert parts[0].function_call.args == {'arg': 'value1'} - assert parts[1].function_call.name == 'tool_2' - assert parts[1].function_call.args == {'arg': 'value2'} - - # Second response: turn_complete True - assert responses[1].turn_complete is True + assert len(responses[0].content.parts) == 1 + assert responses[0].content.parts[0].function_call.name == 'get_weather' + assert responses[0].content.parts[0].function_call.args == {'city': 'Paris'} + assert responses[0].model_version == MODEL_VERSION @pytest.mark.asyncio