diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 25e46ba199..3786af15bb 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -103,7 +103,4 @@ def __update_session_state(self, session: Session, event: Event) -> None: """Updates the session state based on the event.""" if not event.actions or not event.actions.state_delta: return - for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue - session.state.update({key: value}) + session.state.update(event.actions.state_delta) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b2a84effc3..08fc92fce4 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -259,6 +259,14 @@ async def append_event(self, session: Session, event: Event) -> Event: await super().append_event(session=session, event=event) session.last_update_time = event.timestamp + # Strip temp: keys before persisting to storage. + if event.actions and event.actions.state_delta: + event.actions.state_delta = { + k: v + for k, v in event.actions.state_delta.items() + if not k.startswith(State.TEMP_PREFIX) + } + # Update the storage session app_name = session.app_name user_id = session.user_id diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index ec93caafbb..e7337bdf19 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -156,11 +156,11 @@ async def test_session_state(service_type): ) await session_service.append_event(session=session_11, event=event) - # User and app state is stored, temp state is filtered. + # User and app state is stored, temp state is accessible in-memory. assert session_11.state.get('app:key') == 'value' assert session_11.state.get('key11') == 'value11_new' assert session_11.state.get('user:key1') == 'value1' - assert not session_11.state.get('temp:key') + assert session_11.state.get('temp:key') == 'temp' session_12 = await session_service.get_session( app_name=app_name, user_id=user_id_1, session_id=session_id_12 @@ -218,11 +218,11 @@ async def test_create_new_session_will_merge_states(service_type): ) await session_service.append_event(session=session_1, event=event) - # User and app state is stored, temp state is filtered. + # User and app state is stored, temp state is accessible in-memory. assert session_1.state.get('app:key') == 'value' assert session_1.state.get('key1') == 'value1' assert session_1.state.get('user:key1') == 'value1' - assert not session_1.state.get('temp:key') + assert session_1.state.get('temp:key') == 'temp' session_2 = await session_service.create_session( app_name=app_name, user_id=user_id, state={}, session_id=session_id_2 @@ -377,3 +377,46 @@ async def test_get_session_with_config(service_type): ) events = session.events assert len(events) == num_test_events - after_timestamp + 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) +async def test_temp_state_accessible_in_session_during_invocation(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'test_user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + + event = Event( + invocation_id='invocation_1', + author='test_agent', + content=types.Content( + role='model', parts=[types.Part(text='Hello from agent')] + ), + actions=EventActions( + state_delta={ + 'temp:agent_output': 'Hello from agent', + 'temp:oauth_token': 'bearer_abc123', + 'persistent_key': 'should_persist', + } + ), + ) + await session_service.append_event(session=session, event=event) + + # temp: keys are accessible in-memory during the same invocation. + assert session.state.get('temp:agent_output') == 'Hello from agent' + assert session.state.get('temp:oauth_token') == 'bearer_abc123' + assert session.state.get('persistent_key') == 'should_persist' + + # temp: keys are not persisted to storage. + refetched = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert not refetched.state.get('temp:agent_output') + assert not refetched.state.get('temp:oauth_token') + assert refetched.state.get('persistent_key') == 'should_persist'