Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,28 +328,14 @@ async def _prepare_session(
run_request: AgentRunRequest,
runner: Runner,
):

session_id = run_request.session_id
# create a new session if not exists
user_id = run_request.user_id
session = await runner.session_service.get_session(
app_name=runner.app_name,
user_id=user_id,
session_id=session_id,
session = await runner._get_or_create_session(
user_id=run_request.user_id,
session_id=run_request.session_id,
)
if session is None:
session = await runner.session_service.create_session(
app_name=runner.app_name,
user_id=user_id,
state={},
session_id=session_id,
)
# Update run_request with the new session_id
run_request.session_id = session.id

run_request.session_id = session.id
return session

def _check_new_version_extension(self, context: RequestContext):
def _check_new_version_extension(self, context: RequestContext) -> bool:
"""Check if the extension for the new version is requested and activate it."""
if _NEW_A2A_ADK_INTEGRATION_EXTENSION in context.requested_extensions:
context.add_activated_extension(_NEW_A2A_ADK_INTEGRATION_EXTENSION)
Expand Down
32 changes: 13 additions & 19 deletions src/google/adk/a2a/executor/a2a_agent_executor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datetime import timezone
import inspect
import logging
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Optional
Expand Down Expand Up @@ -281,29 +282,22 @@ async def _resolve_session(
run_request: AgentRunRequest,
runner: Runner,
):
session_id = run_request.session_id
# create a new session if not exists
user_id = run_request.user_id
session = await runner.session_service.get_session(
app_name=runner.app_name,
user_id=user_id,
session_id=session_id,
# Checking existence doesn't require event history.
config=base_session_service.GetSessionConfig(num_recent_events=0),
if not run_request.user_id:
raise ValueError('user_id must be set in AgentRunRequest')
if not run_request.session_id:
raise ValueError('session_id must be set in AgentRunRequest')
session = await runner._get_or_create_session(
user_id=run_request.user_id,
session_id=run_request.session_id,
get_session_config=base_session_service.GetSessionConfig(
num_recent_events=0
),
)
if session is None:
session = await runner.session_service.create_session(
app_name=runner.app_name,
user_id=user_id,
state={},
session_id=session_id,
)
# Update run_request with the new session_id
run_request.session_id = session.id
run_request.session_id = session.id

def _get_invocation_metadata(
self, executor_context: ExecutorContext
) -> dict[str, str]:
) -> dict[str, Any]:
return {
_get_adk_metadata_key('app_name'): executor_context.app_name,
_get_adk_metadata_key('user_id'): executor_context.user_id,
Expand Down
8 changes: 6 additions & 2 deletions src/google/adk/a2a/utils/agent_to_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from contextlib import AbstractAsyncContextManager
from contextlib import asynccontextmanager
import logging
from typing import AsyncIterator
Expand Down Expand Up @@ -85,7 +86,9 @@ def to_a2a(
agent_card: Optional[Union[AgentCard, str]] = None,
push_config_store: Optional[PushNotificationConfigStore] = None,
runner: Optional[Runner] = None,
lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None,
lifespan: Optional[
Callable[[Starlette], AbstractAsyncContextManager[None]]
] = None,
) -> Starlette:
"""Convert an ADK agent to a A2A Starlette application.

Expand Down Expand Up @@ -142,6 +145,7 @@ async def create_runner() -> Runner:
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=InMemoryCredentialService(),
auto_create_session=True,
)

# Create A2A components
Expand Down Expand Up @@ -170,7 +174,7 @@ async def create_runner() -> Runner:
)

# Build the agent card and configure A2A routes
async def setup_a2a(app: Starlette):
async def setup_a2a(app: Starlette) -> None:
# Use provided agent card or build one asynchronously
if provided_agent_card is not None:
final_agent_card = provided_agent_card
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/memory/vertex_ai_memory_bank_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# Strong references to fire-and-forget tasks to prevent garbage collection.
# See https://docs.python.org/3/library/asyncio-task.html#creating-tasks
_background_tasks: set[asyncio.Task] = set()
_background_tasks: set[asyncio.Task[None]] = set()

_GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS = frozenset({
'disable_consolidation',
Expand Down Expand Up @@ -565,7 +565,7 @@ def _get_api_client(self) -> vertexai.AsyncClient:
return vertexai.Client(project=self._project, location=self._location).aio


def _log_ingest_task_error(task: asyncio.Task) -> None:
def _log_ingest_task_error(task: asyncio.Task[None]) -> None:
"""Logs errors from fire-and-forget ingest_events tasks."""
if task.cancelled():
return
Expand Down
23 changes: 13 additions & 10 deletions tests/unittests/a2a/executor/test_a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,17 @@ async def mock_run_async(**kwargs):

@pytest.mark.asyncio
async def test_prepare_session_new_session(self):
"""Test session preparation when session doesn't exist."""
"""Test session preparation delegates to runner._get_or_create_session."""
run_args = AgentRunRequest(
user_id="test-user",
session_id=None,
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)

# Mock session service
self.mock_runner.session_service.get_session = AsyncMock(return_value=None)
mock_session = Mock()
mock_session.id = "new-session-id"
self.mock_runner.session_service.create_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=mock_session
)

Expand All @@ -245,10 +243,13 @@ async def test_prepare_session_new_session(self):
self.mock_context, run_args, self.mock_runner
)

# Verify session was created
# Verify session was returned and run_request updated
assert result == mock_session
assert run_args.session_id is not None
self.mock_runner.session_service.create_session.assert_called_once()
assert run_args.session_id == "new-session-id"
self.mock_runner._get_or_create_session.assert_called_once_with(
user_id="test-user",
session_id=None,
)

@pytest.mark.asyncio
async def test_prepare_session_existing_session(self):
Expand All @@ -260,10 +261,9 @@ async def test_prepare_session_existing_session(self):
run_config=Mock(spec=RunConfig),
)

# Mock session service
mock_session = Mock()
mock_session.id = "existing-session"
self.mock_runner.session_service.get_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=mock_session
)

Expand All @@ -274,7 +274,10 @@ async def test_prepare_session_existing_session(self):

# Verify existing session was returned
assert result == mock_session
self.mock_runner.session_service.create_session.assert_not_called()
self.mock_runner._get_or_create_session.assert_called_once_with(
user_id="test-user",
session_id="existing-session",
)

def test_constructor_with_callable_runner(self):
"""Test constructor with callable runner."""
Expand Down
29 changes: 11 additions & 18 deletions tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ async def test_execute_success_new_task(self):
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
# Mock session service
# Mock _get_or_create_session
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=mock_session
)

Expand Down Expand Up @@ -200,10 +200,10 @@ async def test_execute_existing_task(self):
run_config=Mock(spec=RunConfig),
)

# Mock session service
# Mock _get_or_create_session
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=mock_session
)

Expand Down Expand Up @@ -616,7 +616,7 @@ async def test_execute_missing_user_input(self, mock_handle_user_input):
)
mock_handle_user_input.return_value = missing_event

self.mock_runner.session_service.get_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=Mock(id="test-session")
)
self.mock_request_converter.return_value = AgentRunRequest(
Expand All @@ -638,12 +638,10 @@ async def test_execute_missing_user_input(self, mock_handle_user_input):

@pytest.mark.asyncio
async def test_resolve_session_creates_new_session(self):
"""Test that _resolve_session creates a new session if it doesn't exist."""
self.mock_runner.session_service.get_session = AsyncMock(return_value=None)

"""Test that _resolve_session delegates to runner._get_or_create_session."""
new_session = Mock()
new_session.id = "new-session-id"
self.mock_runner.session_service.create_session = AsyncMock(
self.mock_runner._get_or_create_session = AsyncMock(
return_value=new_session
)

Expand All @@ -656,17 +654,12 @@ async def test_resolve_session_creates_new_session(self):

await self.executor._resolve_session(run_request, self.mock_runner)

self.mock_runner.session_service.get_session.assert_called_once_with(
app_name=self.mock_runner.app_name,
self.mock_runner._get_or_create_session.assert_called_once_with(
user_id="test-user",
session_id="old-session-id",
config=GetSessionConfig(num_recent_events=0, after_timestamp=None),
)
self.mock_runner.session_service.create_session.assert_called_once_with(
app_name=self.mock_runner.app_name,
user_id="test-user",
state={},
session_id="old-session-id",
get_session_config=GetSessionConfig(
num_recent_events=0, after_timestamp=None
),
)
assert run_request.session_id == "new-session-id"

Expand Down
1 change: 1 addition & 0 deletions tests/unittests/a2a/integration/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, run_async_fn):
app_name="FakeApp",
agent=agent,
session_service=session_service,
auto_create_session=True,
)
self.run_async_fn = run_async_fn

Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/a2a/utils/test_agent_to_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ async def test_create_runner_function_creates_runner_correctly(
session_service=mock_runner_class.call_args[1]["session_service"],
memory_service=mock_runner_class.call_args[1]["memory_service"],
credential_service=mock_runner_class.call_args[1]["credential_service"],
auto_create_session=True,
)

# Verify the services are of the correct types
Expand Down Expand Up @@ -391,6 +392,7 @@ async def test_create_runner_function_with_agent_without_name(
session_service=mock_runner_class.call_args[1]["session_service"],
memory_service=mock_runner_class.call_args[1]["memory_service"],
credential_service=mock_runner_class.call_args[1]["credential_service"],
auto_create_session=True,
)

@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
Expand Down