diff --git a/.github/workflows/integration-testing.yml b/.github/workflows/integration-testing.yml index a4964567..40d57dc5 100644 --- a/.github/workflows/integration-testing.yml +++ b/.github/workflows/integration-testing.yml @@ -146,7 +146,7 @@ jobs: - group: gateway path: tests_integ/gateway timeout: 15 - extra-deps: "" + extra-deps: "mcp-proxy-for-aws" ignore: "" - group: identity path: tests_integ/identity/test_identity_client.py diff --git a/pyproject.toml b/pyproject.toml index 89ef6c81..474747c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,13 +156,15 @@ dev = [ "strands-agents-evals>=0.1.0", "a2a-sdk[http-server]>=0.3", "ag-ui-protocol>=0.1.10", + "mcp-proxy-for-aws>=0.1.0", ] [project.optional-dependencies] a2a = ["a2a-sdk[http-server]>=0.3"] ag-ui = ["ag-ui-protocol>=0.1.10"] strands-agents = [ - "strands-agents>=1.20.0" + "strands-agents>=1.20.0", + "mcp>=1.23.0,<2.0.0", ] strands-agents-evals = [ "strands-agents-evals>=0.1.0" diff --git a/src/bedrock_agentcore/gateway/integrations/__init__.py b/src/bedrock_agentcore/gateway/integrations/__init__.py new file mode 100644 index 00000000..2da3d743 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/__init__.py @@ -0,0 +1 @@ +"""Gateway integrations.""" diff --git a/src/bedrock_agentcore/gateway/integrations/strands/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/__init__.py new file mode 100644 index 00000000..d00cd5fc --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/__init__.py @@ -0,0 +1 @@ +"""Strands Agents integrations for AgentCore Gateway.""" diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py new file mode 100644 index 00000000..9a6da0f7 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/__init__.py @@ -0,0 +1,5 @@ +"""Gateway Strands plugins.""" + +from .agentcore_tool_search import AgentCoreToolSearchPlugin + +__all__ = ["AgentCoreToolSearchPlugin"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md new file mode 100644 index 00000000..5d5ed741 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/README.md @@ -0,0 +1,138 @@ +# Strands AgentCore Tool Search Plugin + +A semantic tool discovery plugin for [Strands Agents](https://github.com/strands-agents/sdk-python) that uses the [Amazon Bedrock AgentCore Gateway](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-using-mcp-semantic-search.html) `x_amz_bedrock_agentcore_search` tool. This enables agents to dynamically load only the relevant tools for each invocation by deriving user intent from conversation history, even when hundreds of tools are registered on the gateway. + +## Features + +- **Semantic tool discovery** — uses AgentCore Gateway's built-in search to find relevant tools +- **Intent-based loading** — derives user intent via LLM before searching +- **No list_tools call** — tools are built directly from search results +- **Pluggable intent provider** — swap the default intent provider with your own +- **Agent model reuse** — by default, the intent classifier uses the same model as the parent agent + +## Installation + +```bash +pip install 'bedrock-agentcore[strands-agents]' +``` + +## Usage + +```python +from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client +from strands import Agent +from strands.tools.mcp import MCPClient +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin + +mcp_client = MCPClient(lambda: aws_iam_streamablehttp_client( + endpoint="https://.gateway.bedrock-agentcore..amazonaws.com/mcp", + aws_region="us-east-1", + aws_service="bedrock-agentcore", +)) + +mcp_client.start() + +agent = Agent(plugins=[AgentCoreToolSearchPlugin(mcp_client=mcp_client)]) + +agent("Find me afternoon flights to New York") +``` + +Or using a context manager: + +```python +with mcp_client: + agent = Agent(plugins=[AgentCoreToolSearchPlugin(mcp_client=mcp_client)]) + agent("Find me afternoon flights to New York") +``` + +## How It Works + +![Tool Search Flow](images/agentcore_tool_search_plugin.png) + +On each agent invocation: + +1. **User query** — The user sends a query to Strands agent. +2. **Hook** — The agent triggers the `AgentCoreToolSearchPlugin` before model invocation +3. **Derive intent** — The `IntentProvider` sends the last N messages from conversation history to the configured LLM to produce a concise intent string +4. **Search gateway** — The intent is passed to AgentCore Gateway's `x_amz_bedrock_agentcore_search` tool to obtain most relevant tools. +5. **Invoke LLM** — The agent invokes the LLM with the user query along with the matched tools from registered MCP targets (Lambda, API Gateway, MCP Server) + +Previously loaded tools are cleared before each search, so the agent always has the most relevant tools available. + +## Intent Provider + +An `IntentProvider` is responsible for analyzing conversation messages and producing a concise intent string that drives tool search. The plugin calls `derive_intent(messages, model)` before each invocation to determine what tools to load. + +### StrandsIntentProvider + +`StrandsIntentProvider` uses a Strands Agent to classify the last few conversation messages into a concise intent string. By default it uses the parent agent's model. + +**Basic usage (uses the agent's model automatically):** + +```python +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin + +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin(mcp_client=mcp_client) +]) +``` + +**With a custom model for intent classification:** + +```python +from strands.models.bedrock import BedrockModel +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import StrandsIntentProvider + +intent_model = BedrockModel(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0") +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin( + mcp_client=mcp_client, + intent_provider=StrandsIntentProvider(model=intent_model), + ) +]) +``` + +**With a custom system prompt:** + +```python +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import StrandsIntentProvider + +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin( + mcp_client=mcp_client, + intent_provider=StrandsIntentProvider( + system_prompt="Classify the user's intent in one sentence. Focus on the action, not details." + ), + ) +]) +``` + +### Custom Intent Provider + +You can provide your own intent derivation strategy by subclassing `IntentProvider`: + +```python +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import IntentProvider + +class MyIntentProvider(IntentProvider): + def derive_intent(self, messages: list[dict], model=None) -> str: + # custom logic to derive intent + return "intent string" + +agent = Agent(plugins=[ + AgentCoreToolSearchPlugin( + mcp_client=mcp_client, + intent_provider=MyIntentProvider(), + ) +]) +``` + +## Prerequisites + +- An AgentCore Gateway with **[semantic search](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-using-mcp-semantic-search.html) enabled** +- Tools registered on the gateway with descriptions +- AWS credentials with access to the gateway + +For more details, see the [AgentCore Gateway Documentation](https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/gateway-building.html). diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py new file mode 100644 index 00000000..a59f7b27 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/__init__.py @@ -0,0 +1,6 @@ +"""AgentCore Tool Search plugin for Strands Agents.""" + +from .intent_providers import IntentProvider, StrandsIntentProvider +from .plugin import AgentCoreToolSearchPlugin + +__all__ = ["AgentCoreToolSearchPlugin", "IntentProvider", "StrandsIntentProvider"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png new file mode 100644 index 00000000..4e29dc45 Binary files /dev/null and b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/images/agentcore_tool_search_plugin.png differ diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py new file mode 100644 index 00000000..d30209ac --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/__init__.py @@ -0,0 +1,6 @@ +"""Intent provider interfaces and implementations.""" + +from .intent_provider import IntentProvider +from .strands_intent_provider import StrandsIntentProvider + +__all__ = ["StrandsIntentProvider", "IntentProvider"] diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py new file mode 100644 index 00000000..864ece94 --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/intent_provider.py @@ -0,0 +1,26 @@ +"""Intent provider abstract interface.""" + +from abc import ABC, abstractmethod + + +class IntentProvider(ABC): + """Abstract interface for deriving user intent from conversation messages. + + Subclasses must implement the `derive_intent` method to analyze conversation + messages and return a concise intent string. + """ + + @abstractmethod + def derive_intent(self, messages: list[dict], model=None) -> str: + """Analyze conversation messages and return a concise intent string. + + Args: + messages: List of conversation message dicts in Strands format. + model: Optional model instance from the parent agent. Implementations + can use this for LLM-based intent derivation. + + Returns: + A plain text string describing the user's intent. + Returns empty string if intent cannot be determined. + """ + ... diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/strands_intent_provider.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/strands_intent_provider.py new file mode 100644 index 00000000..1833257f --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/intent_providers/strands_intent_provider.py @@ -0,0 +1,73 @@ +"""Strands Agent-based intent provider implementation.""" + +import logging + +from strands import Agent + +from .intent_provider import IntentProvider + +logger = logging.getLogger(__name__) + +INTENT_SYSTEM_PROMPT = ( + "You are an intent classifier. Given the recent conversation messages, " + "produce a concise one-sentence description of what the user is trying to accomplish. " + "Focus on the type of task, not the specific details. " + "Reply with ONLY the intent description, nothing else." +) + + +class StrandsIntentProvider(IntentProvider): + """LLM-based intent provider that uses a Strands Agent to classify the last N messages.""" + + def __init__(self, message_window: int = 5, model=None, system_prompt: str = INTENT_SYSTEM_PROMPT): + """Initialize StrandsIntentProvider. + + Args: + message_window: Number of recent messages to consider. + model: Optional explicit model for intent classification. + system_prompt: System prompt for the intent classifier. Defaults to INTENT_SYSTEM_PROMPT. + """ + self._message_window = message_window + self._explicit_model = model + self._system_prompt = system_prompt + + def derive_intent(self, messages: list[dict], model=None) -> str: + """Derive intent using an LLM. Falls back to agent's model if no explicit model set.""" + try: + recent_messages = messages[-self._message_window :] if messages else [] + if not recent_messages: + return "" + + kwargs = {"system_prompt": self._system_prompt, "tools": []} + # Priority: explicit model > agent's model > Strands default + resolved_model = self._explicit_model or model + if resolved_model: + kwargs["model"] = resolved_model + + intent_agent = Agent(**kwargs) + response = intent_agent(self._format_messages_for_prompt(recent_messages)) + return str(response).strip() + except Exception as e: + logger.error("Failed to derive intent: %s", e) + return "" + + def _format_messages_for_prompt(self, messages: list[dict]) -> str: + """Format user messages into a text prompt for the intent LLM. + + Only includes user-role messages to avoid leaking PII or sensitive data + from tool results or assistant responses. + """ + parts = [] + for msg in messages: + role = msg.get("role", "") + if role != "user": + continue + content = msg.get("content", []) + text = "" + if isinstance(content, list): + text = " ".join( + block.get("text", "") for block in content if isinstance(block, dict) and "text" in block + ) + if text.strip(): + parts.append(text.strip()) + return "\n".join(parts) diff --git a/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py new file mode 100644 index 00000000..28465a0f --- /dev/null +++ b/src/bedrock_agentcore/gateway/integrations/strands/plugins/agentcore_tool_search/plugin.py @@ -0,0 +1,123 @@ +"""AgentCore tool search plugin for Strands Agents.""" + +import json +import logging + +from mcp.types import Tool as MCPTool +from strands.hooks import BeforeInvocationEvent +from strands.plugins import Plugin, hook +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool + +from .intent_providers import IntentProvider, StrandsIntentProvider + +logger = logging.getLogger(__name__) + + +class AgentCoreToolSearchPlugin(Plugin): + """Plugin that dynamically loads tools from AgentCore Gateway based on semantic intent. + + Args: + mcp_client: MCPClient connected to an AgentCore Gateway. + intent_provider: Strategy for deriving intent. Defaults to StrandsIntentProvider. + """ + + name = "agentcore-tool-search-plugin" + + def __init__( + self, + mcp_client: MCPClient, + intent_provider: IntentProvider | None = None, + ): + """Initialize the plugin. + + Args: + mcp_client: MCPClient connected to an AgentCore Gateway. + intent_provider: Strategy for deriving intent. Defaults to StrandsIntentProvider. + """ + super().__init__() + self._intent_provider = intent_provider or StrandsIntentProvider() + self._mcp_client = mcp_client + self._loaded_tool_names: set[str] = set() + + @property + def tools(self): + """Return empty list; tools are loaded dynamically via the hook.""" + return [] + + @hook + def on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Derive intent, search gateway, and load matching tools.""" + messages = event.messages or [] + + # Pass the agent's model to the intent provider + intent = self._intent_provider.derive_intent(messages, model=event.agent.model) + logger.info("Derived intent: %s", intent) + + # Clear all previously loaded conditional tools + for name in list(self._loaded_tool_names): + event.agent.tool_registry.registry.pop(name, None) + self._loaded_tool_names.clear() + + if not intent: + return + + try: + result = self._mcp_client.call_tool_sync( + tool_use_id="intent-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": intent}, + ) + agent_tools = self._build_tools_from_search_result(result) + except Exception as e: + logger.error("AgentCore Gateway search failed: %s", e) + return + + for agent_tool in agent_tools: + try: + # Skip if a non-dynamic tool with this name already exists + if ( + agent_tool.tool_name in event.agent.tool_registry.registry + and agent_tool.tool_name not in self._loaded_tool_names + ): + logger.debug("Skipping tool %s: already registered as a static tool", agent_tool.tool_name) + continue + event.agent.tool_registry.register_tool(agent_tool) + self._loaded_tool_names.add(agent_tool.tool_name) + except Exception as e: + logger.error("Failed to register tool %s: %s", agent_tool.tool_name, e) + + logger.info("Loaded tools: %s", self._loaded_tool_names) + + def _build_tools_from_search_result(self, result) -> list[MCPAgentTool]: + """Build MCPAgentTool objects from the gateway search response.""" + tools = [] + if not result or not isinstance(result, dict): + return tools + + tool_defs = [] + structured = result.get("structuredContent") + if isinstance(structured, dict) and "tools" in structured: + tool_defs = structured["tools"] + else: + for block in result.get("content", []): + if isinstance(block, dict) and "text" in block: + try: + data = json.loads(block["text"]) + if isinstance(data, dict) and "tools" in data: + tool_defs = data["tools"] + break + except (json.JSONDecodeError, TypeError): + continue + + for tool_def in tool_defs: + if not isinstance(tool_def, dict) or "name" not in tool_def: + continue + mcp_tool = MCPTool( + name=tool_def["name"], + description=tool_def.get("description", ""), + inputSchema=tool_def.get("inputSchema", {"type": "object", "properties": {}}), + ) + tools.append(MCPAgentTool(mcp_tool=mcp_tool, mcp_client=self._mcp_client)) + + return tools diff --git a/tests/bedrock_agentcore/gateway/__init__.py b/tests/bedrock_agentcore/gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/__init__.py b/tests/bedrock_agentcore/gateway/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/__init__.py b/tests/bedrock_agentcore/gateway/integrations/strands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py b/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py new file mode 100644 index 00000000..da34bf82 --- /dev/null +++ b/tests/bedrock_agentcore/gateway/integrations/strands/test_agentcore_tool_search_plugin.py @@ -0,0 +1,240 @@ +"""Tests for AgentCoreToolSearchPlugin.""" + +import json +from unittest.mock import Mock + +import pytest + +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + IntentProvider, + StrandsIntentProvider, +) +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.plugin import ( + AgentCoreToolSearchPlugin, +) + + +class FakeIntentProvider(IntentProvider): + """Test intent provider that returns a fixed intent string.""" + + def __init__(self, intent: str = "test intent"): + self._intent = intent + + def derive_intent(self, messages: list[dict], model=None) -> str: + return self._intent + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCPClient.""" + client = Mock() + client.call_tool_sync.return_value = {"content": []} + return client + + +@pytest.fixture +def fixed_intent_provider(): + """Create a fixed intent provider.""" + return FakeIntentProvider("get weather") + + +@pytest.fixture +def plugin(mock_mcp_client, fixed_intent_provider): + """Create an AgentCoreToolSearchPlugin with mocked dependencies.""" + return AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=fixed_intent_provider) + + +@pytest.fixture +def mock_event(): + """Create a mock BeforeInvocationEvent.""" + event = Mock() + event.messages = [{"role": "user", "content": [{"text": "hello"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + return event + + +class TestAgentCoreToolSearchPluginInit: + """Test AgentCoreToolSearchPlugin initialization.""" + + def test_init_with_custom_intent_provider(self, mock_mcp_client): + """Test initialization with a custom intent provider.""" + provider = FakeIntentProvider("custom") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + assert plugin._intent_provider is provider + assert plugin._mcp_client is mock_mcp_client + + def test_init_default_intent_provider(self, mock_mcp_client): + """Test initialization uses StrandsIntentProvider when none provided.""" + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client) + assert isinstance(plugin._intent_provider, StrandsIntentProvider) + + def test_plugin_name(self, plugin): + """Test plugin has correct name.""" + assert plugin.name == "agentcore-tool-search-plugin" + + def test_tools_property_returns_empty(self, plugin): + """Test tools property returns empty list.""" + assert plugin.tools == [] + + +class TestOnBeforeInvocation: + """Test on_before_invocation hook behavior.""" + + def test_empty_intent_skips_search(self, mock_mcp_client, mock_event): + """Test that empty intent does not call gateway search.""" + provider = FakeIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + + plugin.on_before_invocation(mock_event) + + mock_mcp_client.call_tool_sync.assert_not_called() + + def test_calls_gateway_search_with_intent(self, plugin, mock_mcp_client, mock_event): + """Test that derived intent is passed to gateway search.""" + plugin.on_before_invocation(mock_event) + + mock_mcp_client.call_tool_sync.assert_called_once_with( + tool_use_id="intent-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather"}, + ) + + def test_passes_agent_model_to_intent_provider(self, mock_mcp_client, mock_event): + """Test that the agent's model is passed to derive_intent.""" + provider = Mock(spec=IntentProvider) + provider.derive_intent.return_value = "" + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + mock_event.agent.model = Mock(name="test-model") + + plugin.on_before_invocation(mock_event) + + provider.derive_intent.assert_called_once_with(mock_event.messages, model=mock_event.agent.model) + + def test_registers_tools_from_structured_content(self, plugin, mock_mcp_client, mock_event): + """Test tools are registered from structuredContent response.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + { + "name": "weather_tool", + "description": "Get weather", + "inputSchema": {"type": "object", "properties": {"city": {"type": "string"}}}, + } + ] + } + } + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "weather_tool" + assert "weather_tool" in plugin._loaded_tool_names + + def test_registers_tools_from_text_content(self, plugin, mock_mcp_client, mock_event): + """Test tools are registered from JSON text content response.""" + tools_json = json.dumps( + { + "tools": [ + { + "name": "calc_tool", + "description": "Calculator", + "inputSchema": {"type": "object", "properties": {}}, + }, + ] + } + ) + mock_mcp_client.call_tool_sync.return_value = {"content": [{"text": tools_json}]} + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "calc_tool" + + def test_clears_previously_loaded_tools(self, plugin, mock_mcp_client, mock_event): + """Test previously loaded tools are removed from registry.""" + mock_mcp_client.call_tool_sync.return_value = {"content": []} + plugin._loaded_tool_names = {"old_tool_1", "old_tool_2"} + mock_event.agent.tool_registry.registry = { + "old_tool_1": Mock(), + "old_tool_2": Mock(), + "permanent_tool": Mock(), + } + + plugin.on_before_invocation(mock_event) + + assert "old_tool_1" not in mock_event.agent.tool_registry.registry + assert "old_tool_2" not in mock_event.agent.tool_registry.registry + assert "permanent_tool" in mock_event.agent.tool_registry.registry + assert len(plugin._loaded_tool_names) == 0 + + def test_gateway_search_failure_logs_and_returns(self, plugin, mock_mcp_client, mock_event): + """Test gateway search failure is handled gracefully.""" + mock_mcp_client.call_tool_sync.side_effect = RuntimeError("connection failed") + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_not_called() + + def test_skips_invalid_tool_defs(self, plugin, mock_mcp_client, mock_event): + """Test malformed tool definitions are skipped.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + {"description": "no name field"}, + "not a dict", + {"name": "valid_tool", "description": "ok", "inputSchema": {"type": "object", "properties": {}}}, + ] + } + } + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_called_once() + registered_tool = mock_event.agent.tool_registry.register_tool.call_args[0][0] + assert registered_tool.tool_name == "valid_tool" + + def test_register_tool_failure_continues(self, plugin, mock_mcp_client, mock_event): + """Test that failure to register one tool doesn't block others.""" + mock_mcp_client.call_tool_sync.return_value = { + "structuredContent": { + "tools": [ + {"name": "tool_a", "description": "A", "inputSchema": {"type": "object", "properties": {}}}, + {"name": "tool_b", "description": "B", "inputSchema": {"type": "object", "properties": {}}}, + ] + } + } + mock_event.agent.tool_registry.register_tool.side_effect = [RuntimeError("fail"), None] + + plugin.on_before_invocation(mock_event) + + assert mock_event.agent.tool_registry.register_tool.call_count == 2 + assert "tool_a" not in plugin._loaded_tool_names + assert "tool_b" in plugin._loaded_tool_names + + def test_none_result_loads_no_tools(self, plugin, mock_mcp_client, mock_event): + """Test None result from gateway loads no tools.""" + mock_mcp_client.call_tool_sync.return_value = None + + plugin.on_before_invocation(mock_event) + + mock_event.agent.tool_registry.register_tool.assert_not_called() + + def test_empty_messages_with_intent(self, mock_mcp_client): + """Test plugin works with empty messages list.""" + provider = FakeIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mock_mcp_client, intent_provider=provider) + event = Mock() + event.messages = [] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + + mock_mcp_client.call_tool_sync.assert_not_called() diff --git a/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py b/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py new file mode 100644 index 00000000..e8a67ba6 --- /dev/null +++ b/tests/bedrock_agentcore/gateway/integrations/strands/test_intent_providers.py @@ -0,0 +1,232 @@ +"""Tests for IntentProvider and StrandsIntentProvider.""" + +from unittest.mock import Mock, patch + +import pytest + +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + IntentProvider, + StrandsIntentProvider, +) + + +class TestIntentProviderInterface: + """Test IntentProvider abstract interface.""" + + def test_cannot_instantiate_abstract_class(self): + """Test that IntentProvider cannot be instantiated directly.""" + with pytest.raises(TypeError): + IntentProvider() + + def test_subclass_must_implement_derive_intent(self): + """Test that subclass without derive_intent raises TypeError.""" + + class IncompleteProvider(IntentProvider): + pass + + with pytest.raises(TypeError): + IncompleteProvider() + + def test_subclass_with_derive_intent_works(self): + """Test that a proper subclass can be instantiated.""" + + class ValidProvider(IntentProvider): + def derive_intent(self, messages: list[dict], model=None) -> str: + return "test" + + provider = ValidProvider() + assert provider.derive_intent([]) == "test" + + +class TestStrandsIntentProvider: + """Test StrandsIntentProvider class.""" + + def test_init_default_message_window(self): + """Test default message window is 5.""" + provider = StrandsIntentProvider() + assert provider._message_window == 5 + + def test_init_custom_message_window(self): + """Test custom message window.""" + provider = StrandsIntentProvider(message_window=3) + assert provider._message_window == 3 + + def test_init_with_explicit_model(self): + """Test initialization with explicit model.""" + model = Mock() + provider = StrandsIntentProvider(model=model) + assert provider._explicit_model is model + + def test_empty_messages_returns_empty_string(self): + """Test empty messages returns empty string without calling LLM.""" + provider = StrandsIntentProvider() + result = provider.derive_intent([]) + assert result == "" + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_calls_agent(self, mock_agent_class): + """Test derive_intent creates an Agent and calls it.""" + mock_agent = Mock() + mock_agent.return_value = "user wants weather info" + mock_agent_class.return_value = mock_agent + + provider = StrandsIntentProvider(message_window=2) + messages = [ + {"role": "user", "content": [{"text": "What is the weather?"}]}, + ] + + result = provider.derive_intent(messages) + + assert result == "user wants weather info" + mock_agent_class.assert_called_once() + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_uses_explicit_model(self, mock_agent_class): + """Test derive_intent uses explicit model over agent model.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + explicit_model = Mock(name="explicit-model") + provider = StrandsIntentProvider(model=explicit_model) + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=Mock(name="agent-model")) + + # Explicit model takes priority + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["model"] is explicit_model + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_uses_agent_model_when_no_explicit(self, mock_agent_class): + """Test derive_intent falls back to agent model when no explicit model.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + agent_model = Mock(name="agent-model") + provider = StrandsIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=agent_model) + + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["model"] is agent_model + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_no_model_kwarg_when_none(self, mock_agent_class): + """Test derive_intent omits model kwarg when no model available.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + provider = StrandsIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + provider.derive_intent(messages, model=None) + + call_kwargs = mock_agent_class.call_args[1] + assert "model" not in call_kwargs + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_respects_message_window(self, mock_agent_class): + """Test only last N messages are used.""" + mock_agent = Mock() + mock_agent.return_value = "intent" + mock_agent_class.return_value = mock_agent + + provider = StrandsIntentProvider(message_window=3) + messages = [ + {"role": "user", "content": [{"text": "first"}]}, + {"role": "user", "content": [{"text": "second"}]}, + {"role": "user", "content": [{"text": "third"}]}, + {"role": "user", "content": [{"text": "fourth"}]}, + {"role": "user", "content": [{"text": "fifth"}]}, + ] + + provider.derive_intent(messages) + + # Window=3 takes last 3 messages; only user messages are formatted + call_args = mock_agent.call_args[0][0] + assert "first" not in call_args + assert "second" not in call_args + assert "third" in call_args + assert "fourth" in call_args + assert "fifth" in call_args + + @patch( + "bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers.strands_intent_provider.Agent" + ) + def test_derive_intent_handles_exception(self, mock_agent_class): + """Test derive_intent returns empty string on exception.""" + mock_agent_class.side_effect = RuntimeError("LLM unavailable") + + provider = StrandsIntentProvider() + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + result = provider.derive_intent(messages) + + assert result == "" + + def test_format_messages_for_prompt(self): + """Test message formatting only includes user messages.""" + provider = StrandsIntentProvider() + messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"text": "world"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + {"role": "user", "content": [{"text": "What is the weather?"}]}, + ] + + result = provider._format_messages_for_prompt(messages) + + assert "Hello world" in result + assert "What is the weather?" in result + assert "Hi there" not in result + + def test_format_messages_handles_missing_role(self): + """Test formatting skips messages without user role.""" + provider = StrandsIntentProvider() + messages = [{"content": [{"text": "no role"}]}] + + result = provider._format_messages_for_prompt(messages) + + assert result == "" + + def test_format_messages_handles_non_text_blocks(self): + """Test formatting skips non-text content blocks.""" + provider = StrandsIntentProvider() + messages = [ + {"role": "user", "content": [{"image": "data"}, {"text": "only this"}]}, + ] + + result = provider._format_messages_for_prompt(messages) + + assert "only this" in result + assert "data" not in result + + def test_format_messages_excludes_tool_results(self): + """Test formatting excludes assistant tool results to avoid PII leakage.""" + provider = StrandsIntentProvider() + messages = [ + {"role": "user", "content": [{"text": "Check my account"}]}, + {"role": "assistant", "content": [{"toolUse": {"name": "get_account", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"content": [{"text": "SSN: 123-45-6789"}]}}]}, + {"role": "user", "content": [{"text": "Now send an email"}]}, + ] + + result = provider._format_messages_for_prompt(messages) + + assert "Check my account" in result + assert "Now send an email" in result + assert "SSN" not in result + assert "get_account" not in result diff --git a/tests_integ/gateway/integrations/lambda_function/lambda_function.py b/tests_integ/gateway/integrations/lambda_function/lambda_function.py new file mode 100644 index 00000000..2001f9b7 --- /dev/null +++ b/tests_integ/gateway/integrations/lambda_function/lambda_function.py @@ -0,0 +1,117 @@ +"""MCP-compatible Lambda handler for AgentCore Gateway integration tests. + +This Lambda implements the MCP JSON-RPC protocol over HTTP, responding to: +- initialize: Returns server capabilities +- tools/list: Returns available tool definitions +- tools/call: Executes a tool and returns results + +Deploy with Python 3.10+ runtime, handler: lambda_function.lambda_handler +""" + +import json +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + }, + }, + { + "name": "send_email", + "description": "Send an email to a recipient", + "inputSchema": { + "type": "object", + "properties": { + "to": {"type": "string", "description": "Recipient email"}, + "subject": {"type": "string", "description": "Email subject"}, + "body": {"type": "string", "description": "Email body"}, + }, + "required": ["to", "subject", "body"], + }, + }, +] + + +def lambda_handler(event, context): + """Handle MCP JSON-RPC requests from AgentCore Gateway.""" + logger.info("Received event: %s", json.dumps(event)) + + body = event.get("body", "{}") + if isinstance(body, str): + body = json.loads(body) + + method = body.get("method", "") + request_id = body.get("id") + params = body.get("params", {}) + + if method == "initialize": + result = { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {"listChanged": False}, + }, + "serverInfo": { + "name": "integ-test-mcp-server", + "version": "1.0.0", + }, + } + elif method == "notifications/initialized": + # Client acknowledgment, no response needed + return {"statusCode": 200, "body": ""} + elif method == "tools/list": + result = {"tools": TOOLS} + elif method == "tools/call": + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + result = _handle_tool_call(tool_name, arguments) + else: + return { + "statusCode": 200, + "body": json.dumps( + { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + } + ), + } + + response = { + "jsonrpc": "2.0", + "id": request_id, + "result": result, + } + + return { + "statusCode": 200, + "body": json.dumps(response), + } + + +def _handle_tool_call(tool_name, arguments): + """Execute a tool and return MCP-formatted result.""" + if tool_name == "get_weather": + city = arguments.get("city", "unknown") + return {"content": [{"type": "text", "text": f"Weather in {city}: 72°F, sunny with light clouds."}]} + elif tool_name == "send_email": + to = arguments.get("to", "") + subject = arguments.get("subject", "") + return {"content": [{"type": "text", "text": f"Email sent to {to} with subject: {subject}"}]} + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } diff --git a/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py b/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py new file mode 100644 index 00000000..a10a0c54 --- /dev/null +++ b/tests_integ/gateway/integrations/test_agentcore_tool_search_plugin.py @@ -0,0 +1,489 @@ +"""Integration tests for AgentCoreToolSearchPlugin. + +If GATEWAY_ROLE_ARN and GATEWAY_LAMBDA_ARN are set, uses those directly. +Otherwise, automatically provisions the IAM role and Lambda function, +and tears them down after the test run. + +Environment variables (all optional): + BEDROCK_TEST_REGION: AWS region (default: us-west-2) + GATEWAY_ROLE_ARN: IAM role ARN with AgentCore gateway trust policy + GATEWAY_LAMBDA_ARN: Lambda ARN for the gateway target +""" + +import io +import json +import logging +import os +import time +import zipfile + +import boto3 +import pytest + +from bedrock_agentcore.gateway.client import GatewayClient +from bedrock_agentcore.gateway.integrations.strands.plugins import AgentCoreToolSearchPlugin +from bedrock_agentcore.gateway.integrations.strands.plugins.agentcore_tool_search.intent_providers import ( + IntentProvider, + StrandsIntentProvider, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Infrastructure constants +_ROLE_NAME = "integ-test-gateway-role" +_LAMBDA_NAME = "integ-test-lambda" +_LAMBDA_HANDLER = "lambda_function.lambda_handler" +_LAMBDA_RUNTIME = "python3.10" + +_TRUST_POLICY = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "bedrock-agentcore.amazonaws.com"}, + "Action": "sts:AssumeRole", + }, + { + "Effect": "Allow", + "Principal": {"Service": "lambda.amazonaws.com"}, + "Action": "sts:AssumeRole", + }, + ], +} + +_LAMBDA_INVOKE_POLICY_TEMPLATE = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "lambda:InvokeFunction", + "Resource": None, # filled in after Lambda creation + } + ], +} + + +def _get_lambda_zip() -> bytes: + """Package lambda_function.py into a zip archive.""" + lambda_path = os.path.join(os.path.dirname(__file__), "lambda_function", "lambda_function.py") + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + zf.write(lambda_path, "lambda_function.py") + return buf.getvalue() + + +def _ensure_role(iam_client) -> str: + """Create the gateway IAM role if it doesn't exist, return its ARN.""" + try: + response = iam_client.get_role(RoleName=_ROLE_NAME) + return response["Role"]["Arn"] + except iam_client.exceptions.NoSuchEntityException: + pass + + response = iam_client.create_role( + RoleName=_ROLE_NAME, + AssumeRolePolicyDocument=json.dumps(_TRUST_POLICY), + Description="Integration test role for AgentCore gateway tests", + ) + role_arn = response["Role"]["Arn"] + + iam_client.attach_role_policy( + RoleName=_ROLE_NAME, + PolicyArn="arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole", + ) + # Wait for IAM propagation + time.sleep(10) + logger.info("Created role: %s", role_arn) + return role_arn + + +def _attach_lambda_invoke_policy(iam_client, lambda_arn: str): + """Attach a scoped lambda:InvokeFunction policy to the gateway role.""" + policy = _LAMBDA_INVOKE_POLICY_TEMPLATE.copy() + policy["Statement"] = [{"Effect": "Allow", "Action": "lambda:InvokeFunction", "Resource": lambda_arn}] + iam_client.put_role_policy( + RoleName=_ROLE_NAME, + PolicyName="lambda-invoke", + PolicyDocument=json.dumps(policy), + ) + + +def _ensure_lambda(lambda_client, role_arn: str) -> str: + """Create or update the test Lambda, return its ARN.""" + zip_bytes = _get_lambda_zip() + try: + response = lambda_client.get_function(FunctionName=_LAMBDA_NAME) + lambda_client.update_function_code(FunctionName=_LAMBDA_NAME, ZipFile=zip_bytes) + return response["Configuration"]["FunctionArn"] + except lambda_client.exceptions.ResourceNotFoundException: + pass + + response = lambda_client.create_function( + FunctionName=_LAMBDA_NAME, + Runtime=_LAMBDA_RUNTIME, + Role=role_arn, + Handler=_LAMBDA_HANDLER, + Code={"ZipFile": zip_bytes}, + Timeout=30, + Description="MCP test Lambda for AgentCore gateway integration tests", + ) + waiter = lambda_client.get_waiter("function_active_v2") + waiter.wait(FunctionName=_LAMBDA_NAME) + logger.info("Created Lambda: %s", response["FunctionArn"]) + return response["FunctionArn"] + + +class FixedIntentProvider(IntentProvider): + """Intent provider that returns a fixed string for deterministic testing.""" + + def __init__(self, intent: str): + self._intent = intent + + def derive_intent(self, messages: list[dict], model=None) -> str: + return self._intent + + +@pytest.mark.integration +class TestAgentCoreToolSearchPluginIntegration: + """Integration tests for AgentCoreToolSearchPlugin with a live gateway. + + Creates a gateway with a Lambda target exposing test tools, then verifies + the plugin can search and load those tools. + """ + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.role_arn = os.environ.get("GATEWAY_ROLE_ARN") + cls.lambda_arn = os.environ.get("GATEWAY_LAMBDA_ARN") + cls._provisioned_infra = False + + if not cls.role_arn or not cls.lambda_arn: + # Auto-provision infrastructure + session = boto3.Session(region_name=cls.region) + iam_client = session.client("iam") + lambda_client = session.client("lambda", region_name=cls.region) + cls.role_arn = _ensure_role(iam_client) + cls.lambda_arn = _ensure_lambda(lambda_client, cls.role_arn) + _attach_lambda_invoke_policy(iam_client, cls.lambda_arn) + cls._provisioned_infra = True + logger.info("Auto-provisioned infrastructure: role=%s, lambda=%s", cls.role_arn, cls.lambda_arn) + + cls.gw_client = GatewayClient(region_name=cls.region) + cls.test_prefix = f"sdk-integ-plugin-{int(time.time())}" + cls.gateway_id = None + cls.target_id = None + + # Create gateway with semantic search enabled + gw = cls.gw_client.create_gateway_and_wait( + name=f"{cls.test_prefix}-gw", + roleArn=cls.role_arn, + authorizerType="NONE", + protocolType="MCP", + protocolConfiguration={ + "mcp": { + "searchType": "SEMANTIC", + }, + }, + ) + cls.gateway_id = gw["gatewayId"] + logger.info("Created gateway: %s", cls.gateway_id) + + # Create target with test tools + target = cls.gw_client.create_gateway_target_and_wait( + gatewayIdentifier=cls.gateway_id, + name=f"{cls.test_prefix}-target", + targetConfiguration={ + "mcp": { + "lambda": { + "lambdaArn": cls.lambda_arn, + "toolSchema": { + "inlinePayload": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + }, + }, + { + "name": "send_email", + "description": "Send an email to a recipient", + "inputSchema": { + "type": "object", + "properties": { + "to": {"type": "string"}, + "subject": {"type": "string"}, + "body": {"type": "string"}, + }, + "required": ["to", "subject", "body"], + }, + }, + ] + }, + } + }, + }, + credentialProviderConfigurations=[ + {"credentialProviderType": "GATEWAY_IAM_ROLE"}, + ], + ) + cls.target_id = target["targetId"] + logger.info("Created target: %s", cls.target_id) + + # Wait for target search indexing to complete (can take up to 60s) + time.sleep(60) + + @classmethod + def teardown_class(cls): + if cls.gateway_id: + if cls.target_id: + try: + cls.gw_client.delete_gateway_target_and_wait( + gatewayIdentifier=cls.gateway_id, + targetId=cls.target_id, + ) + except Exception as e: + logger.warning("Failed to delete target %s: %s", cls.target_id, e) + try: + cls.gw_client.delete_gateway_and_wait( + gatewayIdentifier=cls.gateway_id, + ) + except Exception as e: + logger.warning("Failed to delete gateway %s: %s", cls.gateway_id, e) + + # Clean up auto-provisioned infrastructure + if cls._provisioned_infra: + session = boto3.Session(region_name=cls.region) + lambda_client = session.client("lambda", region_name=cls.region) + iam_client = session.client("iam") + try: + lambda_client.delete_function(FunctionName=_LAMBDA_NAME) + logger.info("Deleted Lambda: %s", _LAMBDA_NAME) + except Exception as e: + logger.warning("Failed to delete Lambda: %s", e) + try: + iam_client.delete_role_policy(RoleName=_ROLE_NAME, PolicyName="lambda-invoke") + except Exception: + pass + try: + iam_client.detach_role_policy( + RoleName=_ROLE_NAME, + PolicyArn="arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole", + ) + except Exception: + pass + try: + iam_client.delete_role(RoleName=_ROLE_NAME) + logger.info("Deleted role: %s", _ROLE_NAME) + except Exception as e: + logger.warning("Failed to delete role: %s", e) + + def _make_mcp_client(self): + """Create an MCPClient connected to the test gateway via Streamable HTTP with IAM auth.""" + from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client + from strands.tools.mcp import MCPClient + + endpoint = f"https://{self.gateway_id}.gateway.bedrock-agentcore.{self.region}.amazonaws.com/mcp" + return MCPClient( + lambda: aws_iam_streamablehttp_client( + endpoint=endpoint, + aws_region=self.region, + aws_service="bedrock-agentcore", + ) + ) + + @pytest.mark.order(1) + def test_plugin_with_default_intent_provider(self): + """Plugin initializes correctly with StrandsIntentProvider.""" + mcp_client = self._make_mcp_client() + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client) + assert isinstance(plugin._intent_provider, StrandsIntentProvider) + assert plugin.name == "agentcore-tool-search-plugin" + + @pytest.mark.order(2) + def test_plugin_with_custom_intent_provider(self): + """Plugin accepts a custom IntentProvider.""" + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("weather query") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + assert plugin._intent_provider is provider + + @pytest.mark.order(3) + def test_gateway_search_returns_results(self): + """Calling x_amz_bedrock_agentcore_search on the gateway returns tool definitions.""" + mcp_client = self._make_mcp_client() + + with mcp_client: + result = mcp_client.call_tool_sync( + tool_use_id="test-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather information"}, + ) + + assert result is not None + logger.info("Search result keys: %s", result.keys() if isinstance(result, dict) else type(result)) + + @pytest.mark.order(4) + def test_plugin_loads_tools_via_hook(self): + """Plugin loads matching tools into the agent via the before_invocation hook.""" + from strands import Agent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("get weather information") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + with mcp_client: + # First verify the search endpoint returns tools + result = mcp_client.call_tool_sync( + tool_use_id="debug-search", + name="x_amz_bedrock_agentcore_search", + arguments={"query": "get weather information"}, + ) + logger.info("Raw search result: %s", result) + + agent = Agent( + system_prompt="You are a helpful assistant. Use available tools to help the user.", + tools=[], + plugins=[plugin], + ) + # Trigger an invocation so the hook fires + agent("What is the weather in Seattle?") + + logger.info("Loaded tool names: %s", plugin._loaded_tool_names) + # The gateway should have returned the get_weather tool + assert len(plugin._loaded_tool_names) > 0, ( + f"Expected tools to be loaded but got none. Raw search result was: {result}" + ) + + @pytest.mark.order(5) + def test_empty_intent_loads_no_tools(self): + """Plugin does not search gateway when intent is empty.""" + from unittest.mock import Mock + + from strands.hooks import BeforeInvocationEvent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + # Simulate a before_invocation event + event = Mock(spec=BeforeInvocationEvent) + event.messages = [{"role": "user", "content": [{"text": "hello"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + + assert len(plugin._loaded_tool_names) == 0 + + @pytest.mark.order(6) + def test_tools_cleared_between_invocations(self): + """Previously loaded tools are cleared before each new search.""" + from unittest.mock import Mock + + from strands.hooks import BeforeInvocationEvent + + mcp_client = self._make_mcp_client() + provider = FixedIntentProvider("get weather information") + plugin = AgentCoreToolSearchPlugin(mcp_client=mcp_client, intent_provider=provider) + + with mcp_client: + # First: simulate invocation with a real intent + event = Mock(spec=BeforeInvocationEvent) + event.messages = [{"role": "user", "content": [{"text": "weather"}]}] + event.agent = Mock() + event.agent.model = None + event.agent.tool_registry = Mock() + event.agent.tool_registry.registry = {} + + plugin.on_before_invocation(event) + first_tools = set(plugin._loaded_tool_names) + logger.info("First invocation tools: %s", first_tools) + assert len(first_tools) > 0 + + # Second: switch to empty intent — tools should be cleared + provider._intent = "" + event.agent.tool_registry.registry = {name: Mock() for name in first_tools} + + plugin.on_before_invocation(event) + second_tools = set(plugin._loaded_tool_names) + logger.info("Second invocation tools: %s", second_tools) + + assert len(second_tools) == 0 + # Verify old tools were removed from registry + for name in first_tools: + assert name not in event.agent.tool_registry.registry + + +@pytest.mark.integration +class TestStrandsIntentProviderIntegration: + """Integration tests for StrandsIntentProvider with a real LLM.""" + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + + def test_derive_intent_from_messages(self): + """StrandsIntentProvider produces a non-empty intent string from messages.""" + provider = StrandsIntentProvider(message_window=3) + messages = [ + {"role": "user", "content": [{"text": "What's the weather like in Seattle today?"}]}, + {"role": "assistant", "content": [{"text": "Let me check the weather for you."}]}, + {"role": "user", "content": [{"text": "Also check tomorrow's forecast."}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent: %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0 + + def test_derive_intent_empty_messages(self): + """StrandsIntentProvider returns empty string for empty messages.""" + provider = StrandsIntentProvider() + intent = provider.derive_intent([]) + assert intent == "" + + def test_derive_intent_with_custom_model(self): + """StrandsIntentProvider works with an explicitly provided model.""" + from strands.models.bedrock import BedrockModel + + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + region_name=self.region, + ) + provider = StrandsIntentProvider(model=model) + messages = [ + {"role": "user", "content": [{"text": "I need to send an email to my team about the project update."}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent with custom model: %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0 + + def test_derive_intent_respects_message_window(self): + """StrandsIntentProvider only considers the last N messages.""" + provider = StrandsIntentProvider(message_window=2) + messages = [ + {"role": "user", "content": [{"text": "Tell me about dogs."}]}, + {"role": "assistant", "content": [{"text": "Dogs are great pets."}]}, + {"role": "user", "content": [{"text": "Now tell me about the stock market."}]}, + {"role": "assistant", "content": [{"text": "The stock market is complex."}]}, + {"role": "user", "content": [{"text": "What are the best investment strategies?"}]}, + ] + + intent = provider.derive_intent(messages) + + logger.info("Derived intent (window=2): %s", intent) + assert isinstance(intent, str) + assert len(intent) > 0 diff --git a/uv.lock b/uv.lock index 3987847a..34556c6c 100644 --- a/uv.lock +++ b/uv.lock @@ -295,6 +295,7 @@ simulation = [ { name = "strands-agents-evals" }, ] strands-agents = [ + { name = "mcp" }, { name = "strands-agents" }, ] strands-agents-evals = [ @@ -328,6 +329,7 @@ requires-dist = [ { name = "boto3", specifier = ">=1.43.0" }, { name = "botocore", specifier = ">=1.43.0" }, { name = "jinja2", marker = "extra == 'simulation'", specifier = ">=3.1.0" }, + { name = "mcp", marker = "extra == 'strands-agents'", specifier = ">=1.23.0,<2.0.0" }, { name = "pydantic", specifier = ">=2.0.0,<2.41.3" }, { name = "starlette", specifier = ">=0.46.2" }, { name = "strands-agents", marker = "extra == 'strands-agents'", specifier = ">=1.20.0" },