From 3e83c4c8989b93274f0754f22eff25b07cf33ce3 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 22:30:33 +0530 Subject: [PATCH 01/14] feat: add per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 119 ++++++++++++ tests/strands/agent/test_agent.py | 302 ++++++++++++++++++++++++++++++ 2 files changed, 421 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a6..47b5582137 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,10 +9,12 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import asyncio import logging import threading import warnings from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -95,6 +97,20 @@ class _DefaultRetryStrategySentinel: _DEFAULT_AGENT_ID = "default" +@dataclass +class _InflightInvocation: + """Tracks an inflight invocation for idempotency deduplication. + + When a caller provides an `idempotency_token`, the agent registers the invocation in a dict keyed by the token. + If a duplicate call arrives with the same token while the original is still running, the duplicate waits on the + `done` event and receives the same result or error. + """ + + done: threading.Event = field(default_factory=threading.Event) + result: AgentResult | None = None + error: BaseException | None = None + + class Agent(AgentBase): """Core Agent implementation. @@ -285,6 +301,14 @@ def __init__( self._invocation_lock = threading.Lock() self._concurrent_invocation_mode = concurrent_invocation_mode + # Tracks the single inflight invocation for idempotency duplicate detection. + # In THROW mode only one invocation can be inflight at a time, so a single + # variable suffices. Uses threading primitives (not asyncio) because run_async() + # creates separate threads with separate event loops. + self._inflight_idempotency_token: Any = None + self._inflight_invocation: _InflightInvocation | None = None + self._inflight_invocations_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy if ( @@ -422,6 +446,7 @@ def __call__( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -441,6 +466,10 @@ def __call__( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -458,6 +487,7 @@ def __call__( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) ) @@ -469,6 +499,7 @@ async def invoke_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -488,6 +519,10 @@ async def invoke_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -503,6 +538,7 @@ async def invoke_async( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) async for event in events: @@ -685,6 +721,67 @@ def __del__(self) -> None: if hasattr(self, "tool_registry"): self.tool_registry.cleanup() + def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]: + """Check if this invocation is a duplicate of an inflight one, or register it as new. + + Only active in THROW mode. In UNSAFE_REENTRANT mode or when no token is provided, + this is a no-op that returns (None, None). + + Args: + idempotency_token: Caller-provided token for duplicate detection. + + Returns: + A tuple of (waiting_on, registered_token): + - If duplicate: (inflight_invocation_to_wait_on, None) + - If new request: (None, the_registered_token) + - If no token or wrong mode: (None, None) + """ + if idempotency_token is None or self._concurrent_invocation_mode != ConcurrentInvocationMode.THROW: + return None, None + + with self._inflight_invocations_lock: + if self._inflight_idempotency_token == idempotency_token: + return self._inflight_invocation, None + else: + self._inflight_invocation = _InflightInvocation() + self._inflight_idempotency_token = idempotency_token + return None, idempotency_token + + def _complete_idempotent_invocation( + self, + registered_token: Any, + result: AgentResult | None = None, + error: BaseException | None = None, + ) -> None: + """Signal waiting duplicates and clean up idempotency state. + + Safe to call even when registered_token is None (no-op in that case). + + Args: + registered_token: The token that was registered by _check_idempotency, or None. + result: The AgentResult to pass to waiting duplicates (success path). + error: The exception to pass to waiting duplicates (error path). + """ + if registered_token is None: + return + + with self._inflight_invocations_lock: + inflight = self._inflight_invocation + self._inflight_idempotency_token = None + self._inflight_invocation = None + + if inflight is None: + return + + if error is None and result is None: + error = asyncio.CancelledError("Primary invocation was cancelled before completion.") + + if error is not None: + inflight.error = error + else: + inflight.result = result + inflight.done.set() + async def stream_async( self, prompt: AgentInput = None, @@ -692,6 +789,7 @@ async def stream_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -711,6 +809,10 @@ async def stream_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -733,16 +835,30 @@ async def stream_async( yield event["data"] ``` """ + waiting_on, registered_token = self._check_idempotency(idempotency_token) + + if waiting_on is not None: + logger.debug("idempotency_token=<%s> | duplicate request detected, waiting for original", idempotency_token) + await asyncio.to_thread(waiting_on.done.wait) + if waiting_on.error is not None: + raise waiting_on.error + if waiting_on.result is not None: + yield AgentResultEvent(result=waiting_on.result).as_dict() + return + # Conditionally acquire lock based on concurrent_invocation_mode # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: lock_acquired = self._invocation_lock.acquire(blocking=False) if not lock_acquired: + self._complete_idempotent_invocation(registered_token) raise ConcurrencyException( "Agent is already processing a request. Concurrent invocations are not supported." ) + result: AgentResult | None = None + try: self._interrupt_state.resume(prompt) @@ -787,12 +903,15 @@ async def stream_async( except Exception as e: self._end_agent_trace_span(error=e) + self._complete_idempotent_invocation(registered_token, error=e) raise finally: # Clear cancel signal to allow agent reuse after cancellation self._cancel_signal.clear() + self._complete_idempotent_invocation(registered_token, result=result) + if self._invocation_lock.locked(): self._invocation_lock.release() diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafba..6ed43d392b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2699,3 +2699,305 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +class SyncEventFailingModel: + """A mock model that signals when streaming starts, then raises an error. + + Used for testing idempotency behavior when the original invocation fails. + """ + + def __init__(self): + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream(self, *args, **kwargs): + self.started_event.set() + self.proceed_event.wait() + raise RuntimeError("Simulated model failure") + yield # noqa: RET503 - makes this an async generator + + +class IdempotencyTestAgent(Agent): + """Agent subclass that signals when a duplicate idempotency token is detected. + + Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: + the model pauses Thread 1 inside stream(), and this class signals when Thread 2 + has reached _check_idempotency and been identified as a duplicate. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_detected = threading.Event() + + def _check_idempotency(self, idempotency_token): + result = super()._check_idempotency(idempotency_token) + if result[0] is not None: + self.duplicate_detected.set() + return result + + +def test_idempotency_duplicate_waits_and_returns_same_result(): + """Test that a duplicate call with the same idempotency_token waits and returns the same result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc-123") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) + + +def test_idempotency_original_fails_duplicate_gets_same_error(): + """Test that when the original invocation fails, the duplicate receives the same exception.""" + model = SyncEventFailingModel() + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + errors = [] + lock = threading.Lock() + + def invoke(): + try: + agent("test", idempotency_token="abc-123") + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 2, f"Expected 2 errors, got {len(errors)}" + assert all(isinstance(e, RuntimeError) for e in errors) + assert all("Simulated model failure" in str(e) for e in errors) + + +def test_idempotency_different_token_raises_concurrency_exception(): + """Test that a different idempotency_token while another is inflight raises ConcurrencyException.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_def) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_no_token_falls_back_to_throw(): + """Test that a call without idempotency_token still gets ConcurrencyException in THROW mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_with_token(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_without_token(): + try: + result = agent("test") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_with_token) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_without_token) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_ignored_in_unsafe_reentrant(): + """Test that idempotency_token has no effect in UNSAFE_REENTRANT mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + + +def test_idempotency_cleanup_after_completion(): + """Test that after completion, the same token is treated as a fresh request.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + result1 = agent("test", idempotency_token="abc") + assert str(result1).strip() == "response1" + + result2 = agent("test", idempotency_token="abc") + assert str(result2).strip() == "response2" + + assert str(result1) != str(result2) + + +def test_idempotency_with_prompt_as_token(): + """Test that the prompt itself can be used as the idempotency_token.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + prompt = "What's the weather?" + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent(prompt, idempotency_token=prompt) + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) From 7a0d1b95927e3b787177ef3965c544a316b0fd7f Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 22:36:13 +0530 Subject: [PATCH 02/14] feat: added per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 47b5582137..0ed61776a8 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -101,9 +101,10 @@ class _DefaultRetryStrategySentinel: class _InflightInvocation: """Tracks an inflight invocation for idempotency deduplication. - When a caller provides an `idempotency_token`, the agent registers the invocation in a dict keyed by the token. - If a duplicate call arrives with the same token while the original is still running, the duplicate waits on the - `done` event and receives the same result or error. + When a caller provides an `idempotency_token`, the agent registers this invocation + (in THROW mode only one can be inflight at a time). If a duplicate call arrives + with the same token while the original is still running, the duplicate waits on + the `done` event and receives the same result or error. """ done: threading.Event = field(default_factory=threading.Event) @@ -756,6 +757,8 @@ def _complete_idempotent_invocation( """Signal waiting duplicates and clean up idempotency state. Safe to call even when registered_token is None (no-op in that case). + If both result and error are None (e.g. original caller disconnected), + sets asyncio.CancelledError so duplicates receive a clear error. Args: registered_token: The token that was registered by _check_idempotency, or None. From a20d6a0d9d6f0eae4461d2739adeaba156c3d55f Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 23:21:14 +0530 Subject: [PATCH 03/14] feat: add per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0ed61776a8..dd2163c0f6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -776,13 +776,12 @@ def _complete_idempotent_invocation( if inflight is None: return - if error is None and result is None: - error = asyncio.CancelledError("Primary invocation was cancelled before completion.") - if error is not None: inflight.error = error - else: + elif result is not None: inflight.result = result + else: + inflight.error = asyncio.CancelledError("Primary invocation was cancelled before completion.") inflight.done.set() async def stream_async( From 4806849e84930c614cf0129f601b9b29380072bb Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 19:01:45 +0530 Subject: [PATCH 04/14] feat: Updated deadlock fix and added test cases --- src/strands/agent/agent.py | 6 +++ tests/strands/agent/test_agent.py | 68 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index dd2163c0f6..ef9b8f6df9 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -743,6 +743,10 @@ def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocatio with self._inflight_invocations_lock: if self._inflight_idempotency_token == idempotency_token: return self._inflight_invocation, None + elif self._inflight_idempotency_token is not None: + # A different token is already inflight; don't overwrite it. + # Fall through to the _invocation_lock check which will raise ConcurrencyException. + return None, None else: self._inflight_invocation = _InflightInvocation() self._inflight_idempotency_token = idempotency_token @@ -769,6 +773,8 @@ def _complete_idempotent_invocation( return with self._inflight_invocations_lock: + if self._inflight_idempotency_token != registered_token: + return # Another invocation owns the slot; don't touch it. inflight = self._inflight_invocation self._inflight_idempotency_token = None self._inflight_invocation = None diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6ed43d392b..e98ca47b0f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -3001,3 +3001,71 @@ def invoke(): assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" assert len(results) == 2, f"Expected 2 results, got {len(results)}" assert str(results[0]) == str(results[1]) + + +def test_idempotency_no_deadlock_on_competing_token(): + """A 3rd thread with a different token must not prevent a waiting duplicate from waking up. + + T1 runs with token "abc" → T2 (same token) waits as duplicate → T3 arrives with token "def" + and gets ConcurrencyException. T1 then completes and T2 must receive the result, not hang. + """ + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(("abc", result)) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(("def", result)) + except Exception as e: + with lock: + errors.append(e) + + # T1 starts and pauses inside the model + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + # T2 detects duplicate and waits + t2 = threading.Thread(target=invoke_abc) + t2.start() + agent.duplicate_detected.wait() + + # T3 arrives with a different token — must get ConcurrencyException, not corrupt T1's state + t3 = threading.Thread(target=invoke_def) + t3.start() + t3.join(timeout=2.0) + assert not t3.is_alive(), "T3 should have returned quickly with ConcurrencyException" + + # Unblock T1; T2 must wake up (not hang) + model.proceed_event.set() + t1.join(timeout=5.0) + t2.join(timeout=5.0) + + assert not t1.is_alive(), "T1 hung — possible deadlock" + assert not t2.is_alive(), "T2 hung — deadlock: waiting duplicate never woke up" + + abc_results = [r for name, r in results if name == "abc"] + assert len(abc_results) == 2, f"Expected T1 and T2 both to succeed, got results={results} errors={errors}" + assert str(abc_results[0]) == str(abc_results[1]) + + concurrency_errors = [e for e in errors if isinstance(e, ConcurrencyException)] + assert len(concurrency_errors) == 1, f"Expected exactly 1 ConcurrencyException for T3, got {errors}" From c9cb849595dc1296d86e63d938a0a6748646789a Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 22:29:42 +0530 Subject: [PATCH 05/14] feat: Idempotency with deadlock protection --- src/strands/agent/agent.py | 28 +++--- src/strands/types/exceptions.py | 12 +++ tests/strands/agent/test_agent.py | 140 ++++++++++++++++++++++-------- 3 files changed, 132 insertions(+), 48 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ef9b8f6df9..a51d123868 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,7 +62,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException, IdempotencyAbortedError from ..types.traces import AttributeValue from .agent_result import AgentResult from .base import AgentBase @@ -469,8 +469,9 @@ def __call__( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -522,8 +523,9 @@ async def invoke_async( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -761,8 +763,8 @@ def _complete_idempotent_invocation( """Signal waiting duplicates and clean up idempotency state. Safe to call even when registered_token is None (no-op in that case). - If both result and error are None (e.g. original caller disconnected), - sets asyncio.CancelledError so duplicates receive a clear error. + If both result and error are None (e.g. primary lost a lock race or was cancelled), + sets IdempotencyAbortedError so duplicates receive a clear error. Args: registered_token: The token that was registered by _check_idempotency, or None. @@ -787,7 +789,7 @@ def _complete_idempotent_invocation( elif result is not None: inflight.result = result else: - inflight.error = asyncio.CancelledError("Primary invocation was cancelled before completion.") + inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.") inflight.done.set() async def stream_async( @@ -819,8 +821,9 @@ async def stream_async( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -860,10 +863,11 @@ async def stream_async( if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: lock_acquired = self._invocation_lock.acquire(blocking=False) if not lock_acquired: - self._complete_idempotent_invocation(registered_token) - raise ConcurrencyException( + exc = ConcurrencyException( "Agent is already processing a request. Concurrent invocations are not supported." ) + self._complete_idempotent_invocation(registered_token, error=exc) + raise exc result: AgentResult | None = None diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1d1983abd4..c9f0678339 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -105,3 +105,15 @@ class ConcurrencyException(Exception): """ pass + + +class IdempotencyAbortedError(Exception): + """Exception raised to duplicate invocations when the primary invocation was aborted. + + When a caller provides an idempotency_token and another invocation with the same token + is already in-flight, the duplicate waits for the primary to complete. If the primary + is aborted before producing a result (e.g. it lost a lock race or was cancelled), + this exception is raised to all waiting duplicates. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index e98ca47b0f..df6e5a23b2 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -216,6 +216,42 @@ async def stream( yield event +class SyncEventFailingModel: + """A mock model that signals when streaming starts, then raises an error. + + Used for testing idempotency behavior when the original invocation fails. + """ + + def __init__(self): + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream(self, *args, **kwargs): + self.started_event.set() + self.proceed_event.wait() + raise RuntimeError("Simulated model failure") + yield # noqa: RET503 - makes this an async generator + + +class IdempotencyTestAgent(Agent): + """Agent subclass that signals when a duplicate idempotency token is detected. + + Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: + the model pauses Thread 1 inside stream(), and this class signals when Thread 2 + has reached _check_idempotency and been identified as a duplicate. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_detected = threading.Event() + + def _check_idempotency(self, idempotency_token): + result = super()._check_idempotency(idempotency_token) + if result[0] is not None: + self.duplicate_detected.set() + return result + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -2701,42 +2737,6 @@ def hook_callback(event: BeforeModelCallEvent): assert len(hook_called) == 1 -class SyncEventFailingModel: - """A mock model that signals when streaming starts, then raises an error. - - Used for testing idempotency behavior when the original invocation fails. - """ - - def __init__(self): - self.started_event = threading.Event() - self.proceed_event = threading.Event() - - async def stream(self, *args, **kwargs): - self.started_event.set() - self.proceed_event.wait() - raise RuntimeError("Simulated model failure") - yield # noqa: RET503 - makes this an async generator - - -class IdempotencyTestAgent(Agent): - """Agent subclass that signals when a duplicate idempotency token is detected. - - Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: - the model pauses Thread 1 inside stream(), and this class signals when Thread 2 - has reached _check_idempotency and been identified as a duplicate. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.duplicate_detected = threading.Event() - - def _check_idempotency(self, idempotency_token): - result = super()._check_idempotency(idempotency_token) - if result[0] is not None: - self.duplicate_detected.set() - return result - - def test_idempotency_duplicate_waits_and_returns_same_result(): """Test that a duplicate call with the same idempotency_token waits and returns the same result.""" model = SyncEventMockedModel( @@ -3069,3 +3069,71 @@ def invoke_def(): concurrency_errors = [e for e in errors if isinstance(e, ConcurrencyException)] assert len(concurrency_errors) == 1, f"Expected exactly 1 ConcurrencyException for T3, got {errors}" + + +def test_idempotency_multiple_duplicates_all_wake_up(): + """Test that multiple duplicates waiting on the same token all receive the result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + # T1 is the primary + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + # T2 and T3 are both duplicates waiting on the same token + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + agent.duplicate_detected.clear() + + t3 = threading.Thread(target=invoke) + t3.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + t3.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 3, f"Expected 3 results (T1, T2, T3), got {len(results)}" + assert str(results[0]) == str(results[1]) == str(results[2]) + + +def test_idempotency_cleanup_after_failure(): + """Test that after a failure, the same token is treated as a fresh request.""" + fail_model = SyncEventFailingModel() + agent = Agent(model=fail_model, concurrent_invocation_mode="throw") + + # First call fails + with pytest.raises(RuntimeError, match="Simulated model failure"): + fail_model.proceed_event.set() + agent("test", idempotency_token="abc") + + # Second call with the same token should run fresh, not be treated as a duplicate + success_model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "recovered"}]}, + ] + ) + agent.model = success_model + result = agent("test", idempotency_token="abc") + assert str(result).strip() == "recovered" From 137834f0a5deb791b46a146db508fa8712b19059 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 23:52:13 +0530 Subject: [PATCH 06/14] feat: updated comments in tests --- tests/strands/agent/test_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index df6e5a23b2..94b484e842 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -3049,7 +3049,7 @@ def invoke_def(): t2.start() agent.duplicate_detected.wait() - # T3 arrives with a different token — must get ConcurrencyException, not corrupt T1's state + # T3 arrives with a different token - must get ConcurrencyException, not corrupt T1's state t3 = threading.Thread(target=invoke_def) t3.start() t3.join(timeout=2.0) @@ -3060,8 +3060,8 @@ def invoke_def(): t1.join(timeout=5.0) t2.join(timeout=5.0) - assert not t1.is_alive(), "T1 hung — possible deadlock" - assert not t2.is_alive(), "T2 hung — deadlock: waiting duplicate never woke up" + assert not t1.is_alive(), "T1 hung - possible deadlock" + assert not t2.is_alive(), "T2 hung - deadlock: waiting duplicate never woke up" abc_results = [r for name, r in results if name == "abc"] assert len(abc_results) == 2, f"Expected T1 and T2 both to succeed, got results={results} errors={errors}" From ad40cec344ea48f169dceb826642ada5c6128929 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 22:30:33 +0530 Subject: [PATCH 07/14] feat: add per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 119 ++++++++++++ tests/strands/agent/test_agent.py | 302 ++++++++++++++++++++++++++++++ 2 files changed, 421 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133dec..b599b878e3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,10 +9,12 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import asyncio import logging import threading import warnings from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping +from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -98,6 +100,20 @@ class _DefaultRetryStrategySentinel: _DEFAULT_AGENT_ID = "default" +@dataclass +class _InflightInvocation: + """Tracks an inflight invocation for idempotency deduplication. + + When a caller provides an `idempotency_token`, the agent registers the invocation in a dict keyed by the token. + If a duplicate call arrives with the same token while the original is still running, the duplicate waits on the + `done` event and receives the same result or error. + """ + + done: threading.Event = field(default_factory=threading.Event) + result: AgentResult | None = None + error: BaseException | None = None + + class Agent(AgentBase): """Core Agent implementation. @@ -304,6 +320,14 @@ def __init__( self._invocation_lock = threading.Lock() self._concurrent_invocation_mode = concurrent_invocation_mode + # Tracks the single inflight invocation for idempotency duplicate detection. + # In THROW mode only one invocation can be inflight at a time, so a single + # variable suffices. Uses threading primitives (not asyncio) because run_async() + # creates separate threads with separate event loops. + self._inflight_idempotency_token: Any = None + self._inflight_invocation: _InflightInvocation | None = None + self._inflight_invocations_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy if ( @@ -444,6 +468,7 @@ def __call__( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -463,6 +488,10 @@ def __call__( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -480,6 +509,7 @@ def __call__( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) ) @@ -491,6 +521,7 @@ async def invoke_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -510,6 +541,10 @@ async def invoke_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -525,6 +560,7 @@ async def invoke_async( invocation_state=invocation_state, structured_output_model=structured_output_model, structured_output_prompt=structured_output_prompt, + idempotency_token=idempotency_token, **kwargs, ) async for event in events: @@ -741,6 +777,67 @@ def __del__(self) -> None: if hasattr(self, "tool_registry"): self.tool_registry.cleanup() + def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]: + """Check if this invocation is a duplicate of an inflight one, or register it as new. + + Only active in THROW mode. In UNSAFE_REENTRANT mode or when no token is provided, + this is a no-op that returns (None, None). + + Args: + idempotency_token: Caller-provided token for duplicate detection. + + Returns: + A tuple of (waiting_on, registered_token): + - If duplicate: (inflight_invocation_to_wait_on, None) + - If new request: (None, the_registered_token) + - If no token or wrong mode: (None, None) + """ + if idempotency_token is None or self._concurrent_invocation_mode != ConcurrentInvocationMode.THROW: + return None, None + + with self._inflight_invocations_lock: + if self._inflight_idempotency_token == idempotency_token: + return self._inflight_invocation, None + else: + self._inflight_invocation = _InflightInvocation() + self._inflight_idempotency_token = idempotency_token + return None, idempotency_token + + def _complete_idempotent_invocation( + self, + registered_token: Any, + result: AgentResult | None = None, + error: BaseException | None = None, + ) -> None: + """Signal waiting duplicates and clean up idempotency state. + + Safe to call even when registered_token is None (no-op in that case). + + Args: + registered_token: The token that was registered by _check_idempotency, or None. + result: The AgentResult to pass to waiting duplicates (success path). + error: The exception to pass to waiting duplicates (error path). + """ + if registered_token is None: + return + + with self._inflight_invocations_lock: + inflight = self._inflight_invocation + self._inflight_idempotency_token = None + self._inflight_invocation = None + + if inflight is None: + return + + if error is None and result is None: + error = asyncio.CancelledError("Primary invocation was cancelled before completion.") + + if error is not None: + inflight.error = error + else: + inflight.result = result + inflight.done.set() + async def stream_async( self, prompt: AgentInput = None, @@ -748,6 +845,7 @@ async def stream_async( invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, structured_output_prompt: str | None = None, + idempotency_token: Any = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -767,6 +865,10 @@ async def stream_async( invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). + idempotency_token: Optional token for duplicate request detection. If provided in THROW mode + and another invocation with the same token is already inflight, the caller waits for the + original to complete and receives the same result. Can be any hashable object (string, + UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -789,16 +891,30 @@ async def stream_async( yield event["data"] ``` """ + waiting_on, registered_token = self._check_idempotency(idempotency_token) + + if waiting_on is not None: + logger.debug("idempotency_token=<%s> | duplicate request detected, waiting for original", idempotency_token) + await asyncio.to_thread(waiting_on.done.wait) + if waiting_on.error is not None: + raise waiting_on.error + if waiting_on.result is not None: + yield AgentResultEvent(result=waiting_on.result).as_dict() + return + # Conditionally acquire lock based on concurrent_invocation_mode # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: lock_acquired = self._invocation_lock.acquire(blocking=False) if not lock_acquired: + self._complete_idempotent_invocation(registered_token) raise ConcurrencyException( "Agent is already processing a request. Concurrent invocations are not supported." ) + result: AgentResult | None = None + try: self._interrupt_state.resume(prompt) @@ -843,12 +959,15 @@ async def stream_async( except Exception as e: self._end_agent_trace_span(error=e) + self._complete_idempotent_invocation(registered_token, error=e) raise finally: # Clear cancel signal to allow agent reuse after cancellation self._cancel_signal.clear() + self._complete_idempotent_invocation(registered_token, result=result) + if self._invocation_lock.locked(): self._invocation_lock.release() diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c2..7f041c665f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2756,3 +2756,305 @@ def test_as_tool_defaults_description_when_agent_has_none(): tool = agent.as_tool() assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" + + +class SyncEventFailingModel: + """A mock model that signals when streaming starts, then raises an error. + + Used for testing idempotency behavior when the original invocation fails. + """ + + def __init__(self): + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream(self, *args, **kwargs): + self.started_event.set() + self.proceed_event.wait() + raise RuntimeError("Simulated model failure") + yield # noqa: RET503 - makes this an async generator + + +class IdempotencyTestAgent(Agent): + """Agent subclass that signals when a duplicate idempotency token is detected. + + Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: + the model pauses Thread 1 inside stream(), and this class signals when Thread 2 + has reached _check_idempotency and been identified as a duplicate. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_detected = threading.Event() + + def _check_idempotency(self, idempotency_token): + result = super()._check_idempotency(idempotency_token) + if result[0] is not None: + self.duplicate_detected.set() + return result + + +def test_idempotency_duplicate_waits_and_returns_same_result(): + """Test that a duplicate call with the same idempotency_token waits and returns the same result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc-123") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) + + +def test_idempotency_original_fails_duplicate_gets_same_error(): + """Test that when the original invocation fails, the duplicate receives the same exception.""" + model = SyncEventFailingModel() + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + errors = [] + lock = threading.Lock() + + def invoke(): + try: + agent("test", idempotency_token="abc-123") + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 2, f"Expected 2 errors, got {len(errors)}" + assert all(isinstance(e, RuntimeError) for e in errors) + assert all("Simulated model failure" in str(e) for e in errors) + + +def test_idempotency_different_token_raises_concurrency_exception(): + """Test that a different idempotency_token while another is inflight raises ConcurrencyException.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_def) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_no_token_falls_back_to_throw(): + """Test that a call without idempotency_token still gets ConcurrencyException in THROW mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_with_token(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + def invoke_without_token(): + try: + result = agent("test") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke_with_token) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke_without_token) + t2.start() + t2.join(timeout=1.0) + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert isinstance(errors[0], ConcurrencyException) + + +def test_idempotency_ignored_in_unsafe_reentrant(): + """Test that idempotency_token has no effect in UNSAFE_REENTRANT mode.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + + +def test_idempotency_cleanup_after_completion(): + """Test that after completion, the same token is treated as a fresh request.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="throw") + + result1 = agent("test", idempotency_token="abc") + assert str(result1).strip() == "response1" + + result2 = agent("test", idempotency_token="abc") + assert str(result2).strip() == "response2" + + assert str(result1) != str(result2) + + +def test_idempotency_with_prompt_as_token(): + """Test that the prompt itself can be used as the idempotency_token.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + prompt = "What's the weather?" + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent(prompt, idempotency_token=prompt) + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert str(results[0]) == str(results[1]) From abb8c0ab5d92ee0587cd62f027da0cab1b7cbb90 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 22:36:13 +0530 Subject: [PATCH 08/14] feat: added per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index b599b878e3..642aacb2bb 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -104,9 +104,10 @@ class _DefaultRetryStrategySentinel: class _InflightInvocation: """Tracks an inflight invocation for idempotency deduplication. - When a caller provides an `idempotency_token`, the agent registers the invocation in a dict keyed by the token. - If a duplicate call arrives with the same token while the original is still running, the duplicate waits on the - `done` event and receives the same result or error. + When a caller provides an `idempotency_token`, the agent registers this invocation + (in THROW mode only one can be inflight at a time). If a duplicate call arrives + with the same token while the original is still running, the duplicate waits on + the `done` event and receives the same result or error. """ done: threading.Event = field(default_factory=threading.Event) @@ -812,6 +813,8 @@ def _complete_idempotent_invocation( """Signal waiting duplicates and clean up idempotency state. Safe to call even when registered_token is None (no-op in that case). + If both result and error are None (e.g. original caller disconnected), + sets asyncio.CancelledError so duplicates receive a clear error. Args: registered_token: The token that was registered by _check_idempotency, or None. From 8a848a3f798663fa6f15ef086e2dcc332454f092 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Tue, 17 Mar 2026 23:21:14 +0530 Subject: [PATCH 09/14] feat: add per-invocation idempotency support via idempotency_token --- src/strands/agent/agent.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 642aacb2bb..d61ea3d209 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -832,13 +832,12 @@ def _complete_idempotent_invocation( if inflight is None: return - if error is None and result is None: - error = asyncio.CancelledError("Primary invocation was cancelled before completion.") - if error is not None: inflight.error = error - else: + elif result is not None: inflight.result = result + else: + inflight.error = asyncio.CancelledError("Primary invocation was cancelled before completion.") inflight.done.set() async def stream_async( From 17734e53c565833a2b6d6ce6e66181229de4ec79 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 19:01:45 +0530 Subject: [PATCH 10/14] feat: Updated deadlock fix and added test cases --- src/strands/agent/agent.py | 6 +++ tests/strands/agent/test_agent.py | 68 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d61ea3d209..5963438d66 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -799,6 +799,10 @@ def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocatio with self._inflight_invocations_lock: if self._inflight_idempotency_token == idempotency_token: return self._inflight_invocation, None + elif self._inflight_idempotency_token is not None: + # A different token is already inflight; don't overwrite it. + # Fall through to the _invocation_lock check which will raise ConcurrencyException. + return None, None else: self._inflight_invocation = _InflightInvocation() self._inflight_idempotency_token = idempotency_token @@ -825,6 +829,8 @@ def _complete_idempotent_invocation( return with self._inflight_invocations_lock: + if self._inflight_idempotency_token != registered_token: + return # Another invocation owns the slot; don't touch it. inflight = self._inflight_invocation self._inflight_idempotency_token = None self._inflight_invocation = None diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 7f041c665f..05777b4cfd 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -3058,3 +3058,71 @@ def invoke(): assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" assert len(results) == 2, f"Expected 2 results, got {len(results)}" assert str(results[0]) == str(results[1]) + + +def test_idempotency_no_deadlock_on_competing_token(): + """A 3rd thread with a different token must not prevent a waiting duplicate from waking up. + + T1 runs with token "abc" → T2 (same token) waits as duplicate → T3 arrives with token "def" + and gets ConcurrencyException. T1 then completes and T2 must receive the result, not hang. + """ + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke_abc(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(("abc", result)) + except Exception as e: + with lock: + errors.append(e) + + def invoke_def(): + try: + result = agent("test", idempotency_token="def") + with lock: + results.append(("def", result)) + except Exception as e: + with lock: + errors.append(e) + + # T1 starts and pauses inside the model + t1 = threading.Thread(target=invoke_abc) + t1.start() + model.started_event.wait() + + # T2 detects duplicate and waits + t2 = threading.Thread(target=invoke_abc) + t2.start() + agent.duplicate_detected.wait() + + # T3 arrives with a different token — must get ConcurrencyException, not corrupt T1's state + t3 = threading.Thread(target=invoke_def) + t3.start() + t3.join(timeout=2.0) + assert not t3.is_alive(), "T3 should have returned quickly with ConcurrencyException" + + # Unblock T1; T2 must wake up (not hang) + model.proceed_event.set() + t1.join(timeout=5.0) + t2.join(timeout=5.0) + + assert not t1.is_alive(), "T1 hung — possible deadlock" + assert not t2.is_alive(), "T2 hung — deadlock: waiting duplicate never woke up" + + abc_results = [r for name, r in results if name == "abc"] + assert len(abc_results) == 2, f"Expected T1 and T2 both to succeed, got results={results} errors={errors}" + assert str(abc_results[0]) == str(abc_results[1]) + + concurrency_errors = [e for e in errors if isinstance(e, ConcurrencyException)] + assert len(concurrency_errors) == 1, f"Expected exactly 1 ConcurrencyException for T3, got {errors}" From 4a7796fc859918d2ce820bba046c45c68d95ce8d Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 22:29:42 +0530 Subject: [PATCH 11/14] feat: Idempotency with deadlock protection --- src/strands/agent/agent.py | 28 ++++---- src/strands/types/exceptions.py | 12 ++++ tests/strands/agent/test_agent.py | 105 +++++++++++++++++++++++++++++- 3 files changed, 132 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5963438d66..3cef149e43 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,7 +62,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException, IdempotencyAbortedError from ..types.tools import AgentTool from ..types.traces import AttributeValue from ._agent_as_tool import _AgentAsTool @@ -491,8 +491,9 @@ def __call__( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -544,8 +545,9 @@ async def invoke_async( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -817,8 +819,8 @@ def _complete_idempotent_invocation( """Signal waiting duplicates and clean up idempotency state. Safe to call even when registered_token is None (no-op in that case). - If both result and error are None (e.g. original caller disconnected), - sets asyncio.CancelledError so duplicates receive a clear error. + If both result and error are None (e.g. primary lost a lock race or was cancelled), + sets IdempotencyAbortedError so duplicates receive a clear error. Args: registered_token: The token that was registered by _check_idempotency, or None. @@ -843,7 +845,7 @@ def _complete_idempotent_invocation( elif result is not None: inflight.result = result else: - inflight.error = asyncio.CancelledError("Primary invocation was cancelled before completion.") + inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.") inflight.done.set() async def stream_async( @@ -875,8 +877,9 @@ async def stream_async( structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). idempotency_token: Optional token for duplicate request detection. If provided in THROW mode and another invocation with the same token is already inflight, the caller waits for the - original to complete and receives the same result. Can be any hashable object (string, - UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. + original to complete and receives the same result. Duplicate callers receive only the + final AgentResult; intermediate streaming events are not replayed. Can be any hashable + object (string, UUID, or even the prompt itself). Ignored in UNSAFE_REENTRANT mode. **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -916,10 +919,11 @@ async def stream_async( if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: lock_acquired = self._invocation_lock.acquire(blocking=False) if not lock_acquired: - self._complete_idempotent_invocation(registered_token) - raise ConcurrencyException( + exc = ConcurrencyException( "Agent is already processing a request. Concurrent invocations are not supported." ) + self._complete_idempotent_invocation(registered_token, error=exc) + raise exc result: AgentResult | None = None diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1d1983abd4..c9f0678339 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -105,3 +105,15 @@ class ConcurrencyException(Exception): """ pass + + +class IdempotencyAbortedError(Exception): + """Exception raised to duplicate invocations when the primary invocation was aborted. + + When a caller provides an idempotency_token and another invocation with the same token + is already in-flight, the duplicate waits for the primary to complete. If the primary + is aborted before producing a result (e.g. it lost a lock race or was cancelled), + this exception is raised to all waiting duplicates. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 05777b4cfd..22c2b857c0 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -58,7 +58,6 @@ async def stream(*args, **kwargs): mock = unittest.mock.Mock(spec=getattr(request, "param", None)) mock.configure_mock(mock_stream=unittest.mock.MagicMock()) mock.stream.side_effect = stream - mock.stateful = False return mock @@ -218,6 +217,42 @@ async def stream( yield event +class SyncEventFailingModel: + """A mock model that signals when streaming starts, then raises an error. + + Used for testing idempotency behavior when the original invocation fails. + """ + + def __init__(self): + self.started_event = threading.Event() + self.proceed_event = threading.Event() + + async def stream(self, *args, **kwargs): + self.started_event.set() + self.proceed_event.wait() + raise RuntimeError("Simulated model failure") + yield # noqa: RET503 - makes this an async generator + + +class IdempotencyTestAgent(Agent): + """Agent subclass that signals when a duplicate idempotency token is detected. + + Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: + the model pauses Thread 1 inside stream(), and this class signals when Thread 2 + has reached _check_idempotency and been identified as a duplicate. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.duplicate_detected = threading.Event() + + def _check_idempotency(self, idempotency_token): + result = super()._check_idempotency(idempotency_token) + if result[0] is not None: + self.duplicate_detected.set() + return result + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -3126,3 +3161,71 @@ def invoke_def(): concurrency_errors = [e for e in errors if isinstance(e, ConcurrencyException)] assert len(concurrency_errors) == 1, f"Expected exactly 1 ConcurrencyException for T3, got {errors}" + + +def test_idempotency_multiple_duplicates_all_wake_up(): + """Test that multiple duplicates waiting on the same token all receive the result.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + ) + agent = IdempotencyTestAgent(model=model, concurrent_invocation_mode="throw") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test", idempotency_token="abc") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + # T1 is the primary + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() + + # T2 and T3 are both duplicates waiting on the same token + t2 = threading.Thread(target=invoke) + t2.start() + agent.duplicate_detected.wait() + agent.duplicate_detected.clear() + + t3 = threading.Thread(target=invoke) + t3.start() + agent.duplicate_detected.wait() + + model.proceed_event.set() + t1.join() + t2.join() + t3.join() + + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 3, f"Expected 3 results (T1, T2, T3), got {len(results)}" + assert str(results[0]) == str(results[1]) == str(results[2]) + + +def test_idempotency_cleanup_after_failure(): + """Test that after a failure, the same token is treated as a fresh request.""" + fail_model = SyncEventFailingModel() + agent = Agent(model=fail_model, concurrent_invocation_mode="throw") + + # First call fails + with pytest.raises(RuntimeError, match="Simulated model failure"): + fail_model.proceed_event.set() + agent("test", idempotency_token="abc") + + # Second call with the same token should run fresh, not be treated as a duplicate + success_model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "recovered"}]}, + ] + ) + agent.model = success_model + result = agent("test", idempotency_token="abc") + assert str(result).strip() == "recovered" From 90293bf6a36db1aaf31a8d36c698da64189a6706 Mon Sep 17 00:00:00 2001 From: BV-Venky Date: Thu, 19 Mar 2026 23:52:13 +0530 Subject: [PATCH 12/14] feat: updated comments in tests --- tests/strands/agent/test_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 22c2b857c0..824b5387e6 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -3141,7 +3141,7 @@ def invoke_def(): t2.start() agent.duplicate_detected.wait() - # T3 arrives with a different token — must get ConcurrencyException, not corrupt T1's state + # T3 arrives with a different token - must get ConcurrencyException, not corrupt T1's state t3 = threading.Thread(target=invoke_def) t3.start() t3.join(timeout=2.0) @@ -3152,8 +3152,8 @@ def invoke_def(): t1.join(timeout=5.0) t2.join(timeout=5.0) - assert not t1.is_alive(), "T1 hung — possible deadlock" - assert not t2.is_alive(), "T2 hung — deadlock: waiting duplicate never woke up" + assert not t1.is_alive(), "T1 hung - possible deadlock" + assert not t2.is_alive(), "T2 hung - deadlock: waiting duplicate never woke up" abc_results = [r for name, r in results if name == "abc"] assert len(abc_results) == 2, f"Expected T1 and T2 both to succeed, got results={results} errors={errors}" From ff2ed0c85ebb7e7d7bc78ab3bab1ebfd0d34c7f0 Mon Sep 17 00:00:00 2001 From: Venkatesh Bhukya Date: Mon, 13 Apr 2026 21:53:41 +0530 Subject: [PATCH 13/14] feat: rebase with main --- tests/strands/agent/test_agent.py | 39 +++---------------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 824b5387e6..02bea3bcc7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -58,6 +58,7 @@ async def stream(*args, **kwargs): mock = unittest.mock.Mock(spec=getattr(request, "param", None)) mock.configure_mock(mock_stream=unittest.mock.MagicMock()) mock.stream.side_effect = stream + mock.stateful = False return mock @@ -223,6 +224,8 @@ class SyncEventFailingModel: Used for testing idempotency behavior when the original invocation fails. """ + stateful = False + def __init__(self): self.started_event = threading.Event() self.proceed_event = threading.Event() @@ -2793,42 +2796,6 @@ def test_as_tool_defaults_description_when_agent_has_none(): assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" -class SyncEventFailingModel: - """A mock model that signals when streaming starts, then raises an error. - - Used for testing idempotency behavior when the original invocation fails. - """ - - def __init__(self): - self.started_event = threading.Event() - self.proceed_event = threading.Event() - - async def stream(self, *args, **kwargs): - self.started_event.set() - self.proceed_event.wait() - raise RuntimeError("Simulated model failure") - yield # noqa: RET503 - makes this an async generator - - -class IdempotencyTestAgent(Agent): - """Agent subclass that signals when a duplicate idempotency token is detected. - - Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: - the model pauses Thread 1 inside stream(), and this class signals when Thread 2 - has reached _check_idempotency and been identified as a duplicate. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.duplicate_detected = threading.Event() - - def _check_idempotency(self, idempotency_token): - result = super()._check_idempotency(idempotency_token) - if result[0] is not None: - self.duplicate_detected.set() - return result - - def test_idempotency_duplicate_waits_and_returns_same_result(): """Test that a duplicate call with the same idempotency_token waits and returns the same result.""" model = SyncEventMockedModel( From 11f006bf2252d7b626068d0ba3247b2e7344dc5c Mon Sep 17 00:00:00 2001 From: Venkatesh Bhukya Date: Fri, 29 May 2026 14:27:29 +0530 Subject: [PATCH 14/14] refactor(agent): extract concurrency and idempotency into _ConcurrencyController --- strands-py/src/strands/agent/_concurrency.py | 172 ++++++++++++ strands-py/src/strands/agent/agent.py | 148 ++-------- strands-py/src/strands/tools/_caller.py | 4 +- strands-py/tests/strands/agent/test_agent.py | 29 +- .../tests/strands/agent/test_concurrency.py | 255 ++++++++++++++++++ 5 files changed, 472 insertions(+), 136 deletions(-) create mode 100644 strands-py/src/strands/agent/_concurrency.py create mode 100644 strands-py/tests/strands/agent/test_concurrency.py diff --git a/strands-py/src/strands/agent/_concurrency.py b/strands-py/src/strands/agent/_concurrency.py new file mode 100644 index 0000000000..a6dc25cc44 --- /dev/null +++ b/strands-py/src/strands/agent/_concurrency.py @@ -0,0 +1,172 @@ +"""Concurrency and idempotency control for Agent invocations. + +Encapsulates the per-Agent state that guards against concurrent invocations and +deduplicates retried requests via caller-supplied idempotency tokens. Designed to +be used by ``Agent.stream_async`` as a single delegate, keeping the orchestration +in ``agent.py`` and the synchronization primitives + bookkeeping here. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from ..types.agent import ConcurrentInvocationMode +from ..types.exceptions import IdempotencyAbortedError + +if TYPE_CHECKING: + from .agent_result import AgentResult + + +@dataclass +class _InflightInvocation: + """Tracks an inflight invocation for idempotency deduplication. + + Duplicate callers wait on ``done`` and then read ``result`` or ``error``. + """ + + done: threading.Event = field(default_factory=threading.Event) + result: "AgentResult | None" = None + error: BaseException | None = None + + +@dataclass +class _BeginResult: + """Outcome of ``_ConcurrencyController.begin``. + + Exactly one of the following is the actionable signal for the caller: + + - ``waiting_on`` is set: this call is a duplicate of an inflight token. Wait on + ``waiting_on.done`` and then yield the cached result or raise the cached error. + - ``lock_acquired`` is False: a different invocation owns the lock. Raise + ``ConcurrencyException``. + - Otherwise: proceed with the invocation. Pass ``registered_token`` back to + ``complete()`` in the success and error paths so waiters get unblocked. + """ + + waiting_on: _InflightInvocation | None + registered_token: Any + lock_acquired: bool + + +class _ConcurrencyController: + """Owns the invocation lock and the inflight idempotency-token registry. + + In THROW mode only one invocation can be inflight at a time, so a single + inflight slot suffices. Uses ``threading`` primitives (not asyncio) because + ``Agent.run_async()`` may spawn separate event loops on separate threads. + """ + + def __init__(self, mode: ConcurrentInvocationMode) -> None: + self._mode = mode + self._invocation_lock = threading.Lock() + self._inflight_token: Any = None + self._inflight: _InflightInvocation | None = None + self._inflight_lock = threading.Lock() + + @property + def mode(self) -> ConcurrentInvocationMode: + """Return the configured concurrency mode.""" + return self._mode + + def begin(self, idempotency_token: Any) -> _BeginResult: + """Attempt to start a new invocation. + + Combines idempotency-check + lock-acquire into a single call. The returned + ``_BeginResult`` tells the caller which of three paths to take. + + Args: + idempotency_token: Caller-provided dedup token, or None. + + Returns: + See ``_BeginResult``. If ``waiting_on`` is set, the lock is *not* held + and ``registered_token`` is None. + """ + waiting_on, registered_token = self._check_idempotency(idempotency_token) + if waiting_on is not None: + return _BeginResult(waiting_on=waiting_on, registered_token=None, lock_acquired=False) + + lock_acquired = True + if self._mode == ConcurrentInvocationMode.THROW: + lock_acquired = self._invocation_lock.acquire(blocking=False) + + return _BeginResult(waiting_on=None, registered_token=registered_token, lock_acquired=lock_acquired) + + def complete( + self, + registered_token: Any, + *, + result: "AgentResult | None" = None, + error: BaseException | None = None, + ) -> None: + """Signal waiting duplicates and clear the inflight slot. + + Safe to call multiple times for the same ``registered_token`` (subsequent + calls no-op once the slot has been cleared). Safe to call with + ``registered_token=None`` (no-op). + + If both ``result`` and ``error`` are None, waiters receive + ``IdempotencyAbortedError``. + """ + if registered_token is None: + return + + with self._inflight_lock: + if self._inflight_token != registered_token: + # Another invocation owns the slot (or it was already cleared). + return + inflight = self._inflight + self._inflight_token = None + self._inflight = None + + if inflight is None: + return + + if error is not None: + inflight.error = error + elif result is not None: + inflight.result = result + else: + inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.") + inflight.done.set() + + def try_acquire_lock(self) -> bool: + """Non-blockingly acquire the invocation lock. + + Exposed for direct tool callers that bypass the full idempotency flow but + still need to serialize against an inflight invocation. + + Returns: + True if the lock was acquired, False otherwise. + """ + return self._invocation_lock.acquire(blocking=False) + + def release_lock(self) -> None: + """Release the invocation lock if it is held. Safe to call unconditionally.""" + if self._invocation_lock.locked(): + self._invocation_lock.release() + + def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]: + """Register a new inflight token, identify a duplicate, or no-op. + + Returns: + ``(waiting_on, registered_token)``: + - duplicate: ``(inflight_invocation, None)`` + - new request: ``(None, idempotency_token)`` + - different token already inflight, no token provided, or + UNSAFE_REENTRANT mode: ``(None, None)`` + """ + if idempotency_token is None or self._mode != ConcurrentInvocationMode.THROW: + return None, None + + with self._inflight_lock: + if self._inflight_token == idempotency_token: + return self._inflight, None + if self._inflight_token is not None: + # A different token is inflight; don't overwrite. Caller will hit the + # lock-acquire path and surface ConcurrencyException. + return None, None + self._inflight = _InflightInvocation() + self._inflight_token = idempotency_token + return None, idempotency_token diff --git a/strands-py/src/strands/agent/agent.py b/strands-py/src/strands/agent/agent.py index ec470a1307..680e7ada83 100644 --- a/strands-py/src/strands/agent/agent.py +++ b/strands-py/src/strands/agent/agent.py @@ -15,7 +15,6 @@ import threading import warnings from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping -from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, Any, @@ -70,10 +69,11 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException, IdempotencyAbortedError +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.tools import AgentTool from ..types.traces import AttributeValue from ._agent_as_tool import _AgentAsTool +from ._concurrency import _ConcurrencyController from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -108,21 +108,6 @@ class _DefaultRetryStrategySentinel: _DEFAULT_AGENT_ID = "default" -@dataclass -class _InflightInvocation: - """Tracks an inflight invocation for idempotency deduplication. - - When a caller provides an `idempotency_token`, the agent registers this invocation - (in THROW mode only one can be inflight at a time). If a duplicate call arrives - with the same token while the original is still running, the duplicate waits on - the `done` event and receives the same result or error. - """ - - done: threading.Event = field(default_factory=threading.Event) - result: AgentResult | None = None - error: BaseException | None = None - - class Agent(AgentBase): """Core Agent implementation. @@ -324,19 +309,7 @@ def __init__( # Runtime state for model providers (e.g., server-side response ids) self._model_state: dict[str, Any] = {} - # Initialize lock for guarding concurrent invocations - # Using threading.Lock instead of asyncio.Lock because run_async() creates - # separate event loops in different threads, so asyncio.Lock wouldn't work - self._invocation_lock = threading.Lock() - self._concurrent_invocation_mode = concurrent_invocation_mode - - # Tracks the single inflight invocation for idempotency duplicate detection. - # In THROW mode only one invocation can be inflight at a time, so a single - # variable suffices. Uses threading primitives (not asyncio) because run_async() - # creates separate threads with separate event loops. - self._inflight_idempotency_token: Any = None - self._inflight_invocation: _InflightInvocation | None = None - self._inflight_invocations_lock = threading.Lock() + self._concurrency = _ConcurrencyController(concurrent_invocation_mode) # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy @@ -490,6 +463,14 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) + @property + def concurrent_invocation_mode(self) -> ConcurrentInvocationMode: + """The concurrency posture this agent was configured with. + + Mirrors the ``concurrent_invocation_mode`` constructor argument. + """ + return self._concurrency.mode + def __call__( self, prompt: AgentInput = None, @@ -808,74 +789,6 @@ def __del__(self) -> None: if hasattr(self, "tool_registry"): self.tool_registry.cleanup() - def _check_idempotency(self, idempotency_token: Any) -> tuple[_InflightInvocation | None, Any]: - """Check if this invocation is a duplicate of an inflight one, or register it as new. - - Only active in THROW mode. In UNSAFE_REENTRANT mode or when no token is provided, - this is a no-op that returns (None, None). - - Args: - idempotency_token: Caller-provided token for duplicate detection. - - Returns: - A tuple of (waiting_on, registered_token): - - If duplicate: (inflight_invocation_to_wait_on, None) - - If new request: (None, the_registered_token) - - If no token or wrong mode: (None, None) - """ - if idempotency_token is None or self._concurrent_invocation_mode != ConcurrentInvocationMode.THROW: - return None, None - - with self._inflight_invocations_lock: - if self._inflight_idempotency_token == idempotency_token: - return self._inflight_invocation, None - elif self._inflight_idempotency_token is not None: - # A different token is already inflight; don't overwrite it. - # Fall through to the _invocation_lock check which will raise ConcurrencyException. - return None, None - else: - self._inflight_invocation = _InflightInvocation() - self._inflight_idempotency_token = idempotency_token - return None, idempotency_token - - def _complete_idempotent_invocation( - self, - registered_token: Any, - result: AgentResult | None = None, - error: BaseException | None = None, - ) -> None: - """Signal waiting duplicates and clean up idempotency state. - - Safe to call even when registered_token is None (no-op in that case). - If both result and error are None (e.g. primary lost a lock race or was cancelled), - sets IdempotencyAbortedError so duplicates receive a clear error. - - Args: - registered_token: The token that was registered by _check_idempotency, or None. - result: The AgentResult to pass to waiting duplicates (success path). - error: The exception to pass to waiting duplicates (error path). - """ - if registered_token is None: - return - - with self._inflight_invocations_lock: - if self._inflight_idempotency_token != registered_token: - return # Another invocation owns the slot; don't touch it. - inflight = self._inflight_invocation - self._inflight_idempotency_token = None - self._inflight_invocation = None - - if inflight is None: - return - - if error is not None: - inflight.error = error - elif result is not None: - inflight.result = result - else: - inflight.error = IdempotencyAbortedError("Primary invocation was aborted before producing a result.") - inflight.done.set() - async def stream_async( self, prompt: AgentInput = None, @@ -930,28 +843,23 @@ async def stream_async( yield event["data"] ``` """ - waiting_on, registered_token = self._check_idempotency(idempotency_token) + begin = self._concurrency.begin(idempotency_token) - if waiting_on is not None: + if begin.waiting_on is not None: logger.debug("idempotency_token=<%s> | duplicate request detected, waiting for original", idempotency_token) - await asyncio.to_thread(waiting_on.done.wait) - if waiting_on.error is not None: - raise waiting_on.error - if waiting_on.result is not None: - yield AgentResultEvent(result=waiting_on.result).as_dict() + await asyncio.to_thread(begin.waiting_on.done.wait) + if begin.waiting_on.error is not None: + raise begin.waiting_on.error + if begin.waiting_on.result is not None: + yield AgentResultEvent(result=begin.waiting_on.result).as_dict() return - # Conditionally acquire lock based on concurrent_invocation_mode - # Using threading.Lock instead of asyncio.Lock because run_async() creates - # separate event loops in different threads - if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: - lock_acquired = self._invocation_lock.acquire(blocking=False) - if not lock_acquired: - exc = ConcurrencyException( - "Agent is already processing a request. Concurrent invocations are not supported." - ) - self._complete_idempotent_invocation(registered_token, error=exc) - raise exc + if not begin.lock_acquired: + exc = ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) + self._concurrency.complete(begin.registered_token, error=exc) + raise exc result: AgentResult | None = None @@ -999,17 +907,15 @@ async def stream_async( except Exception as e: self._end_agent_trace_span(error=e) - self._complete_idempotent_invocation(registered_token, error=e) + self._concurrency.complete(begin.registered_token, error=e) raise finally: # Clear cancel signal to allow agent reuse after cancellation self._cancel_signal.clear() - self._complete_idempotent_invocation(registered_token, result=result) - - if self._invocation_lock.locked(): - self._invocation_lock.release() + self._concurrency.complete(begin.registered_token, result=result) + self._concurrency.release_lock() async def _run_loop( self, diff --git a/strands-py/src/strands/tools/_caller.py b/strands-py/src/strands/tools/_caller.py index 0b5408f351..5daa60a659 100644 --- a/strands-py/src/strands/tools/_caller.py +++ b/strands-py/src/strands/tools/_caller.py @@ -96,7 +96,7 @@ def caller( acquired_lock = ( should_lock and isinstance(self._agent, Agent) - and self._agent._invocation_lock.acquire_lock(blocking=False) + and self._agent._concurrency.try_acquire_lock() ) if should_lock and not acquired_lock: raise ConcurrencyException( @@ -141,7 +141,7 @@ async def acall() -> ToolResult: finally: if acquired_lock and isinstance(self._agent, Agent): - self._agent._invocation_lock.release() + self._agent._concurrency.release_lock() return caller diff --git a/strands-py/tests/strands/agent/test_agent.py b/strands-py/tests/strands/agent/test_agent.py index 67f58cd191..215360139e 100644 --- a/strands-py/tests/strands/agent/test_agent.py +++ b/strands-py/tests/strands/agent/test_agent.py @@ -242,18 +242,21 @@ class IdempotencyTestAgent(Agent): Pairs with SyncEventMockedModel to provide deterministic two-thread synchronization: the model pauses Thread 1 inside stream(), and this class signals when Thread 2 - has reached _check_idempotency and been identified as a duplicate. + has reached the concurrency controller and been identified as a duplicate. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.duplicate_detected = threading.Event() + original_begin = self._concurrency.begin - def _check_idempotency(self, idempotency_token): - result = super()._check_idempotency(idempotency_token) - if result[0] is not None: - self.duplicate_detected.set() - return result + def begin_with_signal(token): + result = original_begin(token) + if result.waiting_on is not None: + self.duplicate_detected.set() + return result + + self._concurrency.begin = begin_with_signal def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): @@ -2458,7 +2461,7 @@ def test_agent_concurrent_invocation_mode_default_is_throw(): agent = Agent(model=model) # Verify the default mode - assert agent._concurrent_invocation_mode == "throw" + assert agent.concurrent_invocation_mode == "throw" def test_agent_concurrent_invocation_mode_stores_value(): @@ -2466,10 +2469,10 @@ def test_agent_concurrent_invocation_mode_stores_value(): model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) agent_throw = Agent(model=model, concurrent_invocation_mode="throw") - assert agent_throw._concurrent_invocation_mode == "throw" + assert agent_throw.concurrent_invocation_mode == "throw" agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") - assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + assert agent_reentrant.concurrent_invocation_mode == "unsafe_reentrant" def test_agent_concurrent_invocation_mode_accepts_enum(): @@ -2479,12 +2482,12 @@ def test_agent_concurrent_invocation_mode_accepts_enum(): # Using enum values agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW) - assert agent_throw._concurrent_invocation_mode == "throw" - assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW + assert agent_throw.concurrent_invocation_mode == "throw" + assert agent_throw.concurrent_invocation_mode == ConcurrentInvocationMode.THROW agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT) - assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" - assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT + assert agent_reentrant.concurrent_invocation_mode == "unsafe_reentrant" + assert agent_reentrant.concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT @pytest.mark.asyncio diff --git a/strands-py/tests/strands/agent/test_concurrency.py b/strands-py/tests/strands/agent/test_concurrency.py new file mode 100644 index 0000000000..0e06bd27da --- /dev/null +++ b/strands-py/tests/strands/agent/test_concurrency.py @@ -0,0 +1,255 @@ +"""Unit tests for _ConcurrencyController (idempotency + invocation locking).""" + +import threading +from unittest.mock import MagicMock + +import pytest + +from strands.agent._concurrency import _ConcurrencyController, _InflightInvocation +from strands.types.agent import ConcurrentInvocationMode +from strands.types.exceptions import IdempotencyAbortedError + + +@pytest.fixture +def controller(): + return _ConcurrencyController(ConcurrentInvocationMode.THROW) + + +@pytest.fixture +def reentrant_controller(): + return _ConcurrencyController(ConcurrentInvocationMode.UNSAFE_REENTRANT) + + +def test_mode_property(controller, reentrant_controller): + assert controller.mode == ConcurrentInvocationMode.THROW + assert reentrant_controller.mode == ConcurrentInvocationMode.UNSAFE_REENTRANT + + +def test_begin_first_call_acquires_lock_and_registers_token(controller): + result = controller.begin("abc") + + assert result.waiting_on is None + assert result.registered_token == "abc" + assert result.lock_acquired is True + + +def test_begin_without_token_acquires_lock_but_registers_nothing(controller): + result = controller.begin(None) + + assert result.waiting_on is None + assert result.registered_token is None + assert result.lock_acquired is True + + +def test_begin_duplicate_token_returns_waiting_on(controller): + controller.begin("abc") + second = controller.begin("abc") + + assert second.waiting_on is not None + assert second.lock_acquired is False + assert second.registered_token is None + assert isinstance(second.waiting_on, _InflightInvocation) + + +def test_begin_different_token_while_inflight_fails_lock(controller): + controller.begin("abc") + + second = controller.begin("def") + + assert second.waiting_on is None + assert second.registered_token is None + assert second.lock_acquired is False + + +def test_begin_no_token_while_inflight_fails_lock(controller): + controller.begin("abc") + + second = controller.begin(None) + + assert second.waiting_on is None + assert second.registered_token is None + assert second.lock_acquired is False + + +def test_complete_with_result_signals_waiters(controller): + first = controller.begin("abc") + dup = controller.begin("abc") + + mock_result = MagicMock() + controller.complete(first.registered_token, result=mock_result) + + assert dup.waiting_on.done.is_set() + assert dup.waiting_on.result is mock_result + assert dup.waiting_on.error is None + + +def test_complete_with_error_signals_waiters(controller): + first = controller.begin("abc") + dup = controller.begin("abc") + + err = RuntimeError("boom") + controller.complete(first.registered_token, error=err) + + assert dup.waiting_on.done.is_set() + assert dup.waiting_on.error is err + assert dup.waiting_on.result is None + + +def test_complete_with_neither_result_nor_error_sets_aborted(controller): + first = controller.begin("abc") + dup = controller.begin("abc") + + controller.complete(first.registered_token) + + assert dup.waiting_on.done.is_set() + assert isinstance(dup.waiting_on.error, IdempotencyAbortedError) + + +def test_complete_is_idempotent_on_double_call(controller): + """except + finally both call complete(); second must no-op.""" + first = controller.begin("abc") + dup = controller.begin("abc") + + err = RuntimeError("first") + controller.complete(first.registered_token, error=err) + # Second call (e.g. from finally with result=None) must not overwrite. + controller.complete(first.registered_token, result=None) + + assert dup.waiting_on.error is err # unchanged + + +def test_complete_with_none_token_is_noop(controller): + controller.begin("abc") + # Should not touch the inflight slot. + controller.complete(None, error=RuntimeError("x")) + + # Inflight slot still owns "abc". + second = controller.begin("def") + assert second.lock_acquired is False # "abc" is still inflight + + +def test_complete_after_cleared_is_noop(controller): + first = controller.begin("abc") + controller.complete(first.registered_token, result=MagicMock()) + + # Slot is clear; calling complete again on the same token is a safe no-op. + controller.complete(first.registered_token, error=RuntimeError("late")) + + +def test_release_lock_when_held(controller): + controller.begin("abc") + controller.release_lock() + + # Lock is now free; a new begin (after completing the inflight) should acquire it. + controller.complete("abc", result=MagicMock()) + result = controller.begin("def") + assert result.lock_acquired is True + + +def test_release_lock_when_not_held_is_noop(controller): + controller.release_lock() # should not raise + controller.release_lock() + + +def test_completion_clears_slot_so_next_begin_is_fresh(controller): + first = controller.begin("abc") + controller.complete(first.registered_token, result=MagicMock()) + controller.release_lock() + + second = controller.begin("abc") + assert second.waiting_on is None + assert second.registered_token == "abc" + assert second.lock_acquired is True + + +def test_unsafe_reentrant_ignores_idempotency_token(reentrant_controller): + first = reentrant_controller.begin("abc") + second = reentrant_controller.begin("abc") + + assert first.waiting_on is None + assert first.registered_token is None + assert first.lock_acquired is True + assert second.waiting_on is None + assert second.registered_token is None + assert second.lock_acquired is True + + +def test_unsafe_reentrant_complete_with_none_token_is_noop(reentrant_controller): + first = reentrant_controller.begin("abc") + # registered_token is None in UNSAFE_REENTRANT, so complete is a no-op. + reentrant_controller.complete(first.registered_token, result=MagicMock()) + + +def test_multiple_duplicates_all_wake_up(controller): + first = controller.begin("abc") + dup1 = controller.begin("abc") + dup2 = controller.begin("abc") + dup3 = controller.begin("abc") + + mock_result = MagicMock() + controller.complete(first.registered_token, result=mock_result) + + assert dup1.waiting_on.done.is_set() + assert dup2.waiting_on.done.is_set() + assert dup3.waiting_on.done.is_set() + # All three see the same _InflightInvocation instance. + assert dup1.waiting_on is dup2.waiting_on is dup3.waiting_on + assert dup1.waiting_on.result is mock_result + + +def test_duplicate_does_not_acquire_lock(controller): + """Verify the lock stays held by the primary while a duplicate is waiting.""" + first = controller.begin("abc") + assert first.lock_acquired is True + + dup = controller.begin("abc") + assert dup.lock_acquired is False # duplicate doesn't claim the lock + + # Now if a third party with a different token arrives, it should fail on lock. + other = controller.begin("xyz") + assert other.lock_acquired is False + + +def test_lock_acquire_fail_path_cleanup_via_complete(controller): + """Simulate the lock-acquire-fail cleanup pattern used by stream_async.""" + controller.begin("abc") # T1 owns the slot and the lock + + # T2 with same token would be a duplicate, but here we exercise a different + # token to mirror the "lock-fail with newly-registered token" scenario. + # Note: when token differs, registered_token is None (no cleanup needed). + second = controller.begin("def") + assert second.registered_token is None + # Calling complete with None is a safe no-op: + controller.complete(second.registered_token, error=RuntimeError("would-be ConcurrencyException")) + + +def test_concurrent_begin_only_one_primary_others_duplicates(controller): + """Stress test: many threads call begin with the same token concurrently. + + Exactly one must become the primary (lock_acquired=True, registered_token set); + all others must be duplicates (waiting_on set, registered_token=None). + """ + barrier = threading.Barrier(10) + results = [] + lock = threading.Lock() + + def call(): + barrier.wait() + r = controller.begin("abc") + with lock: + results.append(r) + + threads = [threading.Thread(target=call) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + primaries = [r for r in results if r.registered_token == "abc"] + duplicates = [r for r in results if r.waiting_on is not None] + + assert len(primaries) == 1 + assert len(duplicates) == 9 + assert primaries[0].lock_acquired is True + assert all(d.lock_acquired is False for d in duplicates) + assert all(d.registered_token is None for d in duplicates)