diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index a9b55f526e..32b2cfa9d2 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -328,28 +328,14 @@ async def _prepare_session( run_request: AgentRunRequest, runner: Runner, ): - - session_id = run_request.session_id - # create a new session if not exists - user_id = run_request.user_id - session = await runner.session_service.get_session( - app_name=runner.app_name, - user_id=user_id, - session_id=session_id, + session = await runner._get_or_create_session( + user_id=run_request.user_id, + session_id=run_request.session_id, ) - if session is None: - session = await runner.session_service.create_session( - app_name=runner.app_name, - user_id=user_id, - state={}, - session_id=session_id, - ) - # Update run_request with the new session_id - run_request.session_id = session.id - + run_request.session_id = session.id return session - def _check_new_version_extension(self, context: RequestContext): + def _check_new_version_extension(self, context: RequestContext) -> bool: """Check if the extension for the new version is requested and activate it.""" if _NEW_A2A_ADK_INTEGRATION_EXTENSION in context.requested_extensions: context.add_activated_extension(_NEW_A2A_ADK_INTEGRATION_EXTENSION) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py index 320af124df..228f03c606 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -18,6 +18,7 @@ from datetime import timezone import inspect import logging +from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -281,29 +282,22 @@ async def _resolve_session( run_request: AgentRunRequest, runner: Runner, ): - session_id = run_request.session_id - # create a new session if not exists - user_id = run_request.user_id - session = await runner.session_service.get_session( - app_name=runner.app_name, - user_id=user_id, - session_id=session_id, - # Checking existence doesn't require event history. - config=base_session_service.GetSessionConfig(num_recent_events=0), + if not run_request.user_id: + raise ValueError('user_id must be set in AgentRunRequest') + if not run_request.session_id: + raise ValueError('session_id must be set in AgentRunRequest') + session = await runner._get_or_create_session( + user_id=run_request.user_id, + session_id=run_request.session_id, + get_session_config=base_session_service.GetSessionConfig( + num_recent_events=0 + ), ) - if session is None: - session = await runner.session_service.create_session( - app_name=runner.app_name, - user_id=user_id, - state={}, - session_id=session_id, - ) - # Update run_request with the new session_id - run_request.session_id = session.id + run_request.session_id = session.id def _get_invocation_metadata( self, executor_context: ExecutorContext - ) -> dict[str, str]: + ) -> dict[str, Any]: return { _get_adk_metadata_key('app_name'): executor_context.app_name, _get_adk_metadata_key('user_id'): executor_context.user_id, diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 3e8ed461e2..d104afaf9b 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -14,6 +14,7 @@ from __future__ import annotations +from contextlib import AbstractAsyncContextManager from contextlib import asynccontextmanager import logging from typing import AsyncIterator @@ -85,7 +86,9 @@ def to_a2a( agent_card: Optional[Union[AgentCard, str]] = None, push_config_store: Optional[PushNotificationConfigStore] = None, runner: Optional[Runner] = None, - lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None, + lifespan: Optional[ + Callable[[Starlette], AbstractAsyncContextManager[None]] + ] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -142,6 +145,7 @@ async def create_runner() -> Runner: session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), credential_service=InMemoryCredentialService(), + auto_create_session=True, ) # Create A2A components @@ -170,7 +174,7 @@ async def create_runner() -> Runner: ) # Build the agent card and configure A2A routes - async def setup_a2a(app: Starlette): + async def setup_a2a(app: Starlette) -> None: # Use provided agent card or build one asynchronously if provided_agent_card is not None: final_agent_card = provided_agent_card diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 0906e9a6ba..c95fa06d2e 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -41,7 +41,7 @@ # Strong references to fire-and-forget tasks to prevent garbage collection. # See https://docs.python.org/3/library/asyncio-task.html#creating-tasks -_background_tasks: set[asyncio.Task] = set() +_background_tasks: set[asyncio.Task[None]] = set() _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS = frozenset({ 'disable_consolidation', @@ -565,7 +565,7 @@ def _get_api_client(self) -> vertexai.AsyncClient: return vertexai.Client(project=self._project, location=self._location).aio -def _log_ingest_task_error(task: asyncio.Task) -> None: +def _log_ingest_task_error(task: asyncio.Task[None]) -> None: """Logs errors from fire-and-forget ingest_events tasks.""" if task.cancelled(): return diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4f44e1363c..df0af8b944 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -224,7 +224,7 @@ async def mock_run_async(**kwargs): @pytest.mark.asyncio async def test_prepare_session_new_session(self): - """Test session preparation when session doesn't exist.""" + """Test session preparation delegates to runner._get_or_create_session.""" run_args = AgentRunRequest( user_id="test-user", session_id=None, @@ -232,11 +232,9 @@ async def test_prepare_session_new_session(self): run_config=Mock(spec=RunConfig), ) - # Mock session service - self.mock_runner.session_service.get_session = AsyncMock(return_value=None) mock_session = Mock() mock_session.id = "new-session-id" - self.mock_runner.session_service.create_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=mock_session ) @@ -245,10 +243,13 @@ async def test_prepare_session_new_session(self): self.mock_context, run_args, self.mock_runner ) - # Verify session was created + # Verify session was returned and run_request updated assert result == mock_session - assert run_args.session_id is not None - self.mock_runner.session_service.create_session.assert_called_once() + assert run_args.session_id == "new-session-id" + self.mock_runner._get_or_create_session.assert_called_once_with( + user_id="test-user", + session_id=None, + ) @pytest.mark.asyncio async def test_prepare_session_existing_session(self): @@ -260,10 +261,9 @@ async def test_prepare_session_existing_session(self): run_config=Mock(spec=RunConfig), ) - # Mock session service mock_session = Mock() mock_session.id = "existing-session" - self.mock_runner.session_service.get_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=mock_session ) @@ -274,7 +274,10 @@ async def test_prepare_session_existing_session(self): # Verify existing session was returned assert result == mock_session - self.mock_runner.session_service.create_session.assert_not_called() + self.mock_runner._get_or_create_session.assert_called_once_with( + user_id="test-user", + session_id="existing-session", + ) def test_constructor_with_callable_runner(self): """Test constructor with callable runner.""" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py index 940b79a0b9..11a7b709c2 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -97,10 +97,10 @@ async def test_execute_success_new_task(self): new_message=Mock(spec=Content), run_config=Mock(spec=RunConfig), ) - # Mock session service + # Mock _get_or_create_session mock_session = Mock() mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=mock_session ) @@ -200,10 +200,10 @@ async def test_execute_existing_task(self): run_config=Mock(spec=RunConfig), ) - # Mock session service + # Mock _get_or_create_session mock_session = Mock() mock_session.id = "test-session" - self.mock_runner.session_service.get_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=mock_session ) @@ -616,7 +616,7 @@ async def test_execute_missing_user_input(self, mock_handle_user_input): ) mock_handle_user_input.return_value = missing_event - self.mock_runner.session_service.get_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=Mock(id="test-session") ) self.mock_request_converter.return_value = AgentRunRequest( @@ -638,12 +638,10 @@ async def test_execute_missing_user_input(self, mock_handle_user_input): @pytest.mark.asyncio async def test_resolve_session_creates_new_session(self): - """Test that _resolve_session creates a new session if it doesn't exist.""" - self.mock_runner.session_service.get_session = AsyncMock(return_value=None) - + """Test that _resolve_session delegates to runner._get_or_create_session.""" new_session = Mock() new_session.id = "new-session-id" - self.mock_runner.session_service.create_session = AsyncMock( + self.mock_runner._get_or_create_session = AsyncMock( return_value=new_session ) @@ -656,17 +654,12 @@ async def test_resolve_session_creates_new_session(self): await self.executor._resolve_session(run_request, self.mock_runner) - self.mock_runner.session_service.get_session.assert_called_once_with( - app_name=self.mock_runner.app_name, + self.mock_runner._get_or_create_session.assert_called_once_with( user_id="test-user", session_id="old-session-id", - config=GetSessionConfig(num_recent_events=0, after_timestamp=None), - ) - self.mock_runner.session_service.create_session.assert_called_once_with( - app_name=self.mock_runner.app_name, - user_id="test-user", - state={}, - session_id="old-session-id", + get_session_config=GetSessionConfig( + num_recent_events=0, after_timestamp=None + ), ) assert run_request.session_id == "new-session-id" diff --git a/tests/unittests/a2a/integration/server.py b/tests/unittests/a2a/integration/server.py index c965a71091..aa0c9d6314 100644 --- a/tests/unittests/a2a/integration/server.py +++ b/tests/unittests/a2a/integration/server.py @@ -44,6 +44,7 @@ def __init__(self, run_async_fn): app_name="FakeApp", agent=agent, session_service=session_service, + auto_create_session=True, ) self.run_async_fn = run_async_fn diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9e2458ebd..8dd69b7552 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -327,6 +327,7 @@ async def test_create_runner_function_creates_runner_correctly( session_service=mock_runner_class.call_args[1]["session_service"], memory_service=mock_runner_class.call_args[1]["memory_service"], credential_service=mock_runner_class.call_args[1]["credential_service"], + auto_create_session=True, ) # Verify the services are of the correct types @@ -391,6 +392,7 @@ async def test_create_runner_function_with_agent_without_name( session_service=mock_runner_class.call_args[1]["session_service"], memory_service=mock_runner_class.call_args[1]["memory_service"], credential_service=mock_runner_class.call_args[1]["credential_service"], + auto_create_session=True, ) @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")