diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 76dd2ddab4..917477db01 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -31,11 +31,6 @@ from .auth_tool import AuthConfig from .auth_tool import AuthToolArguments -# Prefix used by toolset auth credential IDs. -# Auth requests with this prefix are for toolset authentication (before tool -# listing) and don't require resuming a function call. -TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' - async def _store_auth_and_collect_resume_targets( events: list[Event], @@ -50,7 +45,7 @@ async def _store_auth_and_collect_resume_targets( ``AuthToolArguments`` args, merges ``credential_key`` into the corresponding auth response, stores credentials via ``AuthHandler``, and returns the set of original function call IDs that should be - re-executed (excluding toolset auth). + re-executed. Args: events: Session events to scan. @@ -96,8 +91,7 @@ async def _store_auth_and_collect_resume_targets( state=state ) - # Step 3: Collect original function call IDs to resume, skipping - # toolset auth entries which don't map to a resumable function call. + # Step 3: Collect original function call IDs to resume. tools_to_resume: set[str] = set() for fc_id in auth_fc_ids: requested_auth_config = requested_auth_config_by_id.get(fc_id) @@ -115,10 +109,6 @@ async def _store_auth_and_collect_resume_targets( and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME ): args = AuthToolArguments.model_validate(function_call.args) - if args.function_call_id.startswith( - TOOLSET_AUTH_CREDENTIAL_ID_PREFIX - ): - continue tools_to_resume.add(args.function_call_id) return tools_to_resume diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 7a98315a54..1aca2d7410 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -36,8 +36,6 @@ from ...agents.live_request_queue import LiveRequestQueue from ...agents.readonly_context import ReadonlyContext from ...agents.run_config import StreamingMode -from ...auth.auth_handler import AuthHandler -from ...auth.auth_tool import AuthConfig from ...auth.credential_manager import CredentialManager from ...events.event import Event from ...models.base_llm_connection import BaseLlmConnection @@ -51,10 +49,6 @@ from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing from .audio_cache_manager import AudioCacheManager -from .functions import build_auth_request_event - -# Prefix used by toolset auth credential IDs -TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent @@ -115,24 +109,24 @@ def _finalize_model_response_event( async def _resolve_toolset_auth( invocation_context: InvocationContext, agent: LlmAgent, -) -> AsyncGenerator[Event, None]: +) -> None: """Resolves authentication for toolsets before tool listing. For each toolset with auth configured via get_auth_config(): - If credential is available, populate auth_config.exchanged_auth_credential - - If credential is not available, yield auth request event and interrupt + - If credential is not available, log and continue — auth will be handled + on demand by ToolAuthHandler when a tool is actually invoked. + + This avoids triggering OAuth redirects on every agent invocation, + including messages that don't require any tool calls. Args: invocation_context: The invocation context. agent: The LLM agent. - - Yields: - Auth request events if any toolset needs authentication. """ if not agent.tools: return - pending_auth_requests: dict[str, AuthConfig] = {} callback_context = CallbackContext(invocation_context) for tool_union in agent.tools: @@ -161,30 +155,11 @@ async def _resolve_toolset_auth( # Populate in-place for toolset to use in get_tools() auth_config.exchanged_auth_credential = credential else: - # Need auth - will interrupt - toolset_id = ( - f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}' + logger.debug( + 'No credential found for toolset %s; deferring auth to tool' + ' invocation.', + type(tool_union).__name__, ) - pending_auth_requests[toolset_id] = auth_config - - if not pending_auth_requests: - return - - # Build auth requests dict with generated auth requests - auth_requests = { - credential_id: AuthHandler(auth_config).generate_auth_request() - for credential_id, auth_config in pending_auth_requests.items() - } - - # Yield event with auth requests using the shared helper - yield build_auth_request_event( - invocation_context, - auth_requests, - author=agent.name, - ) - - # Interrupt invocation - invocation_context.end_invocation = True async def _handle_before_model_callback( @@ -916,14 +891,7 @@ async def _preprocess_async( # Resolve toolset authentication before tool listing. # This ensures credentials are ready before get_tools() is called. - async with Aclosing( - self._resolve_toolset_auth(invocation_context, agent) - ) as agen: - async for event in agen: - yield event - - if invocation_context.end_invocation: - return + await _resolve_toolset_auth(invocation_context, agent) # Run processors for tools. await _process_agent_tools(invocation_context, llm_request) @@ -1273,17 +1241,6 @@ def _finalize_model_response_event( llm_request, llm_response, model_response_event ) - async def _resolve_toolset_auth( - self, - invocation_context: InvocationContext, - agent: LlmAgent, - ) -> AsyncGenerator[Event, None]: - async with Aclosing( - _resolve_toolset_auth(invocation_context, agent) - ) as agen: - async for event in agen: - yield event - async def _handle_before_model_callback( self, invocation_context: InvocationContext, diff --git a/tests/unittests/auth/test_toolset_auth.py b/tests/unittests/auth/test_toolset_auth.py index b5efc42551..99d05413f9 100644 --- a/tests/unittests/auth/test_toolset_auth.py +++ b/tests/unittests/auth/test_toolset_auth.py @@ -16,7 +16,6 @@ from typing import Optional from unittest.mock import AsyncMock -from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch @@ -28,12 +27,8 @@ from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth -from google.adk.auth.auth_preprocessor import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX from google.adk.auth.auth_tool import AuthConfig -from google.adk.auth.auth_tool import AuthToolArguments from google.adk.flows.llm_flows.base_llm_flow import _resolve_toolset_auth -from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow -from google.adk.flows.llm_flows.base_llm_flow import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX as FLOW_PREFIX from google.adk.flows.llm_flows.functions import build_auth_request_event from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from google.adk.tools.base_tool import BaseTool @@ -85,15 +80,6 @@ def create_oauth2_auth_config() -> AuthConfig: ) -class TestToolsetAuthPrefixConstant: - """Test that prefix constants are consistent.""" - - def test_prefix_constants_match(self): - """Ensure auth_preprocessor and base_llm_flow use the same prefix.""" - assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == FLOW_PREFIX - assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == "_adk_toolset_auth_" - - class TestResolveToolsetAuth: """Tests for _resolve_toolset_auth method in BaseLlmFlow.""" @@ -121,19 +107,12 @@ def mock_agent(self): return agent @pytest.mark.asyncio - async def test_no_tools_returns_no_events( - self, mock_invocation_context, mock_agent - ): - """Test that no events are yielded when agent has no tools.""" + async def test_no_tools_completes(self, mock_invocation_context, mock_agent): + """Test that resolve completes without side effects when agent has no tools.""" mock_agent.tools = [] - events = [] - async for event in _resolve_toolset_auth( - mock_invocation_context, mock_agent - ): - events.append(event) + await _resolve_toolset_auth(mock_invocation_context, mock_agent) - assert len(events) == 0 assert mock_invocation_context.end_invocation is False @pytest.mark.asyncio @@ -144,13 +123,8 @@ async def test_toolset_without_auth_config_skipped( toolset = MockToolset(auth_config=None) mock_agent.tools = [toolset] - events = [] - async for event in _resolve_toolset_auth( - mock_invocation_context, mock_agent - ): - events.append(event) + await _resolve_toolset_auth(mock_invocation_context, mock_agent) - assert len(events) == 0 assert mock_invocation_context.end_invocation is False @pytest.mark.asyncio @@ -162,7 +136,6 @@ async def test_toolset_with_credential_available_populates_config( toolset = MockToolset(auth_config=auth_config) mock_agent.tools = [toolset] - # Mock CredentialManager to return a credential mock_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, oauth2=OAuth2Auth(access_token="test-token"), @@ -175,23 +148,21 @@ async def test_toolset_with_credential_available_populates_config( mock_manager.get_auth_credential = AsyncMock(return_value=mock_credential) MockCredentialManager.return_value = mock_manager - events = [] - async for event in _resolve_toolset_auth( - mock_invocation_context, mock_agent - ): - events.append(event) + await _resolve_toolset_auth(mock_invocation_context, mock_agent) - # No auth request events - credential was available - assert len(events) == 0 assert mock_invocation_context.end_invocation is False - # Credential should be populated in auth_config assert auth_config.exchanged_auth_credential == mock_credential @pytest.mark.asyncio - async def test_toolset_without_credential_yields_auth_event( + async def test_toolset_without_credential_defers_auth( self, mock_invocation_context, mock_agent ): - """Test that auth request event is yielded when credential not available.""" + """Test that auth is deferred when credential is not available. + + When no credential is found, _resolve_toolset_auth should not interrupt + the invocation. Auth will be handled on demand by ToolAuthHandler when + a tool is actually invoked. + """ auth_config = create_oauth2_auth_config() toolset = MockToolset(auth_config=auth_config) mock_agent.tools = [toolset] @@ -203,37 +174,16 @@ async def test_toolset_without_credential_yields_auth_event( mock_manager.get_auth_credential = AsyncMock(return_value=None) MockCredentialManager.return_value = mock_manager - events = [] - async for event in _resolve_toolset_auth( - mock_invocation_context, mock_agent - ): - events.append(event) - - # Should yield one auth request event - assert len(events) == 1 - assert mock_invocation_context.end_invocation is True + await _resolve_toolset_auth(mock_invocation_context, mock_agent) - # Check event structure - event = events[0] - assert event.invocation_id == "test-invocation-id" - assert event.author == "test-agent" - assert event.content is not None - assert len(event.content.parts) == 1 - - # Check function call - fc = event.content.parts[0].function_call - assert fc.name == REQUEST_EUC_FUNCTION_CALL_NAME - # The args use camelCase aliases from the pydantic model - assert fc.args["functionCallId"].startswith( - TOOLSET_AUTH_CREDENTIAL_ID_PREFIX - ) - assert "MockToolset" in fc.args["functionCallId"] + assert mock_invocation_context.end_invocation is False + assert auth_config.exchanged_auth_credential is None @pytest.mark.asyncio - async def test_multiple_toolsets_needing_auth( + async def test_multiple_toolsets_without_credentials_defers_auth( self, mock_invocation_context, mock_agent ): - """Test that multiple toolsets needing auth yield multiple function calls.""" + """Test that multiple toolsets without credentials do not interrupt.""" auth_config1 = create_oauth2_auth_config() auth_config2 = create_oauth2_auth_config() toolset1 = MockToolset(auth_config=auth_config1) @@ -247,40 +197,51 @@ async def test_multiple_toolsets_needing_auth( mock_manager.get_auth_credential = AsyncMock(return_value=None) MockCredentialManager.return_value = mock_manager - events = [] - async for event in _resolve_toolset_auth( - mock_invocation_context, mock_agent - ): - events.append(event) + await _resolve_toolset_auth(mock_invocation_context, mock_agent) + + assert mock_invocation_context.end_invocation is False - # Should yield one event with multiple function calls - # But since both toolsets have same class name, they'll have same ID - # and only one will be in pending_auth_requests (dict overwrites) - assert len(events) == 1 - assert mock_invocation_context.end_invocation is True + @pytest.mark.asyncio + async def test_mixed_toolsets_populates_available_credentials( + self, mock_invocation_context, mock_agent + ): + """Test that credentials are populated when available, without interrupt. + When one toolset has credentials and another does not, the available + credential should be populated while the missing one is deferred. + """ + auth_config_with_cred = create_oauth2_auth_config() + auth_config_without_cred = create_oauth2_auth_config() + toolset_with_cred = MockToolset(auth_config=auth_config_with_cred) + toolset_without_cred = MockToolset(auth_config=auth_config_without_cred) + mock_agent.tools = [toolset_with_cred, toolset_without_cred] -class TestAuthPreprocessorToolsetAuthSkip: - """Tests for auth preprocessor skipping toolset auth.""" + mock_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(access_token="test-token"), + ) - def test_toolset_auth_prefix_skipped(self): - """Test that function calls with toolset auth prefix are skipped.""" - from google.adk.auth.auth_preprocessor import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX + call_count = 0 - # Verify the prefix is correct - assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == "_adk_toolset_auth_" + async def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_credential + return None - # Test that a function_call_id starting with this prefix would be skipped - toolset_function_call_id = f"{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}McpToolset" - assert toolset_function_call_id.startswith( - TOOLSET_AUTH_CREDENTIAL_ID_PREFIX - ) + with patch( + "google.adk.flows.llm_flows.base_llm_flow.CredentialManager" + ) as MockCredentialManager: + mock_manager = AsyncMock() + mock_manager.get_auth_credential = AsyncMock(side_effect=side_effect) + MockCredentialManager.return_value = mock_manager - # Regular tool auth function_call_id should NOT start with prefix - regular_function_call_id = "call_123" - assert not regular_function_call_id.startswith( - TOOLSET_AUTH_CREDENTIAL_ID_PREFIX - ) + await _resolve_toolset_auth(mock_invocation_context, mock_agent) + + assert mock_invocation_context.end_invocation is False + assert auth_config_with_cred.exchanged_auth_credential == mock_credential + assert auth_config_without_cred.exchanged_auth_credential is None class TestCallbackContextGetAuthResponse: