Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion strands-py/src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .._async import run_async
from ..event_loop._retry import ModelRetryStrategy
from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle
from ..experimental.checkpoint import Checkpoint, CheckpointPosition
from ..tools._tool_helpers import generate_missing_tool_result_content
from ..types._snapshot import (
SNAPSHOT_SCHEMA_VERSION,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW,
checkpointing: bool = False,
Comment thread
JackYPCOnline marked this conversation as resolved.
Comment thread
JackYPCOnline marked this conversation as resolved.
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -214,6 +216,14 @@ def __init__(
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations.
Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided
only for advanced use cases where the caller understands the risks.
checkpointing: When True, the event loop pauses at cycle boundaries
(after_model, after_tools) and returns ``stop_reason="checkpoint"``
with a populated ``checkpoint`` field. Resume by passing the
checkpoint back as ``{"checkpointResume": {"checkpoint": ...}}``
(or a one-element list of the same). The SDK does not capture
conversation state in the checkpoint; pair with a SessionManager
for cross-process state continuity. Defaults to False.
See :mod:`strands.experimental.checkpoint`.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -304,6 +314,12 @@ def __init__(

self._interrupt_state = _InterruptState()

# Checkpointing: pause at cycle boundaries when enabled.
self._checkpointing: bool = checkpointing
self._checkpoint: Checkpoint | None = None
self._checkpoint_cycle_index: int = 0
self._checkpoint_resume_position: CheckpointPosition | None = None

# Runtime state for model providers (e.g., server-side response ids)
self._model_state: dict[str, Any] = {}

Expand Down Expand Up @@ -374,7 +390,7 @@ def cancel(self) -> None:
This method is thread-safe and can be called from any context
(e.g., another thread, web request handler, background task).

The agent will stop gracefully at the next checkpoint:
The agent will stop gracefully at the next cancellation-safe point:
- During model response streaming
- Before tool execution

Expand Down Expand Up @@ -1006,10 +1022,48 @@ async def _execute_event_loop_cycle(
if structured_output_context:
structured_output_context.cleanup(self.tool_registry)

def _try_consume_checkpoint_resume(self, prompt: Any) -> bool:
"""Consume a ``checkpointResume`` prompt block, returning True if found.

Accepts the block either as a dict or as a one-element list of one dict.
Mismatched shapes raise ``TypeError``; missing ``checkpoint`` key raises
``KeyError``; ``checkpointing=False`` raises ``ValueError``; schema
mismatch raises ``CheckpointException``.
"""
if isinstance(prompt, dict) and "checkpointResume" in prompt:
resume_block: dict[str, Any] = prompt
elif isinstance(prompt, list) and any(isinstance(c, dict) and "checkpointResume" in c for c in prompt):
invalid_keys = [k for c in prompt if isinstance(c, dict) for k in c if k != "checkpointResume"]
if invalid_keys:
raise TypeError(
f"content_types=<{invalid_keys}> | checkpointResume cannot be mixed with other content types"
)
if len(prompt) != 1:
raise TypeError(f"block_count=<{len(prompt)}> | only one checkpointResume block permitted per prompt")
resume_block = prompt[0]
else:
return False

if not self._checkpointing:
raise ValueError(
"Received checkpointResume block but agent was created with checkpointing=False. "
"Pass checkpointing=True when constructing the Agent."
)

payload = resume_block["checkpointResume"]
if not isinstance(payload, dict) or "checkpoint" not in payload:
raise KeyError("checkpoint | missing required key in checkpointResume block")

self._checkpoint = Checkpoint.from_dict(payload["checkpoint"])
return True

async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if self._interrupt_state.activated:
return []

if self._try_consume_checkpoint_resume(prompt):
return []

messages: Messages | None = None
if prompt is not None:
# Check if the latest message is toolUse
Expand Down
18 changes: 16 additions & 2 deletions strands-py/src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import BaseModel

from ..experimental.checkpoint import Checkpoint
from ..interrupt import Interrupt
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message
Expand All @@ -26,6 +27,9 @@ class AgentResult:
state: Additional state information from the event loop.
interrupts: List of interrupts if raised by user.
structured_output: Parsed structured output when structured_output_model was specified.
checkpoint: Checkpoint captured when the agent paused for durable execution.
Populated only when stop_reason == "checkpoint". See
strands.experimental.checkpoint for usage.
"""

stop_reason: StopReason
Expand All @@ -34,6 +38,7 @@ class AgentResult:
state: Any
interrupts: Sequence[Interrupt] | None = None
structured_output: BaseModel | None = None
checkpoint: Checkpoint | None = None

@property
def context_size(self) -> int | None:
Expand Down Expand Up @@ -94,15 +99,23 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
Returns:
AgentResult instance
Raises:
TypeError: If the data format is invalid@
TypeError: If the data format is invalid
"""
if data.get("type") != "agent_result":
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")

message = cast(Message, data.get("message"))
stop_reason = cast(StopReason, data.get("stop_reason"))
checkpoint_data = data.get("checkpoint")
checkpoint = Checkpoint.from_dict(checkpoint_data) if checkpoint_data else None

return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})
return cls(
message=message,
stop_reason=stop_reason,
metrics=EventLoopMetrics(),
state={},
checkpoint=checkpoint,
)

def to_dict(self) -> dict[str, Any]:
"""Convert this AgentResult to JSON-serializable dictionary.
Expand All @@ -114,4 +127,5 @@ def to_dict(self) -> dict[str, Any]:
"type": "agent_result",
"message": self.message,
"stop_reason": self.stop_reason,
"checkpoint": self.checkpoint.to_dict() if self.checkpoint else None,
}
85 changes: 83 additions & 2 deletions strands-py/src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from opentelemetry import trace as trace_api

from ..experimental.checkpoint import Checkpoint, CheckpointPosition
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import Tracer, get_tracer
Expand Down Expand Up @@ -117,6 +118,27 @@ async def _estimate_input_tokens(agent: "Agent") -> int:
)


def _build_checkpoint_stop_event(
agent: "Agent",
position: CheckpointPosition,
cycle_index: int,
message: Message,
request_state: Any,
) -> EventLoopStopEvent:
"""Build a checkpoint stop event. Used at ``after_model`` and ``after_tools``."""
checkpoint = Checkpoint(
position=position,
cycle_index=cycle_index,
)
return EventLoopStopEvent(
"checkpoint",
message,
agent.event_loop_metrics,
request_state,
checkpoint=checkpoint,
)


async def event_loop_cycle(
agent: "Agent",
invocation_state: dict[str, Any],
Expand Down Expand Up @@ -145,12 +167,16 @@ async def event_loop_cycle(
structured_output_context: Optional context for structured output management.

Yields:
Model and tool stream events. The last event is a tuple containing:
Model and tool stream events. The final ``EventLoopStopEvent`` payload
(``event["stop"]``) is a 7-tuple:

- StopReason: Reason the model stopped generating (e.g., "tool_use")
- StopReason: Reason the model stopped generating (e.g., "tool_use", "checkpoint")
- Message: The generated message from the model
- EventLoopMetrics: Updated metrics for the event loop
- Any: Updated request state
- Sequence[Interrupt] | None: Interrupts raised during the cycle, if any
- BaseModel | None: Structured output result, if any
- Checkpoint | None: Checkpoint captured when stop_reason == "checkpoint"

Raises:
EventLoopException: If an error occurs during execution
Expand All @@ -164,6 +190,18 @@ async def event_loop_cycle(
# Initialize state and get cycle trace
if "request_state" not in invocation_state:
invocation_state["request_state"] = {}

# Consume the resume marker (one-shot).
resume_context = agent._checkpoint
if resume_context is not None:
agent._checkpoint = None
# after_tools means that cycle finished; resume increments cycle_index.
next_cycle = (
resume_context.cycle_index + 1 if resume_context.position == "after_tools" else resume_context.cycle_index
)
agent._checkpoint_cycle_index = next_cycle
agent._checkpoint_resume_position = resume_context.position

attributes = {"event_loop_cycle_id": str(invocation_state.get("event_loop_cycle_id"))}
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
invocation_state["event_loop_cycle_trace"] = cycle_trace
Expand Down Expand Up @@ -223,6 +261,24 @@ async def event_loop_cycle(
)

if stop_reason == "tool_use":
# Emit after_model checkpoint, unless we just resumed from one.
if agent._checkpointing and not agent._cancel_signal.is_set():
resume_position = agent._checkpoint_resume_position
agent._checkpoint_resume_position = None
if resume_position != "after_model":
cycle_index = agent._checkpoint_cycle_index
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)
yield _build_checkpoint_stop_event(
agent=agent,
position="after_model",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# Handle tool execution
tool_events = _handle_tool_execution(
stop_reason,
Expand Down Expand Up @@ -640,6 +696,31 @@ async def _handle_tool_execution(
)
return

# Emit after_tools checkpoint. Only fires on tool_use cycles: a model that
# returns end_turn first never reaches this branch.
if agent._checkpointing and not agent._cancel_signal.is_set():
cycle_index = agent._checkpoint_cycle_index
agent._checkpoint_cycle_index = cycle_index + 1
yield _build_checkpoint_stop_event(
agent=agent,
position="after_tools",
cycle_index=cycle_index,
message=message,
request_state=invocation_state["request_state"],
)
return

# If checkpointing is on and cancel suppressed the checkpoint above, emit
# "cancelled" now to avoid an extra model call.
if agent._checkpointing and agent._cancel_signal.is_set():
yield EventLoopStopEvent(
"cancelled",
message,
agent.event_loop_metrics,
invocation_state["request_state"],
)
return

events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
Expand Down
Loading
Loading