Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""

text = ''
tool_call_parts = []
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand Down Expand Up @@ -327,14 +326,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
if text:
yield self.__build_full_text_response(text)
text = ''
if tool_call_parts:
logger.debug('Returning aggregated tool_call_parts')
yield LlmResponse(
content=types.Content(role='model', parts=tool_call_parts),
model_version=self._model_version,
live_session_id=live_session_id,
)
tool_call_parts = []
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
Expand Down Expand Up @@ -362,10 +353,21 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
if text:
yield self.__build_full_text_response(text)
text = ''
tool_call_parts.extend([
# Yield tool calls immediately. gemini-3.1-flash-live-preview does
# not send turn_complete until AFTER it receives the tool response,
# so buffering tool calls until turn_complete deadlocks run_live()
# on that model. Earlier versions of this method (<= ADK 1.27)
# yielded immediately as well; the accumulation pattern introduced
# in 1.28 broke 3.1 Live compatibility.
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
])
]
yield LlmResponse(
content=types.Content(role='model', parts=parts),
model_version=self._model_version,
live_session_id=live_session_id,
)
if message.session_resumption_update:
logger.debug('Received session resumption message: %s', message)
yield (
Expand All @@ -383,14 +385,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
live_session_id=live_session_id,
)

if tool_call_parts:
logger.debug('Exited loop with pending tool_call_parts')
yield LlmResponse(
content=types.Content(role='model', parts=tool_call_parts),
model_version=self._model_version,
live_session_id=self._gemini_session.session_id,
)

async def close(self):
"""Closes the llm server connection."""

Expand Down
108 changes: 82 additions & 26 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,35 +1023,40 @@ async def mock_receive_generator():

assert len(responses) == 3

# First response: the audio content and grounding metadata
assert responses[0].grounding_metadata == grounding_metadata
assert responses[0].content == mock_content
# First response: the tool call — yielded immediately on arrival
# (no longer buffered until turn_complete, which would deadlock on 3.1)
assert responses[0].content is not None
assert responses[0].content.parts is not None
assert responses[0].content.parts[0].inline_data == audio_blob

# Second response: the tool call, buffered until turn_complete
assert responses[1].content is not None
assert responses[1].content.parts is not None
assert responses[1].content.parts[0].function_call is not None
assert responses[0].content.parts[0].function_call is not None
assert (
responses[1].content.parts[0].function_call.name
responses[0].content.parts[0].function_call.name
== 'enterprise_web_search'
)
assert responses[1].content.parts[0].function_call.args == {
assert responses[0].content.parts[0].function_call.args == {
'query': 'Google stock price today'
}
assert responses[1].grounding_metadata is None

# Second response: the audio content and grounding metadata
assert responses[1].grounding_metadata == grounding_metadata
assert responses[1].content == mock_content
assert responses[1].content is not None
assert responses[1].content.parts is not None
assert responses[1].content.parts[0].inline_data == audio_blob

# Third response: the turn_complete
assert responses[2].turn_complete is True


@pytest.mark.asyncio
async def test_receive_multiple_tool_calls_buffered_until_turn_complete(
async def test_receive_multiple_tool_call_messages_yielded_immediately(
gemini_connection, mock_gemini_session
):
"""Test receive buffers multiple tool call messages until turn complete."""
"""Test receive yields each tool_call message immediately (no buffering).

Tool calls MUST be yielded the moment they arrive, not accumulated until
turn_complete. gemini-3.1-flash-live-preview does not send turn_complete
until AFTER it receives the tool response — buffering causes a deadlock.
"""
# First tool call message
mock_tool_call_msg1 = mock.create_autospec(
types.LiveServerMessage, instance=True
Expand Down Expand Up @@ -1120,20 +1125,71 @@ async def mock_receive_generator():

responses = [resp async for resp in gemini_connection.receive()]

# Expected: One LlmResponse with both tool calls, then one with turn_complete
assert len(responses) == 2
# Expected: one response per tool_call message, then one for turn_complete.
assert len(responses) == 3

# First response: single LlmResponse carrying both function calls
# First response: tool_1 yielded immediately (not waiting for turn_complete)
assert responses[0].content is not None
assert len(responses[0].content.parts) == 1
assert responses[0].content.parts[0].function_call.name == 'tool_1'
assert responses[0].content.parts[0].function_call.args == {'arg': 'value1'}

# Second response: tool_2 yielded immediately
assert responses[1].content is not None
assert len(responses[1].content.parts) == 1
assert responses[1].content.parts[0].function_call.name == 'tool_2'
assert responses[1].content.parts[0].function_call.args == {'arg': 'value2'}

# Third response: turn_complete True
assert responses[2].turn_complete is True


@pytest.mark.asyncio
async def test_receive_tool_call_yielded_without_turn_complete(
gemini_connection, mock_gemini_session
):
"""Regression test for the Gemini 3.1 Flash Live deadlock.

Scenario: model sends a tool_call message but NOT turn_complete (as
gemini-3.1-flash-live-preview does — it waits for the tool response
before sending turn_complete). receive() must yield the tool_call so
the flow layer can execute the tool and send the response back.

Before the fix: receive() buffered the tool_call internally and only
yielded on turn_complete, causing run_live() to hang indefinitely on
3.1 models (tool never dispatched -> response never sent -> server
never completes the turn -> WebSocket eventually times out).
"""
function_call = types.FunctionCall(name='get_weather', args={'city': 'Paris'})
mock_tool_call = mock.create_autospec(
types.LiveServerToolCall, instance=True
)
mock_tool_call.function_calls = [function_call]

mock_msg = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_msg.usage_metadata = None
mock_msg.server_content = None
mock_msg.tool_call = mock_tool_call
mock_msg.session_resumption_update = None
mock_msg.go_away = None

async def mock_receive_generator():
yield mock_msg
# NOTE: deliberately no turn_complete — mimics Gemini 3.1 Live behavior.
# The generator simply exhausts after the tool_call message.

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Must yield the tool_call even without turn_complete.
assert len(responses) == 1
assert responses[0].content is not None
parts = responses[0].content.parts
assert len(parts) == 2
assert parts[0].function_call.name == 'tool_1'
assert parts[0].function_call.args == {'arg': 'value1'}
assert parts[1].function_call.name == 'tool_2'
assert parts[1].function_call.args == {'arg': 'value2'}

# Second response: turn_complete True
assert responses[1].turn_complete is True
assert len(responses[0].content.parts) == 1
assert responses[0].content.parts[0].function_call.name == 'get_weather'
assert responses[0].content.parts[0].function_call.args == {'city': 'Paris'}
assert responses[0].model_version == MODEL_VERSION


@pytest.mark.asyncio
Expand Down
Loading