diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 8f253b3220..f2fe908297 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -146,29 +146,71 @@ def _init_from_model_json(self, model_json_str: str) -> None: self._dpmodel = BaseModel.deserialize(model_data) self._is_spin = False - self.rcut = self._dpmodel.get_rcut() - self.type_map = self._dpmodel.get_type_map() + self._rcut = self._dpmodel.get_rcut() + self._type_map = self._dpmodel.get_type_map() + self._sel = list(self._dpmodel.get_sel()) + self._mixed_types = bool(self._dpmodel.mixed_types()) if self._is_spin: - self._model_output_def = ModelOutputDef( - FittingOutputDef( - [ - OutputVariableDef( - "energy", - shape=[1], - reducible=True, - r_differentiable=True, - c_differentiable=True, - atomic=True, - magnetic=True, - ) - ] - ) - ) + spin_fitting_defs = self._dpmodel.model_output_def().def_outp.get_data() + # Keep only physical fitting outputs; mask is derived by ModelOutputDef. + fitting_defs = [ + vdef for name, vdef in spin_fitting_defs.items() if name != "mask" + ] + self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs)) else: self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def()) + def _init_from_metadata(self) -> None: + """Initialize DeepEval from ``metadata.json`` alone. + + Used when the ``.pt2`` / ``.pte`` archive ships no ``model.json`` + (e.g. for backends that do not travel through the dpmodel round-trip). + The metadata contract is the same one the C++ ``DeepPotPTExpt`` + reader consumes, so anything that validates against the C++ side + automatically validates here. + + ``self._dpmodel`` is left as ``None`` to signal the metadata-only + mode. Inference does not need it: it runs through + ``aoti_load_package`` / the exported module and uses plain + attributes (``self._rcut``, ``self._sel``, ``self._mixed_types``, + ``self._model_output_def``) for all metadata-level queries. + """ + self._dpmodel = None + self._is_spin = bool(self.metadata.get("is_spin", False)) + self._rcut = float(self.metadata["rcut"]) + self._type_map = list(self.metadata["type_map"]) + self._sel = [int(s) for s in self.metadata["sel"]] + self._mixed_types = bool(self.metadata["mixed_types"]) + + fitting_defs = [] + for vdef in self.metadata["fitting_output_defs"]: + fitting_defs.append( + OutputVariableDef( + name=vdef["name"], + shape=list(vdef["shape"]), + reducible=vdef.get("reducible", False), + r_differentiable=vdef.get("r_differentiable", False), + c_differentiable=vdef.get("c_differentiable", False), + atomic=vdef.get("atomic", True), + category=int( + vdef.get("category", OutputVariableCategory.OUT.value) + ), + r_hessian=vdef.get("r_hessian", False), + magnetic=vdef.get("magnetic", False), + intensive=vdef.get("intensive", False), + ) + ) + self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs)) + def _load_pte(self, model_file: str) -> None: - """Load a .pte (torch.export) model file.""" + """Load a .pte (torch.export) model file. + + ``model.json`` is optional: when present it is used to reconstruct + the dpmodel instance (enabling dpmodel-level introspection such as + ``eval_descriptor``); when absent we fall back to pure metadata + mode via :meth:`_init_from_metadata`. ``metadata.json`` is the + only contract the inference path actually requires. + """ extra_files = { "model.json": "", "model_def_script.json": "", @@ -176,38 +218,69 @@ def _load_pte(self, model_file: str) -> None: } exported = torch.export.load(model_file, extra_files=extra_files) self.exported_module = exported.module() - self._init_from_model_json(extra_files["model.json"]) mds = extra_files["model_def_script.json"] self._model_def_script = json.loads(mds) if mds else {} md = extra_files["metadata.json"] - self.metadata = json.loads(md) if md else {} + if not md: + raise ValueError( + f"Invalid .pte file '{model_file}': missing 'metadata.json'" + ) + self.metadata = json.loads(md) + + model_json_str = extra_files["model.json"] + if model_json_str: + self._init_from_model_json(model_json_str) + else: + self._init_from_metadata() def _load_pt2(self, model_file: str) -> None: - """Load a .pt2 (AOTInductor) model file.""" + """Load a .pt2 (AOTInductor) model file. + + ``model.json`` is optional — it only enables the dpmodel + round-trip (used by ``eval_descriptor``, ``eval_typeebd``, etc.). + Pure AOTI inference (``DeepPot.eval`` / ``dp test`` / ASE + calculator) only needs ``metadata.json``, matching the contract + the C++ ``DeepPotPTExpt`` reader enforces. + + Archive entries are located under ``model/extra/`` so that the + PyTorch 2.11 ``load_pt2`` loader accepts the archive without the + "outdated pt2 file" fallback warning. + """ import zipfile from torch._inductor import ( aoti_load_package, ) + from deepmd.pt_expt.utils.serialization import ( + PT2_EXTRA_PREFIX, + ) + + md_entry = PT2_EXTRA_PREFIX + "metadata.json" + model_json_entry = PT2_EXTRA_PREFIX + "model.json" + mds_entry = PT2_EXTRA_PREFIX + "model_def_script.json" + # Read metadata from the .pt2 ZIP archive with zipfile.ZipFile(model_file, "r") as zf: names = zf.namelist() - if "extra/model.json" not in names: + if md_entry not in names: raise ValueError( - f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'" + f"Invalid .pt2 file '{model_file}': missing '{md_entry}'" ) - model_json_str = zf.read("extra/model.json").decode("utf-8") + md = zf.read(md_entry).decode("utf-8") + model_json_str = "" + if model_json_entry in names: + model_json_str = zf.read(model_json_entry).decode("utf-8") mds = "" - if "extra/model_def_script.json" in names: - mds = zf.read("extra/model_def_script.json").decode("utf-8") - md = "" - if "extra/metadata.json" in names: - md = zf.read("extra/metadata.json").decode("utf-8") + if mds_entry in names: + mds = zf.read(mds_entry).decode("utf-8") - self._init_from_model_json(model_json_str) + self.metadata = json.loads(md) self._model_def_script = json.loads(mds) if mds else {} - self.metadata = json.loads(md) if md else {} + if model_json_str: + self._init_from_model_json(model_json_str) + else: + self._init_from_metadata() # Load the AOTInductor model package (.pt2 ZIP archive). # Uses torch._inductor.aoti_load_package (private API, stable since PyTorch 2.6). @@ -305,8 +378,10 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None: self._is_spin = ( model_params.get("type") == "spin_ener" or "spin" in model_params ) - self.rcut = model.get_rcut() - self.type_map = model.get_type_map() + self._rcut = model.get_rcut() + self._type_map = model.get_type_map() + self._sel = list(model.get_sel()) + self._mixed_types = bool(model.mixed_types()) if self._is_spin: self._model_output_def = ModelOutputDef( FittingOutputDef( @@ -399,28 +474,41 @@ def _eager_runner( def get_rcut(self) -> float: """Get the cutoff radius of this model.""" - return self.rcut + return self._rcut def get_ntypes(self) -> int: """Get the number of atom types of this model.""" - return len(self.type_map) + return len(self._type_map) def get_type_map(self) -> list[str]: """Get the type map (element name of the atom types) of this model.""" - return self.type_map + return self._type_map def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this DP.""" - return self._dpmodel.get_dim_fparam() + if self._dpmodel is not None: + return self._dpmodel.get_dim_fparam() + return int(self.metadata["dim_fparam"]) def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this DP.""" - return self._dpmodel.get_dim_aparam() + if self._dpmodel is not None: + return self._dpmodel.get_dim_aparam() + return int(self.metadata["dim_aparam"]) @property def model_type(self) -> type["DeepEvalWrapper"]: - """The the evaluator of the model type.""" - model_output_type = self._dpmodel.model_output_type() + """The evaluator of the model type.""" + if self._dpmodel is not None: + model_output_type = self._dpmodel.model_output_type() + else: + # Metadata-only mode: derive the output-type set from the + # fitting_output_defs names. `model_output_type()` on a + # dpmodel is the same set — just the base output names, not + # their derived `*_redu` / `*_derv_*` twins. + model_output_type = [ + d.name for d in self._model_output_def.def_outp.get_data().values() + ] if "energy" in model_output_type: return DeepPot elif "dos" in model_output_type: @@ -441,7 +529,12 @@ def get_sel_type(self) -> list[int]: to the result of the model. If returning an empty list, all atom types are selected. """ - return self._dpmodel.get_sel_type() + if self._dpmodel is not None: + return self._dpmodel.get_sel_type() + # Metadata-only mode: read the `sel_type` field populated by + # `_collect_metadata`. Missing field → `[]` (every type + # selected), matching the dpmodel default for energy models. + return [int(t) for t in self.metadata.get("sel_type", [])] def get_numb_dos(self) -> int: """Get the number of DOS.""" @@ -453,13 +546,15 @@ def get_has_efield(self) -> bool: def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" - return getattr(self, "_is_spin", False) + return self._is_spin def get_use_spin(self) -> list[bool]: """Get the per-type spin usage of this model.""" - if getattr(self, "_is_spin", False): + if not self._is_spin: + return [] + if self._dpmodel is not None: return self._dpmodel.spin.use_spin.tolist() - return [] + return [bool(v) for v in self.metadata.get("use_spin", [])] def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model. Only used in old implement.""" @@ -613,9 +708,9 @@ def _build_nlist_native( """ nframes = coords.shape[0] natoms = coords.shape[1] - rcut = self.rcut - sel = self._dpmodel.get_sel() - mixed_types = self._dpmodel.mixed_types() + rcut = self._rcut + sel = self._sel + mixed_types = self._mixed_types if cells is not None: box_input = cells.reshape(nframes, 3, 3) @@ -726,8 +821,8 @@ def _build_nlist_ase_single( nlist : np.ndarray, shape (nloc, nsel) mapping : np.ndarray, shape (nall,) """ - sel = self._dpmodel.get_sel() - mixed_types = self._dpmodel.mixed_types() + sel = self._sel + mixed_types = self._mixed_types nsel = sum(sel) natoms = positions.shape[0] @@ -770,7 +865,7 @@ def _build_nlist_ase_single( ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64) # Build nlist: vectorized CSR-to-dense conversion - rcut = self.rcut + rcut = self._rcut counts = np.diff(first_neigh) max_nn = int(counts.max()) if counts.size > 0 else 0 @@ -1186,13 +1281,44 @@ def get_model(self) -> torch.nn.Module: return self.exported_module def _is_spin_model(self) -> bool: - """Check if the underlying dpmodel is a SpinModel.""" + """Check if the underlying model is a SpinModel. + + Primary path: the :attr:`_is_spin` attribute set by the loaders + — this works for both ``model.json`` and metadata-only archives + (a spin ``.pt2`` carries ``is_spin=true`` in its metadata). + + Legacy path: ``isinstance(_dpmodel, SpinModel)`` — retained for + tests that construct a non-spin archive and then swap + :attr:`_dpmodel` to a :class:`SpinModel` instance after load. + """ + if self._is_spin: + return True + if self._dpmodel is None: + return False from deepmd.dpmodel.model.spin_model import ( SpinModel, ) return isinstance(self._dpmodel, SpinModel) + def _require_dpmodel(self, feature: str) -> None: + """Guard for features that need a deserialised dpmodel instance. + + ``eval_descriptor`` / ``eval_typeebd`` / ``eval_fitting_last_layer`` + all introspect the dpmodel's internal sub-modules, which requires + ``model.json`` to have been present at load time. Archives + shipped without ``model.json`` (metadata-only mode) can still run + the main ``eval`` inference path but cannot expose these hooks. + """ + if self._dpmodel is None: + raise NotImplementedError( + f"{feature} requires the dpmodel instance, which is only " + "available when the .pt2 / .pte archive contains " + "'model.json'. The loaded archive is metadata-only; " + "re-export with the full dpmodel serialisation to enable " + "this feature." + ) + def eval_typeebd(self) -> np.ndarray: """Evaluate type embedding. @@ -1205,7 +1331,11 @@ def eval_typeebd(self) -> np.ndarray: ------ KeyError If the model has no type embedding networks. + NotImplementedError + If the archive was loaded in metadata-only mode. """ + self._require_dpmodel("eval_typeebd") + from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP model = self._dpmodel @@ -1249,6 +1379,8 @@ def eval_descriptor( np.ndarray Descriptor output, shape ``(nframes, nloc, dim_descrpt)``. """ + self._require_dpmodel("eval_descriptor") + coords = np.array(coords) atom_types = np.array(atom_types, dtype=np.int32) if cells is not None: @@ -1315,6 +1447,8 @@ def eval_fitting_last_layer( np.ndarray Middle-layer output, shape ``(nframes, nloc, neuron[-1])``. """ + self._require_dpmodel("eval_fitting_last_layer") + coords = np.array(coords) atom_types = np.array(atom_types, dtype=np.int32) if cells is not None: diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 04cdedd6cf..7b2559db4f 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -15,6 +15,25 @@ traverse_model_dict, ) +# --------------------------------------------------------------------------- +# AOTInductor ``.pt2`` archive layout. +# +# PyTorch 2.11 tightened the single-model ``.pt2`` convention so that every +# entry in the ZIP archive must live under the top-level ``model/`` directory. +# Any stray root-level file makes +# ``torch.export.pt2_archive._package.load_pt2`` raise ``RuntimeError`` at +# load time; the upper-level ``torch._inductor.package.package.load_package`` +# then emits a misleading ``Loading outdated pt2 file. Please regenerate +# your package.`` warning and falls back to the legacy C++ loader. +# +# deepmd-kit therefore stores its metadata JSON blobs under ``model/extra/`` +# so that the strict ``load_pt2`` loader accepts the archive without +# complaint. The C++ reader (``commonPTExpt.h::read_zip_entry``) resolves +# this layout transparently because it matches ``entry_name`` as a +# ``/``-delimited suffix. +# --------------------------------------------------------------------------- +PT2_EXTRA_PREFIX = "model/extra/" + def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: """Neutralise shape-guard assertion nodes in a spin model's exported graph. @@ -251,9 +270,15 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: The ``fitting_output_defs`` list is also included so that ``ModelOutputDef`` can be reconstructed without loading the full model. """ - fitting_output_def = model.atomic_output_def() + if is_spin: + fitting_output_def = model.model_output_def().def_outp + else: + fitting_output_def = model.atomic_output_def() fitting_output_defs = [] for vdef in fitting_output_def.get_data().values(): + # Keep metadata aligned with physical fitting outputs only. + if is_spin and vdef.name == "mask": + continue fitting_output_defs.append( { "name": vdef.name, @@ -262,7 +287,9 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: "r_differentiable": vdef.r_differentiable, "c_differentiable": vdef.c_differentiable, "atomic": vdef.atomic, - "category": vdef.category, + # OutputVariableCategory is an IntEnum; force plain int for + # deterministic JSON serialisation across Python versions. + "category": int(vdef.category), "r_hessian": vdef.r_hessian, "magnetic": vdef.magnetic, "intensive": vdef.intensive, @@ -279,6 +306,10 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: "has_default_fparam": model.has_default_fparam(), "default_fparam": model.get_default_fparam(), "fitting_output_defs": fitting_output_defs, + # sel_type enables `DeepEval.get_sel_type()` without a dpmodel + # round-trip; required for dipole/polar/wfc models in metadata-only + # inference (energy models return []). + "sel_type": [int(t) for t in model.get_sel_type()], "is_spin": is_spin, } if is_spin: @@ -326,21 +357,23 @@ def _serialize_from_file_pte(model_file: str) -> dict: def _serialize_from_file_pt2(model_file: str) -> dict: """Serialize a .pt2 model file to a dictionary. - Reads the model dict stored in the extra/ directory of the .pt2 ZIP archive. + Reads the model dict stored in the ``model/extra/`` directory of the + ``.pt2`` ZIP archive. """ import zipfile + model_json_entry = PT2_EXTRA_PREFIX + "model.json" + model_def_script_entry = PT2_EXTRA_PREFIX + "model_def_script.json" with zipfile.ZipFile(model_file, "r") as zf: - if "extra/model.json" not in zf.namelist(): + names = zf.namelist() + if model_json_entry not in names: raise ValueError( - f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'" + f"Invalid .pt2 file '{model_file}': missing '{model_json_entry}'" ) - model_json = zf.read("extra/model.json").decode("utf-8") + model_json = zf.read(model_json_entry).decode("utf-8") model_def_script_json = "" - if "extra/model_def_script.json" in zf.namelist(): - model_def_script_json = zf.read("extra/model_def_script.json").decode( - "utf-8" - ) + if model_def_script_entry in names: + model_def_script_json = zf.read(model_def_script_entry).decode("utf-8") model_dict = json.loads(model_json) model_dict = _json_to_numpy(model_dict) if model_def_script_json: @@ -609,13 +642,20 @@ def _deserialize_to_file_pt2( finally: _inductor_config.realize_opcount_threshold = saved_threshold - # Embed metadata into the .pt2 ZIP archive + # Embed metadata into the .pt2 ZIP archive. Entries are placed under + # ``model/extra/`` so the strict PyTorch 2.11 ``load_pt2`` loader + # accepts the archive without emitting the "outdated pt2 file" + # fallback warning. See the module-level comment on + # ``PT2_EXTRA_PREFIX`` for the rationale. model_def_script = data.get("model_def_script") or {} metadata["output_keys"] = output_keys with zipfile.ZipFile(model_file, "a") as zf: - zf.writestr("extra/metadata.json", json.dumps(metadata)) - zf.writestr("extra/model_def_script.json", json.dumps(model_def_script)) + zf.writestr(PT2_EXTRA_PREFIX + "metadata.json", json.dumps(metadata)) + zf.writestr( + PT2_EXTRA_PREFIX + "model_def_script.json", + json.dumps(model_def_script), + ) zf.writestr( - "extra/model.json", + PT2_EXTRA_PREFIX + "model.json", json.dumps(data_for_json, separators=(",", ":")), ) diff --git a/source/tests/infer/gen_sea.py b/source/tests/infer/gen_sea.py index 02f4e7ee63..905d537c62 100644 --- a/source/tests/infer/gen_sea.py +++ b/source/tests/infer/gen_sea.py @@ -78,13 +78,17 @@ def _patch_no_atomic_virial(pt2_path: str) -> None: """Flip do_atomic_virial=False in the metadata.json of a .pt2 archive. The .pt2 is a ZIP archive; the metadata blob lives at - ``extra/metadata.json``. We rewrite the archive with that one entry + ``model/extra/metadata.json``. We rewrite the archive with that one entry replaced and all other entries preserved verbatim. """ import json import zipfile - metadata_name = "extra/metadata.json" + from deepmd.pt_expt.utils.serialization import ( + PT2_EXTRA_PREFIX, + ) + + metadata_name = PT2_EXTRA_PREFIX + "metadata.json" tmp_path = pt2_path + ".tmp" # PyTorch .pt2 archives use ZIP_STORED (uncompressed) so that the C++ # reader (read_zip_entry in commonPTExpt.h) and torch's mmap-based diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index f96f08ae28..f77b882b7c 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -651,6 +651,17 @@ def setUpClass(cls) -> None: finally: torch.set_default_device("cuda:9999999") + cls.meta_tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) + cls.meta_tmpfile.close() + with ( + zipfile.ZipFile(cls.tmpfile.name, "r") as zin, + zipfile.ZipFile(cls.meta_tmpfile.name, "w") as zout, + ): + for info in zin.infolist(): + if info.filename == "model/extra/model.json": + continue + zout.writestr(info, zin.read(info.filename)) + # Also save to .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) cls.pte_tmpfile.close() @@ -658,6 +669,8 @@ def setUpClass(cls) -> None: # Create DeepPot for .pt2 cls.dp = DeepPot(cls.tmpfile.name) + # Create DeepPot for metadata-only .pt2 + cls.dp_meta = DeepPot(cls.meta_tmpfile.name) # Create DeepPot for .pte reference cls.dp_pte = DeepPot(cls.pte_tmpfile.name) @@ -666,6 +679,7 @@ def tearDownClass(cls) -> None: import os os.unlink(cls.tmpfile.name) + os.unlink(cls.meta_tmpfile.name) os.unlink(cls.pte_tmpfile.name) def test_get_rcut(self) -> None: @@ -738,14 +752,52 @@ def test_pt2_file_is_zip(self) -> None: self.assertTrue(zipfile.is_zipfile(self.tmpfile.name)) def test_pt2_has_metadata(self) -> None: - """The .pt2 ZIP should contain metadata entries.""" + """The .pt2 ZIP should contain metadata entries under ``model/extra/``.""" with zipfile.ZipFile(self.tmpfile.name, "r") as zf: names = zf.namelist() - self.assertIn("extra/metadata.json", names) - self.assertIn("extra/model_def_script.json", names) - self.assertIn("extra/model.json", names) - self.assertNotIn("extra/output_keys.json", names) - self.assertNotIn("extra/model_params.json", names) + self.assertIn("model/extra/metadata.json", names) + self.assertIn("model/extra/model_def_script.json", names) + self.assertIn("model/extra/model.json", names) + self.assertNotIn("model/extra/output_keys.json", names) + self.assertNotIn("model/extra/model_params.json", names) + + def test_metadata_only_pt2_has_no_model_json(self) -> None: + """The metadata-only .pt2 keeps metadata but drops model.json.""" + with zipfile.ZipFile(self.meta_tmpfile.name, "r") as zf: + names = zf.namelist() + self.assertIn("model/extra/metadata.json", names) + self.assertNotIn("model/extra/model.json", names) + + def test_metadata_only_pt2_accessors_match(self) -> None: + """Metadata-only .pt2 archives expose the same metadata API.""" + full = self.dp.deep_eval + meta = self.dp_meta.deep_eval + self.assertIsNotNone(full._dpmodel) + self.assertIsNone(meta._dpmodel) + self.assertEqual(full.get_rcut(), meta.get_rcut()) + self.assertEqual(full.get_ntypes(), meta.get_ntypes()) + self.assertEqual(full.get_type_map(), meta.get_type_map()) + self.assertEqual(full.get_dim_fparam(), meta.get_dim_fparam()) + self.assertEqual(full.get_dim_aparam(), meta.get_dim_aparam()) + self.assertEqual(full.get_sel_type(), meta.get_sel_type()) + self.assertEqual(full.get_has_spin(), meta.get_has_spin()) + self.assertEqual(full.get_use_spin(), meta.get_use_spin()) + self.assertIs(full.model_type, meta.model_type) + + def test_metadata_only_pt2_eval_parity(self) -> None: + """Metadata-only .pt2 inference matches the full archive exactly.""" + rng = np.random.default_rng(GLOBAL_SEED + 29) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + full_out = self.dp.eval(coords, cells, atom_types, atomic=True) + meta_out = self.dp_meta.eval(coords, cells, atom_types, atomic=True) + + self.assertEqual(len(full_out), len(meta_out)) + for ref, test in zip(full_out, meta_out, strict=True): + np.testing.assert_array_equal(test, ref) def test_eval_consistency(self) -> None: """Test that DeepPot.eval gives same results as direct model forward.""" diff --git a/source/tests/pt_expt/infer/test_deep_eval_metadata_only.py b/source/tests/pt_expt/infer/test_deep_eval_metadata_only.py new file mode 100644 index 0000000000..8ea48ab821 --- /dev/null +++ b/source/tests/pt_expt/infer/test_deep_eval_metadata_only.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Metadata-only loading tests for the pt_expt DeepEval. + +Exercises the "no ``model.json``" fallback path added to +:class:`deepmd.pt_expt.infer.deep_eval.DeepEval`: pt_expt ``.pte`` / +``.pt2`` archives are loadable when they only ship ``extra/metadata.json`` +(matching the contract the C++ ``DeepPotPTExpt`` reader enforces). + +Strategy +-------- +1. Build a tiny pt_expt SeA energy model and freeze it to a regular + ``.pte`` (the fast path; ``.pt2`` AOTInductor compilation is too + heavy for a routine unit test). +2. Read back that ``.pte`` and record the reference outputs. +3. Copy all archive entries except ``extra/model.json`` into a + metadata-only variant. +4. Load the metadata-only archive via ``DeepPot`` and assert that the + metadata-level accessors and the numeric ``eval`` result are + **bitwise identical** to the reference. +5. Verify that the dpmodel-only hooks (``eval_descriptor``, + ``eval_typeebd``, ``eval_fitting_last_layer``) raise + :class:`NotImplementedError` in metadata-only mode, since they + inherently need the deserialised dpmodel instance. +""" + +from __future__ import ( + annotations, +) + +import tempfile +import unittest +import zipfile +from pathlib import ( + Path, +) + +import numpy as np +import torch + +from deepmd.infer import ( + DeepPot, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils.serialization import ( + deserialize_to_file, +) + + +def _strip_extra_model_json(src: Path, dst: Path) -> None: + """Copy ``src`` to ``dst`` dropping any ``extra/model.json`` entry. + + ``torch.export.save`` lays the archive out as + ``/extra/{model,metadata,model_def_script}.json``; the + tmp prefix is chosen at save time so we match by suffix rather than + by an exact path. Every other entry (including the AOTI-compiled + binaries for ``.pt2``) is copied through unmodified. + """ + with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(dst, "w") as zout: + for info in zin.infolist(): + if info.filename.endswith("extra/model.json"): + continue + zout.writestr(info, zin.read(info.filename)) + + +class TestDeepEvalMetadataOnlyPte(unittest.TestCase): + """End-to-end parity between full and metadata-only ``.pte`` archives.""" + + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(0) + + # ----- build a tiny fp64 SeA energy model ----- + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [6, 6] + cls.type_map = ["O", "H"] + cls.ntypes = len(cls.type_map) + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.ntypes, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + seed=7, + ) + model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = model.to(torch.float64).eval() + cls.model_data = {"model": cls.model.serialize()} + + # ----- freeze to .pte (full + metadata-only variants) ----- + cls._tmpdir = tempfile.TemporaryDirectory() + tmp_root = Path(cls._tmpdir.name) + cls.full_path = tmp_root / "full.pte" + cls.meta_only_path = tmp_root / "meta_only.pte" + deserialize_to_file(str(cls.full_path), cls.model_data) + _strip_extra_model_json(cls.full_path, cls.meta_only_path) + + cls.dp_full = DeepPot(str(cls.full_path)) + cls.dp_meta = DeepPot(str(cls.meta_only_path)) + + # ----- a deterministic sample for numeric parity ----- + rng = np.random.default_rng(42) + cls.natoms = 5 + cls.coord = rng.random((1, cls.natoms, 3), dtype=np.float64) * 6.0 + cls.cell = (np.eye(3, dtype=np.float64) * 12.0).reshape(1, 9) + cls.atype = np.array([0, 1, 0, 1, 0], dtype=np.int32) + + @classmethod + def tearDownClass(cls) -> None: + cls._tmpdir.cleanup() + + # ----- archive layout sanity ------------------------------------ + + def test_meta_only_archive_has_no_extra_model_json(self) -> None: + with zipfile.ZipFile(self.meta_only_path, "r") as zf: + names = zf.namelist() + self.assertFalse( + any(n.endswith("extra/model.json") for n in names), + msg="extra/model.json must be absent in the metadata-only archive", + ) + self.assertTrue( + any(n.endswith("extra/metadata.json") for n in names), + msg="extra/metadata.json is mandatory and must survive zip surgery", + ) + + # ----- metadata-level parity ------------------------------------ + + def test_metadata_level_accessors_match(self) -> None: + """All metadata-level queries agree between the two archives.""" + full = self.dp_full.deep_eval + meta = self.dp_meta.deep_eval + self.assertEqual(full.get_rcut(), meta.get_rcut()) + self.assertEqual(full.get_ntypes(), meta.get_ntypes()) + self.assertEqual(full.get_type_map(), meta.get_type_map()) + self.assertEqual(full.get_dim_fparam(), meta.get_dim_fparam()) + self.assertEqual(full.get_dim_aparam(), meta.get_dim_aparam()) + self.assertEqual(full.get_sel_type(), meta.get_sel_type()) + self.assertEqual(full.get_has_spin(), meta.get_has_spin()) + self.assertEqual(full.get_use_spin(), meta.get_use_spin()) + self.assertIs(full.model_type, meta.model_type) + + def test_internal_attributes_match(self) -> None: + """The hot-path attributes hoisted in both init paths must agree.""" + full = self.dp_full.deep_eval + meta = self.dp_meta.deep_eval + self.assertEqual(list(full._sel), list(meta._sel)) + self.assertEqual(bool(full._mixed_types), bool(meta._mixed_types)) + self.assertEqual(full._rcut, meta._rcut) + self.assertEqual(list(full._type_map), list(meta._type_map)) + + def test_dpmodel_presence(self) -> None: + """``_dpmodel`` is the single signal that separates the two modes.""" + self.assertIsNotNone(self.dp_full.deep_eval._dpmodel) + self.assertIsNone(self.dp_meta.deep_eval._dpmodel) + + # ----- numeric parity ------------------------------------------- + + def test_eval_numeric_parity(self) -> None: + """``DeepPot.eval`` must be bitwise identical across the two archives.""" + e_full, f_full, v_full = self.dp_full.eval( + self.coord, self.cell, self.atype, atomic=False + )[:3] + e_meta, f_meta, v_meta = self.dp_meta.eval( + self.coord, self.cell, self.atype, atomic=False + )[:3] + np.testing.assert_array_equal( + e_meta, e_full, err_msg="energy mismatch between full / meta-only" + ) + np.testing.assert_array_equal( + f_meta, f_full, err_msg="force mismatch between full / meta-only" + ) + np.testing.assert_array_equal( + v_meta, v_full, err_msg="virial mismatch between full / meta-only" + ) + + def test_eval_atomic_parity(self) -> None: + """Atomic outputs (atom_energy / atom_virial) match as well.""" + full_out = self.dp_full.eval(self.coord, self.cell, self.atype, atomic=True) + meta_out = self.dp_meta.eval(self.coord, self.cell, self.atype, atomic=True) + self.assertEqual(len(full_out), len(meta_out)) + for ref, test in zip(full_out, meta_out, strict=True): + np.testing.assert_array_equal(test, ref) + + # ----- dpmodel-only hooks must degrade to NotImplementedError --- + + def test_eval_descriptor_requires_dpmodel(self) -> None: + with self.assertRaises(NotImplementedError): + self.dp_meta.deep_eval.eval_descriptor(self.coord, self.cell, self.atype) + + def test_eval_fitting_last_layer_requires_dpmodel(self) -> None: + with self.assertRaises(NotImplementedError): + self.dp_meta.deep_eval.eval_fitting_last_layer( + self.coord, self.cell, self.atype + ) + + def test_eval_typeebd_requires_dpmodel(self) -> None: + with self.assertRaises(NotImplementedError): + self.dp_meta.deep_eval.eval_typeebd() + + +class TestDeepEvalMetadataOnlyGuards(unittest.TestCase): + """Error-path coverage that is independent of the .pte fixture.""" + + def test_missing_metadata_json_is_rejected(self) -> None: + """A ``.pte`` stripped of ``metadata.json`` must raise on load. + + Metadata is the minimum contract — unlike ``model.json`` it + must always be present. + """ + torch.manual_seed(0) + ds = DescrptSeA(4.0, 0.5, [6, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=ds.mixed_types(), seed=1) + model = EnergyModel(ds, ft, type_map=["a", "b"]).to(torch.float64).eval() + with tempfile.TemporaryDirectory() as tmp: + full = Path(tmp) / "full.pte" + broken = Path(tmp) / "no_metadata.pte" + deserialize_to_file(str(full), {"model": model.serialize()}) + with ( + zipfile.ZipFile(full, "r") as zin, + zipfile.ZipFile(broken, "w") as zout, + ): + for info in zin.infolist(): + if info.filename.endswith("extra/metadata.json"): + continue + zout.writestr(info, zin.read(info.filename)) + with self.assertRaises(ValueError): + DeepPot(str(broken)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/infer/test_deep_eval_spin.py b/source/tests/pt_expt/infer/test_deep_eval_spin.py index 829b1f5666..6a3be99319 100644 --- a/source/tests/pt_expt/infer/test_deep_eval_spin.py +++ b/source/tests/pt_expt/infer/test_deep_eval_spin.py @@ -8,6 +8,7 @@ import copy import os import tempfile +import zipfile import numpy as np import pytest @@ -105,6 +106,37 @@ BOX = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64) +def _strip_extra_model_json(src: str, dst: str) -> None: + """Copy ``src`` to ``dst`` dropping any ``extra/model.json`` entry.""" + with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(dst, "w") as zout: + for info in zin.infolist(): + if info.filename.endswith("extra/model.json"): + continue + zout.writestr(info, zin.read(info.filename)) + + +def _assert_fitting_output_defs_match(full_eval, meta_eval) -> None: + """Assert that metadata rebuilds the same fitting output definitions.""" + full_defs = full_eval._model_output_def.def_outp.get_data() + meta_defs = meta_eval._model_output_def.def_outp.get_data() + assert full_defs.keys() == meta_defs.keys() + attrs = ( + "shape", + "reducible", + "r_differentiable", + "c_differentiable", + "atomic", + "category", + "r_hessian", + "magnetic", + "intensive", + ) + for name, full_def in full_defs.items(): + meta_def = meta_defs[name] + for attr in attrs: + assert getattr(meta_def, attr) == getattr(full_def, attr) + + def _build_reference(): """Build pt_expt model and run eager reference inference. @@ -163,6 +195,9 @@ def spin_model_files(): finally: torch.set_default_device(prev) files[ext] = path + meta_path = os.path.join(tmpdir, "spin_test_metadata_only.pte") + _strip_extra_model_json(files[".pte"], meta_path) + files[".pte.meta"] = meta_path yield files, ref_pbc, ref_nopbc for path in files.values(): if os.path.exists(path): @@ -341,6 +376,32 @@ def test_eval_nopbc_nonatomic(self, spin_model_files, ext) -> None: ) +class TestSpinMetadataOnly: + """Test metadata-only spin model inference through DeepPot.""" + + def test_metadata_only_spin_pte_parity(self, spin_model_files) -> None: + """Metadata-only spin .pte matches full archive metadata and outputs.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, _ = spin_model_files + full_dp = DeepPot(files[".pte"]) + meta_dp = DeepPot(files[".pte.meta"]) + + assert meta_dp.has_spin == full_dp.has_spin + assert meta_dp.use_spin == full_dp.use_spin + assert meta_dp.get_ntypes_spin() == full_dp.get_ntypes_spin() + _assert_fitting_output_defs_match(full_dp.deep_eval, meta_dp.deep_eval) + + full_out = full_dp.eval(COORD, BOX, ATYPE, atomic=True, spin=SPIN) + meta_out = meta_dp.eval(COORD, BOX, ATYPE, atomic=True, spin=SPIN) + + assert len(full_out) == len(meta_out) + for ref, test in zip(full_out, meta_out, strict=True): + np.testing.assert_array_equal(test, ref) + + SPIN_FPARAM_CONFIG = copy.deepcopy(SPIN_CONFIG) SPIN_FPARAM_CONFIG["fitting_net"]["numb_fparam"] = 1 SPIN_FPARAM_CONFIG["fitting_net"]["default_fparam"] = [0.5]