Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions strands-py/src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..models.model import Model
from ..tools import InvalidToolUseNameException
from ..tools._validator import TOOL_INPUT_PARSE_ERROR_KEY
from ..tools.tools import validate_tool_use_name
from ..types._events import (
CitationStreamEvent,
Expand Down Expand Up @@ -282,12 +283,22 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:

if current_tool_use:
if "input" not in current_tool_use:
current_tool_use["input"] = ""

try:
current_tool_use["input"] = json.loads(current_tool_use["input"])
except ValueError:
current_tool_use["input"] = {}
else:
try:
current_tool_use["input"] = json.loads(current_tool_use["input"])
except ValueError as e:
logger.warning(
"tool_name=<%s>, tool_use_id=<%s> | failed to parse tool input JSON",
current_tool_use.get("name"),
current_tool_use.get("toolUseId"),
)
current_tool_use["input"] = {
TOOL_INPUT_PARSE_ERROR_KEY: (
f"Invalid JSON in tool input for '{current_tool_use.get('name', 'unknown')}': {e}. "
"Retry with a valid JSON object."
)
}

tool_use_id = current_tool_use["toolUseId"]
tool_use_name = current_tool_use["name"]
Expand Down
18 changes: 18 additions & 0 deletions strands-py/src/strands/tools/_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ..types.content import Message
from ..types.tools import ToolResult, ToolUse

TOOL_INPUT_PARSE_ERROR_KEY = "__strands_tool_input_parse_error__"


def validate_and_prepare_tools(
message: Message,
Expand All @@ -28,6 +30,22 @@ def validate_and_prepare_tools(
# Avoid modifying original `tool_uses` variable during iteration
tool_uses_copy = tool_uses.copy()
for tool in tool_uses_copy:
parse_error = (
tool.get("input", {}).get(TOOL_INPUT_PARSE_ERROR_KEY) if isinstance(tool.get("input"), dict) else None
)
if parse_error:
tool_uses.remove(tool)
invalid_tool_use_ids.append(tool["toolUseId"])
tool_uses.append(tool)
tool_results.append(
{
"toolUseId": tool["toolUseId"],
"status": "error",
"content": [{"text": f"Error: {parse_error}"}],
}
)
continue

try:
validate_tool_use(tool)
except InvalidToolUseNameException as e:
Expand Down
20 changes: 20 additions & 0 deletions strands-py/tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import strands
import strands.event_loop
from strands.tools._validator import TOOL_INPUT_PARSE_ERROR_KEY
from strands.types._events import ModelStopReason, TypedEvent
from strands.types.content import Message, Messages
from strands.types.streaming import (
Expand Down Expand Up @@ -310,6 +311,25 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s
assert tru_callback_event == exp_callback_event


def test_handle_content_block_stop_marks_malformed_tool_input(caplog):
state = {
"content": [],
"current_tool_use": {"toolUseId": "123", "name": "search", "input": '{"query": "unterminated'},
"text": "",
"reasoningText": "",
"citationsContent": [],
"redactedContent": b"",
}

with caplog.at_level("WARNING", logger="strands.event_loop.streaming"):
updated_state = strands.event_loop.streaming.handle_content_block_stop(state)

tool_input = updated_state["content"][0]["toolUse"]["input"]
assert TOOL_INPUT_PARSE_ERROR_KEY in tool_input
assert "Invalid JSON in tool input for 'search'" in tool_input[TOOL_INPUT_PARSE_ERROR_KEY]
assert "tool_name=<search>, tool_use_id=<123> | failed to parse tool input JSON" in caplog.text


@pytest.mark.parametrize(
("state", "exp_updated_state"),
[
Expand Down
31 changes: 31 additions & 0 deletions strands-py/tests/strands/tools/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from strands.tools import _validator
from strands.tools._validator import TOOL_INPUT_PARSE_ERROR_KEY
from strands.types.content import Message


Expand Down Expand Up @@ -49,3 +50,33 @@ def test_validate_and_prepare_tools():
assert tru_tool_uses == exp_tool_uses
assert tru_tool_results == exp_tool_results
assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids


def test_validate_and_prepare_tools_turns_malformed_input_into_tool_result():
message: Message = {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "t1",
"name": "search",
"input": {TOOL_INPUT_PARSE_ERROR_KEY: "Invalid JSON in tool input for 'search'"},
}
}
],
}

tool_uses = []
tool_results = []
invalid_tool_use_ids = []

_validator.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)

assert invalid_tool_use_ids == ["t1"]
assert tool_results == [
{
"toolUseId": "t1",
"status": "error",
"content": [{"text": "Error: Invalid JSON in tool input for 'search'"}],
}
]