diff --git a/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md b/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md new file mode 100644 index 00000000000..25dfde1e67f --- /dev/null +++ b/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md @@ -0,0 +1,8 @@ +--- +# Change versionKind to one of: internal, fix, dependencies, feature, deprecation, breaking +changeKind: fix +packages: + - "@typespec/http-client-python" +--- + +[python] add `typeddict` `models-mode` for Python HTTP client emitter to generated `TypedDict`s for input models diff --git a/packages/http-client-python/emitter/src/types.ts b/packages/http-client-python/emitter/src/types.ts index c8dbcf9b584..89577ff0bdf 100644 --- a/packages/http-client-python/emitter/src/types.ts +++ b/packages/http-client-python/emitter/src/types.ts @@ -290,7 +290,7 @@ function emitModel(context: PythonSdkContext, type: SdkModelType): Record>, properties: new Array>(), snakeCaseName: camelToSnakeCase(type.name), - base: "dpg", + base: (context.emitContext.options as any)["models-mode"] === "typeddict" ? "typeddict" : "dpg", internal: type.access === "internal", crossLanguageDefinitionId: type.crossLanguageDefinitionId, usage: type.usage, diff --git a/packages/http-client-python/eng/scripts/ci/regenerate-common.ts b/packages/http-client-python/eng/scripts/ci/regenerate-common.ts index 8dd426152b4..9abd4a57a8f 100644 --- a/packages/http-client-python/eng/scripts/ci/regenerate-common.ts +++ b/packages/http-client-python/eng/scripts/ci/regenerate-common.ts @@ -105,9 +105,17 @@ export const BASE_AZURE_EMITTER_OPTIONS: Record< "package-name": "client-structure-twooperationgroup", namespace: "client.structure.twooperationgroup", }, - "client/naming": { - namespace: "client.naming.main", - }, + "client/naming": [ + { + namespace: "client.naming.main", + }, + { + "package-name": "client-naming-typeddict", + namespace: "client.naming.typeddict", + "models-mode": "typeddict", + "generate-test": "false", + }, + ], "client/overload": { namespace: "client.overload", }, @@ -203,14 +211,30 @@ export const BASE_EMITTER_OPTIONS: Record< "package-name": "typetest-model-nesteddiscriminator", namespace: "typetest.model.nesteddiscriminator", }, - "type/model/inheritance/not-discriminated": { - "package-name": "typetest-model-notdiscriminated", - namespace: "typetest.model.notdiscriminated", - }, - "type/model/inheritance/single-discriminator": { - "package-name": "typetest-model-singlediscriminator", - namespace: "typetest.model.singlediscriminator", - }, + "type/model/inheritance/not-discriminated": [ + { + "package-name": "typetest-model-notdiscriminated", + namespace: "typetest.model.notdiscriminated", + }, + { + "package-name": "typetest-model-notdiscriminated-typeddict", + namespace: "typetest.model.notdiscriminated.typeddict", + "models-mode": "typeddict", + "generate-test": "false", + }, + ], + "type/model/inheritance/single-discriminator": [ + { + "package-name": "typetest-model-singlediscriminator", + namespace: "typetest.model.singlediscriminator", + }, + { + "package-name": "typetest-model-singlediscriminator-typeddict", + namespace: "typetest.model.singlediscriminator.typeddict", + "models-mode": "typeddict", + "generate-test": "false", + }, + ], "type/model/inheritance/recursive": [ { "package-name": "typetest-model-recursive", diff --git a/packages/http-client-python/generator/pygen/codegen/models/__init__.py b/packages/http-client-python/generator/pygen/codegen/models/__init__.py index a1d9f9a4dbc..5848854ed86 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/models/__init__.py @@ -9,7 +9,7 @@ from .base_builder import BaseBuilder, ParameterListType from .code_model import CodeModel from .client import Client -from .model_type import ModelType, JSONModelType, DPGModelType, MsrestModelType +from .model_type import ModelType, JSONModelType, DPGModelType, MsrestModelType, TypedDictModelType from .dictionary_type import DictionaryType from .list_type import ListType from .combined_type import CombinedType @@ -171,6 +171,8 @@ def build_type(yaml_data: dict[str, Any], code_model: CodeModel) -> BaseType: model_type = JSONModelType elif yaml_data["base"] == "dpg": model_type = DPGModelType # type: ignore + elif yaml_data["base"] == "typeddict": + model_type = TypedDictModelType # type: ignore else: model_type = MsrestModelType # type: ignore response = model_type(yaml_data, code_model) diff --git a/packages/http-client-python/generator/pygen/codegen/models/code_model.py b/packages/http-client-python/generator/pygen/codegen/models/code_model.py index 81f5f20bf8b..6a28ddfcaad 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/code_model.py +++ b/packages/http-client-python/generator/pygen/codegen/models/code_model.py @@ -251,7 +251,7 @@ def need_utils_folder(self, async_mode: bool, client_namespace: str) -> bool: return ( self.need_utils_utils(async_mode, client_namespace) or self.need_utils_serialization - or self.options["models-mode"] == "dpg" + or self.options["models-mode"] in ("dpg", "typeddict") ) @property @@ -271,7 +271,7 @@ def need_utils_form_data(self, async_mode: bool, client_namespace: str) -> bool: (not async_mode) and self.is_top_namespace(client_namespace) and self.has_form_data - and self.options["models-mode"] == "dpg" + and self.options["models-mode"] in ("dpg", "typeddict") ) def need_utils_etag(self, client_namespace: str) -> bool: diff --git a/packages/http-client-python/generator/pygen/codegen/models/lro_operation.py b/packages/http-client-python/generator/pygen/codegen/models/lro_operation.py index b7673d4d865..1a3d37bb85b 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/lro_operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/lro_operation.py @@ -125,7 +125,7 @@ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport: ImportType.SDKCORE, ) if ( - self.code_model.options["models-mode"] == "dpg" + self.code_model.options["models-mode"] in ("dpg", "typeddict") and self.lro_response and self.lro_response.type and self.lro_response.type.type == "model" diff --git a/packages/http-client-python/generator/pygen/codegen/models/model_type.py b/packages/http-client-python/generator/pygen/codegen/models/model_type.py index d0784f81efb..fda02ef13a7 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/model_type.py +++ b/packages/http-client-python/generator/pygen/codegen/models/model_type.py @@ -374,3 +374,27 @@ def imports(self, **kwargs: Any) -> FileImport: if self.flattened_property: file_import.add_submodule_import("typing", "Any", ImportType.STDLIB) return file_import + + +class TypedDictModelType(DPGModelType): + base = "typeddict" + + def type_annotation(self, **kwargs: Any) -> str: + if kwargs.pop("is_response", False): + return "JSON" + return super().type_annotation(**kwargs) + + def docstring_type(self, **kwargs: Any) -> str: + if kwargs.pop("is_response", False): + return "JSON" + return super().docstring_type(**kwargs) + + def docstring_text(self, **kwargs: Any) -> str: + if kwargs.pop("is_response", False): + return "JSON" + return super().docstring_text(**kwargs) + + def imports(self, **kwargs: Any) -> FileImport: + file_import = super().imports(**kwargs) + file_import.define_mutable_mapping_type() + return file_import diff --git a/packages/http-client-python/generator/pygen/codegen/models/operation.py b/packages/http-client-python/generator/pygen/codegen/models/operation.py index c5f15593893..60737410140 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/operation.py @@ -51,7 +51,7 @@ def is_internal(target: Optional[BaseType]) -> bool: - return isinstance(target, ModelType) and target.base == "dpg" and target.internal + return isinstance(target, ModelType) and target.base in ("dpg", "typeddict") and target.internal class OperationBase( # pylint: disable=too-many-public-methods,too-many-instance-attributes @@ -176,7 +176,7 @@ def response_docstring_text(self, **kwargs) -> str: retval = self._response_docstring_helper("docstring_text", **kwargs) if not self.code_model.options["version-tolerant"]: retval += " or the result of cls(response)" - if self.code_model.options["models-mode"] == "dpg" and any( + if self.code_model.options["models-mode"] in ("dpg", "typeddict") and any( isinstance(r.type, ModelType) for r in self.responses ): r = next(r for r in self.responses if isinstance(r.type, ModelType)) @@ -209,7 +209,7 @@ def default_error_deserialization(self, serialize_namespace: str) -> Optional[st f"{exception_schema.type_annotation(skip_quote=True, serialize_namespace=serialize_namespace)}," f"{pylint_disable}" ) - return None if self.code_model.options["models-mode"] == "dpg" else "'object'," + return None if self.code_model.options["models-mode"] in ("dpg", "typeddict") else "'object'," @property def non_default_errors(self) -> list[Response]: @@ -421,7 +421,7 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements for overload in self.overloads: if overload.parameters.has_body: file_import.merge(overload.parameters.body_parameter.type.imports(**kwargs)) - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): relative_path = self.code_model.get_relative_import_path( serialize_namespace, module_name="_utils.model_base" ) @@ -449,7 +449,7 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements file_import.add_import("json", ImportType.STDLIB) if self.enable_import_deserialize_xml: file_import.add_submodule_import(relative_path, "_deserialize_xml", ImportType.LOCAL) - if any( + if self.code_model.options["models-mode"] != "typeddict" and any( r.type and not isinstance(r.type, BinaryIteratorType) and not xml_serializable(str(r.default_content_type)) diff --git a/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py b/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py index f64363ed5fb..e6fb19fe3db 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py @@ -181,12 +181,14 @@ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport: "case_insensitive_dict", ImportType.SDKCORE, ) - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): relative_path = self.code_model.get_relative_import_path( serialize_namespace, module_name="_utils.model_base" ) file_import.merge(self.item_type.imports(**kwargs)) - if self.default_error_deserialization(serialize_namespace) or self.need_deserialize: + if ( + self.default_error_deserialization(serialize_namespace) or self.need_deserialize + ) and self.code_model.options["models-mode"] != "typeddict": file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL) if self.is_xml_paging: file_import.add_submodule_import("xml.etree", "ElementTree", ImportType.STDLIB, alias="ET") diff --git a/packages/http-client-python/generator/pygen/codegen/models/parameter.py b/packages/http-client-python/generator/pygen/codegen/models/parameter.py index 26eb0ba5c9a..34a765b849a 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/parameter.py +++ b/packages/http-client-python/generator/pygen/codegen/models/parameter.py @@ -336,7 +336,7 @@ def method_location( # pylint: disable=too-many-return-statements ) -> ParameterMethodLocation: if not self.in_method_signature: raise ValueError(f"Parameter '{self.client_name}' is not in the method.") - if self.code_model.options["models-mode"] == "dpg" and self.in_flattened_body: + if self.code_model.options["models-mode"] in ("dpg", "typeddict") and self.in_flattened_body: return ParameterMethodLocation.KEYWORD_ONLY if self.grouper: return ParameterMethodLocation.POSITIONAL diff --git a/packages/http-client-python/generator/pygen/codegen/models/request_builder_parameter.py b/packages/http-client-python/generator/pygen/codegen/models/request_builder_parameter.py index c73df6db5af..810fbf3b7e7 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/request_builder_parameter.py +++ b/packages/http-client-python/generator/pygen/codegen/models/request_builder_parameter.py @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs) -> None: if ( isinstance(self.type, (BinaryType, StringType)) or any("xml" in ct for ct in self.content_types) - or self.code_model.options["models-mode"] == "dpg" + or self.code_model.options["models-mode"] in ("dpg", "typeddict") ): self.client_name = "content" else: @@ -40,7 +40,9 @@ def type_annotation(self, **kwargs: Any) -> str: @property def in_method_signature(self) -> bool: return ( - super().in_method_signature and not self.is_partial_body and self.code_model.options["models-mode"] != "dpg" + super().in_method_signature + and not self.is_partial_body + and self.code_model.options["models-mode"] not in ("dpg", "typeddict") ) @property diff --git a/packages/http-client-python/generator/pygen/codegen/models/response.py b/packages/http-client-python/generator/pygen/codegen/models/response.py index d93d46bd897..40599eb47ae 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/response.py +++ b/packages/http-client-python/generator/pygen/codegen/models/response.py @@ -95,6 +95,7 @@ def serialization_type(self, **kwargs: Any) -> str: def type_annotation(self, **kwargs: Any) -> str: if self.type: kwargs["is_operation_file"] = True + kwargs["is_response"] = True type_annotation = self.type.type_annotation(**kwargs) if self.nullable: return f"Optional[{type_annotation}]" @@ -102,11 +103,13 @@ def type_annotation(self, **kwargs: Any) -> str: return "None" def docstring_text(self, **kwargs: Any) -> str: + kwargs["is_response"] = True if self.nullable and self.type: return f"{self.type.docstring_text(**kwargs)} or None" return self.type.docstring_text(**kwargs) if self.type else "None" def docstring_type(self, **kwargs: Any) -> str: + kwargs["is_response"] = True if self.nullable and self.type: return f"{self.type.docstring_type(**kwargs)} or None" return self.type.docstring_type(**kwargs) if self.type else "None" diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py index a95b1fd2f27..c769535dda4 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py @@ -22,10 +22,11 @@ ModelType, EnumType, ) +from ..models.primitive_types import DatetimeType, ByteArraySchema, BinaryType from .enum_serializer import EnumSerializer from .general_serializer import GeneralSerializer from .model_init_serializer import ModelInitSerializer -from .model_serializer import DpgModelSerializer, MsrestModelSerializer +from .model_serializer import DpgModelSerializer, MsrestModelSerializer, TypedDictModelSerializer from .operations_init_serializer import OperationsInitSerializer from .operation_groups_serializer import OperationGroupsSerializer from .request_builders_serializer import RequestBuildersSerializer @@ -118,8 +119,55 @@ def keep_version_file(self) -> bool: # If parsing the version fails, we assume the version file is not valid and overwrite. return False + @staticmethod + def _validate_typeddict_models(code_model: CodeModel) -> None: + """Validate that models are compatible with typeddict mode. + + Raises ValueError if any model uses unsupported features: + readonly properties, datetime types, bytes types, + or additional properties (extends Record). + """ + unsupported: list[str] = [] + for model in code_model.model_types: + if model.base != "typeddict": + continue + model_name = model.name + + for prop in model.properties: + # Readonly + if prop.readonly: + unsupported.append( + f"Model '{model_name}' has readonly property '{prop.client_name}', " + "which is not supported in typeddict mode." + ) + # Datetime + if isinstance(prop.type, DatetimeType): + unsupported.append( + f"Model '{model_name}' has datetime property '{prop.client_name}', " + "which is not supported in typeddict mode." + ) + # Bytes + if isinstance(prop.type, (ByteArraySchema, BinaryType)): + unsupported.append( + f"Model '{model_name}' has bytes property '{prop.client_name}', " + "which is not supported in typeddict mode." + ) + # Additional properties (extends Record) + if prop.client_name == "additional_properties": + unsupported.append( + f"Model '{model_name}' has additional properties (extends Record), " + "which is not supported in typeddict mode." + ) + + if unsupported: + raise ValueError("The following models are not compatible with typeddict mode:\n" + "\n".join(unsupported)) + # pylint: disable=too-many-branches def serialize(self) -> None: + # Validate typeddict mode constraints + if self.code_model.options.get("models-mode") == "typeddict": + self._validate_typeddict_models(self.code_model) + # remove existing folders when generate from tsp if self.code_model.is_tsp and self.code_model.options.get("clear-output-folder"): # remove generated_samples and generated_tests folder @@ -294,7 +342,13 @@ def _serialize_and_write_models_folder( ) -> None: # Write the models folder models_path = self.code_model.get_generation_dir(namespace) / "models" - serializer = DpgModelSerializer if self.code_model.options["models-mode"] == "dpg" else MsrestModelSerializer + models_mode = self.code_model.options["models-mode"] + if models_mode == "dpg": + serializer = DpgModelSerializer + elif models_mode == "typeddict": + serializer = TypedDictModelSerializer + else: + serializer = MsrestModelSerializer if self.code_model.has_non_json_models(models): self.write_file( models_path / Path(f"{self.code_model.models_filename}.py"), @@ -483,7 +537,7 @@ def _serialize_and_write_utils_folder(self, env: Environment, namespace: str): ) # write _model_base.py - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): self.write_file( utils_folder_path / Path("model_base.py"), general_serializer.serialize_model_base_file(), diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py index e43dcd916eb..9c5b14440af 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py @@ -29,6 +29,7 @@ CombinedType, JSONModelType, DPGModelType, + TypedDictModelType, ParameterListType, ByteArraySchema, ) @@ -711,7 +712,7 @@ def _serialize_body_parameter(self, builder: OperationType) -> list[str]: f"_{body_kwarg_name} = self._serialize.body({body_param.client_name}, " f"'{serialization_type}'{is_xml_cmd}{serialization_ctxt_cmd})" ) - elif self.code_model.options["models-mode"] == "dpg": + elif self.code_model.options["models-mode"] in ("dpg", "typeddict"): if json_serializable(body_param.default_content_type): if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore create_body_call = ( @@ -790,9 +791,10 @@ def _initialize_overloads(self, builder: OperationType, is_paging: bool = False) overload.request_builder.parameters.body_parameter.client_name for overload in builder.overloads ] all_dpg_model_overloads = False - if self.code_model.options["models-mode"] == "dpg" and builder.overloads: + if self.code_model.options["models-mode"] in ("dpg", "typeddict") and builder.overloads: all_dpg_model_overloads = all( - isinstance(o.parameters.body_parameter.type, DPGModelType) for o in builder.overloads + isinstance(o.parameters.body_parameter.type, (DPGModelType, TypedDictModelType)) + for o in builder.overloads ) if not all_dpg_model_overloads: for v in sorted(set(client_names), key=client_names.index): @@ -997,7 +999,13 @@ def response_deserialization( # pylint: disable=too-many-statements deserialize_code.append(f" '{serialization_type}',{pylint_disable}") deserialize_code.append(" pipeline_response.http_response") deserialize_code.append(")") - elif self.code_model.options["models-mode"] == "dpg": + elif self.code_model.options["models-mode"] == "typeddict": + if builder.has_stream_response: + deserialize_code.append("deserialized = response.content") + else: + response_attr = "json" if json_serializable(str(response.default_content_type)) else "text" + deserialize_code.append(f"deserialized = response.{response_attr}()") + elif self.code_model.options["models-mode"] in ("dpg", "typeddict"): if builder.has_stream_response: deserialize_code.append("deserialized = response.content") else: @@ -1071,7 +1079,7 @@ def handle_error_response( # pylint: disable=too-many-statements, too-many-bran type_annotation = e.type.type_annotation( # type: ignore is_operation_file=True, skip_quote=True, serialize_namespace=self.serialize_namespace ) - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): if xml_serializable(str(e.default_content_type)): fn = "_failsafe_deserialize_xml" else: @@ -1113,7 +1121,7 @@ def handle_error_response( # pylint: disable=too-many-statements, too-many-bran type_annotation = e.type.type_annotation( # type: ignore is_operation_file=True, skip_quote=True, serialize_namespace=self.serialize_namespace ) - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): if xml_serializable(str(e.default_content_type)): retval.append( " error = _failsafe_deserialize_xml(" @@ -1141,7 +1149,7 @@ def handle_error_response( # pylint: disable=too-many-statements, too-many-bran indent = " " if builder.non_default_errors else " " if builder.non_default_errors: retval.append(" else:") - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): default_exception = next(e for e in builder.exceptions if "default" in e.status_codes and e.type) if xml_serializable(str(default_exception.default_content_type)): fn = "_failsafe_deserialize_xml" @@ -1410,7 +1418,7 @@ def _extract_data_callback( # pylint: disable=too-many-statements,too-many-bran f"self._deserialize(\n {deserialize_type},{pylint_disable}\n pipeline_response{suffix}\n)" ) retval.append(f" deserialized = {deserialized}") - elif self.code_model.options["models-mode"] == "dpg": + elif self.code_model.options["models-mode"] in ("dpg", "typeddict"): # we don't want to generate paging models for DPG retval.append(f" deserialized = {deserialized}") else: @@ -1428,7 +1436,7 @@ def _extract_data_callback( # pylint: disable=too-many-statements,too-many-bran "".join([f'.get("{i}", {{}})' for i in item_name_array[:-1]]) + f'.get("{item_name_array[-1]}", [])' ) pylint_disable = "" - if self.code_model.options["models-mode"] == "dpg": + if self.code_model.options["models-mode"] in ("dpg", "typeddict"): item_type = builder.item_type.type_annotation( is_operation_file=True, serialize_namespace=self.serialize_namespace ) @@ -1605,7 +1613,7 @@ def get_long_running_output(self, builder: LROOperationType) -> list[str]: retval.append(" response_headers = {}") if ( not self.code_model.options["models-mode"] - or self.code_model.options["models-mode"] == "dpg" + or self.code_model.options["models-mode"] in ("dpg", "typeddict") or builder.lro_response.headers ): retval.append(" response = pipeline_response.http_response") diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py index d428a113e5e..27d4451ae2f 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py @@ -382,3 +382,120 @@ def global_pylint_disables(self) -> str: if final_result: return "# pylint: disable=" + ", ".join(final_result) return "" + + +class TypedDictModelSerializer(_ModelSerializer): + def _is_parent_discriminated_base(self, model: ModelType) -> bool: + """Check if any parent of this model is a discriminated base (has discriminated_subtypes).""" + return any(p.discriminated_subtypes for p in model.parents) + + def _reorder_models(self, models: list[ModelType]) -> list[ModelType]: + """Reorder so discriminated base Union aliases come after all their subtypes.""" + bases = [m for m in models if m.discriminated_subtypes] + non_bases = [m for m in models if not m.discriminated_subtypes] + return non_bases + bases + + def serialize(self) -> str: + template = self.env.get_template("model_container.py.jinja2") + return template.render( + code_model=self.code_model, + imports=FileImportSerializer(self.imports()), + str=str, + serializer=self, + models=self._reorder_models(self.models), + ) + + def imports(self) -> FileImport: + file_import = FileImport(self.code_model) + has_required = False + has_discriminated_union = False + for model in self.models: + if model.base == "json": + continue + if model.discriminated_subtypes: + has_discriminated_union = True + file_import.merge( + model.imports( + is_operation_file=False, + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.MODEL, + ) + ) + for prop in model.properties: + file_import.merge( + prop.imports( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.MODEL, + called_by_property=True, + ) + ) + if not (prop.optional or prop.client_default_value is not None): + has_required = True + for parent in model.parents: + if parent.client_namespace != model.client_namespace and not parent.discriminated_subtypes: + file_import.add_submodule_import( + self.code_model.get_relative_import_path( + self.serialize_namespace, + self.code_model.get_imported_namespace_for_model(parent.client_namespace), + ), + parent.name, + ImportType.LOCAL, + ) + file_import.add_submodule_import("typing_extensions", "TypedDict", ImportType.STDLIB) + if has_required: + file_import.add_submodule_import("typing_extensions", "Required", ImportType.STDLIB) + if has_discriminated_union: + file_import.add_submodule_import("typing", "Union", ImportType.STDLIB) + return file_import + + def declare_model(self, model: ModelType) -> str: + # If the model's parent is a discriminated base, don't inherit from it + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + basename = ", ".join([m.name for m in non_discriminated_parents]) + return f"class {model.name}({basename}):{model.pylint_disable()}" + return f"class {model.name}(TypedDict, total=False):{model.pylint_disable()}" + + @staticmethod + def get_properties_to_declare(model: ModelType) -> list[Property]: + # Only exclude inherited properties from non-discriminated parents + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + parent_properties = [p for bm in non_discriminated_parents for p in bm.properties] + properties_to_declare = [ + p + for p in model.properties + if not any( + p.client_name == pp.client_name + and p.type_annotation() == pp.type_annotation() + and not p.is_base_discriminator + for pp in parent_properties + ) + ] + else: + properties_to_declare = model.properties + return properties_to_declare + + def declare_property(self, prop: Property) -> str: + type_annotation = prop.type_annotation(serialize_namespace=self.serialize_namespace) + is_optional = prop.optional or prop.client_default_value is not None + if is_optional: + return f"{prop.wire_name}: {type_annotation}" + return f"{prop.wire_name}: Required[{type_annotation}]" + + def initialize_properties(self, model: ModelType) -> list[str]: + return [] + + def need_init(self, model: ModelType) -> bool: + return False + + def discriminated_subtypes_union(self, model: ModelType) -> str: + subtypes = list(model.discriminated_subtypes.values()) + subtype_names = [s.name for s in subtypes] + return f"{model.name} = Union[{', '.join(subtype_names)}]" + + def is_discriminated_base(self, model: ModelType) -> bool: + return bool(model.discriminated_subtypes) + + def global_pylint_disables(self) -> str: + return "" diff --git a/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 index dc98f999c45..6e033ba77fb 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 @@ -13,5 +13,7 @@ {% include "model_dpg.py.jinja2" %} {% elif model.base == "msrest" %} {% include "model_msrest.py.jinja2" %} +{% elif model.base == "typeddict" %} +{% include "model_typeddict.py.jinja2" %} {% endif %} {% endfor %} diff --git a/packages/http-client-python/generator/pygen/codegen/templates/model_typeddict.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/model_typeddict.py.jinja2 new file mode 100644 index 00000000000..8176626fe3f --- /dev/null +++ b/packages/http-client-python/generator/pygen/codegen/templates/model_typeddict.py.jinja2 @@ -0,0 +1,28 @@ +{# actual template starts here #} +{% import "macros.jinja2" as macros %} + +{% if serializer.is_discriminated_base(model) %} +{{ serializer.discriminated_subtypes_union(model) }} +{% else %} + +{{ serializer.declare_model(model) }} + """{{ op_tools.wrap_string(model.description(is_operation_file=False), "\n ") }} + + {% if model.properties != None %} + {% for p in model.properties %} + {% for line in serializer.variable_documentation_string(p) %} + {{ macros.wrap_model_string(line, '\n ') -}} + {% endfor %} + {% endfor %} + {% endif %} + """ + + {% for p in serializer.get_properties_to_declare(model)%} + {{ serializer.declare_property(p) }} + {% set prop_description = p.description(is_operation_file=False).replace('"', '\\"') %} + {% if prop_description %} + """{{ macros.wrap_model_string(prop_description, '\n ', '\"\"\"') -}} + {% endif %} + {% endfor %} +{% endif %} + diff --git a/packages/http-client-python/generator/pygen/preprocess/__init__.py b/packages/http-client-python/generator/pygen/preprocess/__init__.py index 6d3344059a3..5117a9227e5 100644 --- a/packages/http-client-python/generator/pygen/preprocess/__init__.py +++ b/packages/http-client-python/generator/pygen/preprocess/__init__.py @@ -216,7 +216,7 @@ def add_body_param_type( model_type = ( body_parameter["type"] if origin_type == "model" else body_parameter["type"].get("elementType", {}) ) - is_dpg_model = model_type.get("base") == "dpg" + is_dpg_model = model_type.get("base") in ("dpg", "typeddict") body_parameter["type"] = { "type": "combined", "types": [body_parameter["type"]], diff --git a/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py b/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py new file mode 100644 index 00000000000..40786f16e39 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from client.naming.typeddict import NamingClient, models + + +@pytest.fixture +def client(): + with NamingClient() as client: + yield client + + +def test_client(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'client_name'.""" + client.property.client(models.ClientNameModel(defaultName=True)) + + +def test_language(client: NamingClient): + """TypedDict uses wire name 'defaultName', not language-specific name 'python_name'.""" + client.property.language(models.LanguageClientNameModel(defaultName=True)) + + +def test_compatible_with_encoded_name(client: NamingClient): + """TypedDict uses encoded wire name 'wireName', not client name 'client_name'.""" + client.property.compatible_with_encoded_name( + models.ClientNameAndJsonEncodedNameModel(wireName=True) + ) + + +def test_operation(client: NamingClient): + client.client_name() + + +def test_parameter(client: NamingClient): + client.parameter(client_name="true") + + +def test_header_request(client: NamingClient): + client.header.request(client_name="true") + + +def test_header_response(client: NamingClient): + assert client.header.response(cls=lambda x, y, z: z)["default-name"] == "true" + + +def test_model_client(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'default_name'.""" + client.model_client.client(models.ClientModel(defaultName=True)) + + +def test_model_language(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'default_name'.""" + client.model_client.language(models.PythonModel(defaultName=True)) + + +def test_union_enum_member_name(client: NamingClient): + client.union_enum.union_enum_member_name(models.ExtensibleEnum.CLIENT_ENUM_VALUE1) + + +def test_union_enum_name(client: NamingClient): + client.union_enum.union_enum_name(models.ClientExtensibleEnum.ENUM_VALUE1) diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py new file mode 100644 index 00000000000..782791ab1e7 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py @@ -0,0 +1,37 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from typetest.model.notdiscriminated.typeddict import NotDiscriminatedClient +from typetest.model.notdiscriminated.typeddict.models import Siamese + + +@pytest.fixture +def client(): + with NotDiscriminatedClient() as client: + yield client + + +@pytest.fixture +def valid_body(): + return Siamese(name="abc", age=32, smart=True) + + +def test_get_valid(client, valid_body): + result = client.get_valid() + assert result["name"] == "abc" + assert result["age"] == 32 + assert result["smart"] is True + + +def test_post_valid(client, valid_body): + client.post_valid(valid_body) + + +def test_put_valid(client, valid_body): + result = client.put_valid(valid_body) + assert result["name"] == "abc" + assert result["age"] == 32 + assert result["smart"] is True diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py new file mode 100644 index 00000000000..19335d1bc69 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from typetest.model.singlediscriminator.typeddict import SingleDiscriminatorClient +from typetest.model.singlediscriminator.typeddict.models import Sparrow, Eagle + + +@pytest.fixture +def client(): + with SingleDiscriminatorClient() as client: + yield client + + +@pytest.fixture +def valid_body(): + return Sparrow(wingspan=1, kind="sparrow") + + +def test_get_model(client): + result = client.get_model() + assert result["wingspan"] == 1 + assert result["kind"] == "sparrow" + + +def test_put_model(client, valid_body): + client.put_model(valid_body) + + +@pytest.fixture +def recursive_body(): + return Eagle( + wingspan=5, + kind="eagle", + partner={"wingspan": 2, "kind": "goose"}, + friends=[{"wingspan": 2, "kind": "seagull"}], + hate={"key3": {"wingspan": 1, "kind": "sparrow"}}, + ) + + +def test_get_recursive_model(client): + result = client.get_recursive_model() + assert result["wingspan"] == 5 + assert result["kind"] == "eagle" + assert result["partner"]["kind"] == "goose" + assert result["friends"][0]["kind"] == "seagull" + assert result["hate"]["key3"]["kind"] == "sparrow" + + +def test_put_recursive_model(client, recursive_body): + client.put_recursive_model(recursive_body) + + +def test_get_missing_discriminator(client): + result = client.get_missing_discriminator() + assert result["wingspan"] == 1 + + +def test_get_wrong_discriminator(client): + result = client.get_wrong_discriminator() + assert result["wingspan"] == 1 + assert result["kind"] == "wrongKind" + + +def test_get_legacy_model(client): + result = client.get_legacy_model() + assert result["size"] == 20 + assert result["kind"] == "t-rex"