diff --git a/strands-py/src/strands/session/repository_session_manager.py b/strands-py/src/strands/session/repository_session_manager.py index c1032a85ea..94603ec195 100644 --- a/strands-py/src/strands/session/repository_session_manager.py +++ b/strands-py/src/strands/session/repository_session_manager.py @@ -281,14 +281,32 @@ def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: content["toolUse"]["toolUseId"] for content in message["content"] if "toolUse" in content ] - # Check if there are more messages after the current toolUse message + next_message = messages[index + 1] + next_content = next_message["content"] + next_message_had_tool_results = any("toolResult" in content for content in next_content) tool_result_ids = [ - content["toolResult"]["toolUseId"] - for content in messages[index + 1]["content"] - if "toolResult" in content + content["toolResult"]["toolUseId"] for content in next_content if "toolResult" in content ] - missing_tool_use_ids = list(set(tool_use_ids) - set(tool_result_ids)) + if any(tool_result_id not in tool_use_ids for tool_result_id in tool_result_ids): + logger.warning( + "Session message history has orphaned toolResult blocks that do not match the preceding " + "toolUse. Removing them to maintain valid conversation structure." + ) + next_message["content"] = [ + content + for content in next_content + if "toolResult" not in content or content["toolResult"]["toolUseId"] in tool_use_ids + ] + tool_result_ids = [ + content["toolResult"]["toolUseId"] + for content in next_message["content"] + if "toolResult" in content + ] + + missing_tool_use_ids = [ + tool_use_id for tool_use_id in tool_use_ids if tool_use_id not in tool_result_ids + ] # If there are missing tool use ids, that means the messages history is broken if missing_tool_use_ids: logger.warning( @@ -300,7 +318,11 @@ def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: if tool_result_ids: # If there were any toolResult ids, that means only some of the content blocks are missing - messages[index + 1]["content"].extend(missing_content_blocks) + next_message["content"].extend(missing_content_blocks) + elif next_message_had_tool_results and not next_message["content"]: + # The following message only had orphaned toolResults. Reuse it instead of leaving + # an empty user message behind. + next_message["content"] = missing_content_blocks else: # The message following the toolUse was not a toolResult, so lets insert it messages.insert(index + 1, {"role": "user", "content": missing_content_blocks}) diff --git a/strands-py/tests/strands/session/test_repository_session_manager.py b/strands-py/tests/strands/session/test_repository_session_manager.py index 1d50481132..90c2ea8314 100644 --- a/strands-py/tests/strands/session/test_repository_session_manager.py +++ b/strands-py/tests/strands/session/test_repository_session_manager.py @@ -370,6 +370,55 @@ def test_fix_broken_tool_use_extends_partial_tool_results(existing_session_manag assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted." +def test_fix_broken_tool_use_removes_extra_tool_results(session_manager): + messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "wanted-123", "name": "test_tool", "input": {"input": "test"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "stale-456", "status": "error", "content": [{"text": "old"}]}}, + ], + }, + ] + + fixed_messages = session_manager._fix_broken_tool_use(messages) + + assert len(fixed_messages) == 2 + assert fixed_messages[1]["content"] == [ + {"toolResult": {"toolUseId": "wanted-123", "status": "error", "content": [{"text": "Tool was interrupted."}]}} + ] + + +def test_fix_broken_tool_use_removes_extra_partial_tool_results(session_manager): + messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "wanted-123", "name": "test_tool", "input": {"input": "test"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "stale-456", "status": "error", "content": [{"text": "old"}]}}, + {"toolResult": {"toolUseId": "wanted-123", "status": "success", "content": [{"text": "ok"}]}}, + ], + }, + ] + + fixed_messages = session_manager._fix_broken_tool_use(messages) + + assert len(fixed_messages) == 2 + assert fixed_messages[1]["content"] == [ + {"toolResult": {"toolUseId": "wanted-123", "status": "success", "content": [{"text": "ok"}]}}, + ] + + def test_fix_broken_tool_use_handles_multiple_orphaned_tools(existing_session_manager): """Test fixing multiple orphaned toolUse messages."""