diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index aef4424a49..0ad0c90a87 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -37,6 +37,7 @@ from ..features import FeatureName from ..features import is_feature_enabled from ..utils.variant_utils import GoogleLLMVariant +from ._gemini_schema_util import _sanitize_schema_formats_for_gemini _py_type_2_schema_type = { 'str': types.Type.STRING, @@ -366,7 +367,7 @@ def from_function_with_options( ) parameters_json_schema[name] = types.Schema.model_validate( - json_schema_dict + _sanitize_schema_formats_for_gemini(json_schema_dict) ) if param.default is not inspect.Parameter.empty: if param.default is not None: diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 537094da39..6c20b62724 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -361,3 +361,29 @@ def complex_tool( ), }, ) + + +def test_sanitized_in_json_schema_fallback(): + """Test schema is sanitzed for complex union type.""" + + def complex_tool( + query: str, + mode: str = 'default', + tags: dict[str, str] | None = None, + ) -> str: + return query + + declaration = _automatic_function_calling_util.from_function_with_options( + complex_tool, GoogleLLMVariant.GEMINI_API + ) + + assert declaration.parameters.properties['tags'] == types.Schema( + any_of=[ + types.Schema( + # should not contain `additional_properties={'type': 'string'} from pydantic + type=types.Type.OBJECT, + ), + types.Schema(type=types.Type.NULL), + ], + nullable=True, + )