From 5a56f91c0b1c5bc999766ff06b9a2ef30b8b54d3 Mon Sep 17 00:00:00 2001 From: Nishar Date: Sat, 18 Apr 2026 19:12:44 -0500 Subject: [PATCH] fix: make temp: state keys accessible in session.state during invocation Fixes #3047 temp: prefixed state keys (e.g. set via output_key='temp:...') were being dropped from session.state before lifecycle callbacks could read them. Root cause: BaseSessionService.__update_session_state() explicitly skipped keys starting with 'temp:', preventing them from ever reaching session.state. Changes: - base_session_service.py: Remove the temp: skip in __update_session_state() so all keys flow into the in-memory session.state. - in_memory_session_service.py: Strip temp: keys from event.actions.state_delta after updating the caller's session but before writing to the storage session, preventing temp keys from leaking into persisted state. - test_session_service.py: Update existing tests to reflect that temp: keys are now accessible in-memory, and add a dedicated regression test. Temp keys remain non-persistent: - DatabaseSessionService uses _extract_state_delta() which already filters temp: keys before writing to the database. - InMemorySessionService now strips temp: keys before updating shared state stores and the storage session. - Runner creates a deepcopy of session at invocation start, preventing temp keys from leaking across invocations. --- .../adk/sessions/base_session_service.py | 5 +- .../adk/sessions/in_memory_session_service.py | 8 +++ .../sessions/test_session_service.py | 51 +++++++++++++++++-- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 25e46ba199..3786af15bb 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -103,7 +103,4 @@ def __update_session_state(self, session: Session, event: Event) -> None: """Updates the session state based on the event.""" if not event.actions or not event.actions.state_delta: return - for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue - session.state.update({key: value}) + session.state.update(event.actions.state_delta) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b2a84effc3..08fc92fce4 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -259,6 +259,14 @@ async def append_event(self, session: Session, event: Event) -> Event: await super().append_event(session=session, event=event) session.last_update_time = event.timestamp + # Strip temp: keys before persisting to storage. + if event.actions and event.actions.state_delta: + event.actions.state_delta = { + k: v + for k, v in event.actions.state_delta.items() + if not k.startswith(State.TEMP_PREFIX) + } + # Update the storage session app_name = session.app_name user_id = session.user_id diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index ec93caafbb..e7337bdf19 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -156,11 +156,11 @@ async def test_session_state(service_type): ) await session_service.append_event(session=session_11, event=event) - # User and app state is stored, temp state is filtered. + # User and app state is stored, temp state is accessible in-memory. assert session_11.state.get('app:key') == 'value' assert session_11.state.get('key11') == 'value11_new' assert session_11.state.get('user:key1') == 'value1' - assert not session_11.state.get('temp:key') + assert session_11.state.get('temp:key') == 'temp' session_12 = await session_service.get_session( app_name=app_name, user_id=user_id_1, session_id=session_id_12 @@ -218,11 +218,11 @@ async def test_create_new_session_will_merge_states(service_type): ) await session_service.append_event(session=session_1, event=event) - # User and app state is stored, temp state is filtered. + # User and app state is stored, temp state is accessible in-memory. assert session_1.state.get('app:key') == 'value' assert session_1.state.get('key1') == 'value1' assert session_1.state.get('user:key1') == 'value1' - assert not session_1.state.get('temp:key') + assert session_1.state.get('temp:key') == 'temp' session_2 = await session_service.create_session( app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 @@ -377,3 +377,46 @@ async def test_get_session_with_config(service_type): ) events = session.events assert len(events) == num_test_events - after_timestamp + 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_temp_state_accessible_in_session_during_invocation(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'test_user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + + event = Event( + invocation_id='invocation_1', + author='test_agent', + content=types.Content( + role='model', parts=[types.Part(text='Hello from agent')] + ), + actions=EventActions( + state_delta={ + 'temp:agent_output': 'Hello from agent', + 'temp:oauth_token': 'bearer_abc123', + 'persistent_key': 'should_persist', + } + ), + ) + await session_service.append_event(session=session, event=event) + + # temp: keys are accessible in-memory during the same invocation. + assert session.state.get('temp:agent_output') == 'Hello from agent' + assert session.state.get('temp:oauth_token') == 'bearer_abc123' + assert session.state.get('persistent_key') == 'should_persist' + + # temp: keys are not persisted to storage. + refetched = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert not refetched.state.get('temp:agent_output') + assert not refetched.state.get('temp:oauth_token') + assert refetched.state.get('persistent_key') == 'should_persist'