diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index ac6963b435..d735443383 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -166,6 +166,16 @@ def get_has_efield(self) -> bool: """Check if the model has efield.""" return False + def get_has_spin(self) -> bool: + """Check if the model has spin atom types.""" + return hasattr(self.dp, "spin") + + def get_use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model.""" + if hasattr(self.dp, "spin"): + return self.dp.spin.use_spin.tolist() + return [] + def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model.""" return 0 diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py index b89172c4f6..9da6f8e585 100644 --- a/deepmd/dpmodel/model/base_model.py +++ b/deepmd/dpmodel/model/base_model.py @@ -127,6 +127,14 @@ def deserialize(cls, data: dict) -> "BaseBaseModel": model_type = data.get("type", "standard") if model_type == "standard": model_type = data.get("fitting", {}).get("type", "ener") + if model_type == "spin_ener": + # SpinModel is not a BaseModel subclass and cannot be + # registered via the plugin registry. Dispatch directly. + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + + return SpinModel.deserialize(data) return cls.get_class_by_type(model_type).deserialize(data) raise NotImplementedError(f"Not implemented in class {cls.__name__}") diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index d9f185a60d..e60bf809f9 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -549,12 +549,15 @@ def __getattr__(self, name: str) -> Any: def serialize(self) -> dict: return { + "type": "spin_ener", "backbone_model": self.backbone_model.serialize(), "spin": self.spin.serialize(), } @classmethod def deserialize(cls, data: dict) -> "SpinModel": + data = data.copy() + data.pop("type", None) backbone_model_obj = make_model( DPAtomicModel, T_Bases=(NativeOP, BaseModel) ).deserialize(data["backbone_model"]) @@ -646,7 +649,7 @@ def call_common( ) = self.process_spin_output( atype, model_ret[f"{var_name}_derv_c"], - add_mag=False, + add_mag=True, virtual_scale=False, ) # Always compute mask_mag from atom types (even when forces are unavailable) @@ -823,7 +826,7 @@ def call_common_lower( extended_atype, model_ret[f"{var_name}_derv_c"], nloc, - add_mag=False, + add_mag=True, virtual_scale=False, ) # Always compute mask_mag from atom types (even when forces are unavailable) diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index b697898896..b8ca99469f 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -16,6 +16,7 @@ ModelOutputDef, OutputVariableDef, get_deriv_name, + get_deriv_name_mag, get_hessian_name, get_reduce_name, ) @@ -128,6 +129,21 @@ def communicate_extended_output( model_ret[kk_derv_r], ) new_ret[kk_derv_r] = force + if vdef.magnetic: + kk_derv_r_mag = get_deriv_name_mag(kk)[0] + if model_ret.get(kk_derv_r_mag) is not None: + force_mag = xp.zeros( + vldims + derv_r_ext_dims, + dtype=vv.dtype, + device=device, + ) + force_mag = xp_scatter_sum( + force_mag, + 1, + mapping, + model_ret[kk_derv_r_mag], + ) + new_ret[kk_derv_r_mag] = force_mag else: # name holders new_ret[kk_derv_r] = None @@ -235,10 +251,29 @@ def communicate_extended_output( ) new_ret[kk_derv_c] = virial new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) + if vdef.magnetic: + kk_derv_c_mag = get_deriv_name_mag(kk)[1] + if model_ret.get(kk_derv_c_mag) is not None: + virial_mag = xp.zeros( + vldims + derv_c_ext_dims, + dtype=vv.dtype, + device=device, + ) + virial_mag = xp_scatter_sum( + virial_mag, + 1, + mapping, + model_ret[kk_derv_c_mag], + ) + new_ret[kk_derv_c_mag] = virial_mag else: new_ret[kk_derv_c] = None new_ret[kk_derv_c + "_redu"] = None if not do_atomic_virial: # pop atomic virial, because it is not correctly calculated. new_ret.pop(kk_derv_c) + # Slice mask_mag from extended to local atoms + if "mask_mag" in model_ret: + nloc = new_ret[next(iter(model_output_def.keys_outp()))].shape[1] + new_ret["mask_mag"] = model_ret["mask_mag"][:, :nloc] return new_ret diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index d375a2ecd7..807414fa5d 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -323,6 +323,17 @@ def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" return False + def get_use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model. + + Returns + ------- + list[bool] + A list of bool indicating whether each atom type uses spin. + Empty list if the model does not have spin. + """ + return [] + def get_has_hessian(self) -> bool: """Check if the model has hessian.""" return False @@ -705,6 +716,18 @@ def has_spin(self) -> bool: """Check if the model has spin.""" return self.deep_eval.get_has_spin() + @property + def use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model. + + Returns + ------- + list[bool] + A list of bool indicating whether each atom type uses spin. + Empty list if the model does not have spin. + """ + return self.deep_eval.get_use_spin() + @property def has_hessian(self) -> bool: """Check if the model has hessian.""" diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 6c0ffed7ec..8f4dda7199 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -297,6 +297,13 @@ def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" return self._has_spin + def get_use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model.""" + if self._has_spin: + model = self.dp.model["Default"] + return model.spin.use_spin.tolist() + return [] + def get_has_hessian(self) -> bool: """Check if the model has hessian.""" return self._has_hessian diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 97908c873a..2e30b8574a 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -314,6 +314,13 @@ def get_has_spin(self) -> bool: """Check if the model has spin atom types.""" return self._has_spin + def get_use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model.""" + if self._has_spin: + model = self.dp.model["Default"] + return model.spin.use_spin.tolist() + return [] + def get_has_hessian(self) -> bool: """Check if the model has hessian.""" return self._has_hessian diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index a7b83c0fe0..91c6e2ea71 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -637,12 +637,15 @@ def forward_common_lower( def serialize(self) -> dict: return { + "type": "spin_ener", "backbone_model": self.backbone_model.serialize(), "spin": self.spin.serialize(), } @classmethod def deserialize(cls, data: dict[str, Any]) -> "SpinModel": + data = data.copy() + data.pop("type", None) backbone_model_obj = make_model(DPAtomicModel).deserialize( data["backbone_model"] ) diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index 5d3b02482a..e54ec9c76d 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -72,7 +72,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: """ if not model_file.endswith(".pth"): raise ValueError("PyTorch backend only supports converting .pth file") - model = BaseModel.deserialize(data["model"]) + model_data = data["model"] + if model_data.get("type") == "spin_ener": + from deepmd.pt.model.model.spin_model import ( + SpinEnergyModel, + ) + + model = SpinEnergyModel.deserialize(model_data) + else: + model = BaseModel.deserialize(model_data) # JIT will happy in this way... model.model_def_script = json.dumps(data["model_def_script"]) if "min_nbor_dist" in data.get("@variables", {}): diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 290f2ec923..dd1831a4ba 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -16,6 +16,7 @@ communicate_extended_output, ) from deepmd.dpmodel.output_def import ( + FittingOutputDef, ModelOutputDef, OutputVariableCategory, OutputVariableDef, @@ -124,10 +125,39 @@ def _init_from_model_json(self, model_json_str: str) -> None: model_dict = json.loads(model_json_str) model_dict = _json_to_numpy(model_dict) - self._dpmodel = BaseModel.deserialize(model_dict["model"]) + model_data = model_dict["model"] + + if model_data.get("type") == "spin_ener": + from deepmd.pt_expt.model.spin_model import ( + SpinModel, + ) + + self._dpmodel = SpinModel.deserialize(model_data) + self._is_spin = True + else: + self._dpmodel = BaseModel.deserialize(model_data) + self._is_spin = False + self.rcut = self._dpmodel.get_rcut() self.type_map = self._dpmodel.get_type_map() - self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def()) + if self._is_spin: + self._model_output_def = ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + atomic=True, + magnetic=True, + ) + ] + ) + ) + else: + self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def()) def _load_pte(self, model_file: str) -> None: """Load a .pte (torch.export) model file.""" @@ -230,8 +260,18 @@ def get_has_efield(self) -> bool: """Check if the model has efield.""" return False + def get_has_spin(self) -> bool: + """Check if the model has spin atom types.""" + return getattr(self, "_is_spin", False) + + def get_use_spin(self) -> list[bool]: + """Get the per-type spin usage of this model.""" + if getattr(self, "_is_spin", False): + return self._dpmodel.spin.use_spin.tolist() + return [] + def get_ntypes_spin(self) -> int: - """Get the number of spin atom types of this model.""" + """Get the number of spin atom types of this model. Only used in old implement.""" return 0 def eval( @@ -283,9 +323,26 @@ def eval( coords, atom_types, len(atom_types.shape) > 1 ) request_defs = self._get_request_defs(atomic) - out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, fparam, aparam, request_defs - ) + spins = kwargs.get("spin") + if self._is_spin and spins is None: + raise ValueError( + "This is a spin model but no `spin` argument was provided. " + "Please call eval(..., spin=spin_array)." + ) + if not self._is_spin and spins is not None: + raise ValueError( + "This is not a spin model but a `spin` argument was provided. " + "Please call eval(...) without the `spin` argument." + ) + if spins is not None: + spins = np.array(spins) + out = self._eval_func(self._eval_model_spin, numb_test, natoms)( + coords, cells, atom_types, spins, fparam, aparam, request_defs + ) + else: + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, request_defs + ) return dict( zip( [x.name for x in request_defs], @@ -304,6 +361,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: for x in self.output_def.var_defs.values() if x.category in ( + OutputVariableCategory.OUT, OutputVariableCategory.REDU, OutputVariableCategory.DERV_R, OutputVariableCategory.DERV_C_REDU, @@ -622,9 +680,9 @@ def _eval_model( dtype=torch.float64, device=DEVICE, ) - elif self._is_pt2 and self.get_dim_fparam() > 0: - # .pt2 models are compiled with fparam as a required input. - # When the user omits fparam, fill with default values from metadata. + elif self.get_dim_fparam() > 0: + # Exported models (.pt2/.pte) are compiled with fparam as a + # required input. Fill with default values from metadata. default_fp = self.metadata.get("default_fparam") if default_fp is not None: fparam_t = ( @@ -647,6 +705,13 @@ def _eval_model( dtype=torch.float64, device=DEVICE, ) + elif self.get_dim_aparam() > 0: + # Exported models (.pt2/.pte) are compiled with aparam as a + # required positional input. Unlike fparam, there is no default. + raise ValueError( + f"aparam is required for this model (dim_aparam={self.get_dim_aparam()}) " + "but was not provided." + ) else: aparam_t = None @@ -695,6 +760,164 @@ def _eval_model( ) return tuple(results) + def _eval_model_spin( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + spins: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + coord_input = coords.reshape(nframes, natoms, 3) + if self.neighbor_list is not None: + extended_coord, extended_atype, nlist, mapping = self._build_nlist_ase( + coord_input, + cells, + atom_types, + ) + ext_coord_t = torch.tensor( + extended_coord, dtype=torch.float64, device=DEVICE + ) + ext_atype_t = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) + else: + coord_t = torch.tensor(coord_input, dtype=torch.float64, device=DEVICE) + atype_t = torch.tensor(atom_types, dtype=torch.int64, device=DEVICE) + cells_t = ( + torch.tensor(cells, dtype=torch.float64, device=DEVICE) + if cells is not None + else None + ) + ext_coord_t, ext_atype_t, nlist_t, mapping_t = self._build_nlist_native( + coord_t, + cells_t, + atype_t, + ) + + # Extend spin to ghost atoms using mapping + spin_t = torch.tensor( + spins.reshape(nframes, natoms, 3), dtype=torch.float64, device=DEVICE + ) + batch_idx = ( + torch.arange(nframes, dtype=torch.long, device=DEVICE) + .unsqueeze(1) + .expand_as(mapping_t) + ) + ext_spin_t = spin_t[batch_idx, mapping_t] + + if fparam is not None: + fparam_t = torch.tensor( + fparam.reshape(nframes, self.get_dim_fparam()), + dtype=torch.float64, + device=DEVICE, + ) + elif self.get_dim_fparam() > 0: + # Exported models (.pt2/.pte) are compiled with fparam as a + # required input. Fill with default values from metadata. + default_fp = self.metadata.get("default_fparam") + if default_fp is not None: + fparam_t = ( + torch.tensor(default_fp, dtype=torch.float64, device=DEVICE) + .unsqueeze(0) + .expand(nframes, -1) + .contiguous() + ) + else: + raise ValueError( + f"fparam is required for this model (dim_fparam={self.get_dim_fparam()}) " + "but was not provided, and no default_fparam is stored in the model." + ) + else: + fparam_t = None + + if aparam is not None: + aparam_t = torch.tensor( + aparam.reshape(nframes, natoms, self.get_dim_aparam()), + dtype=torch.float64, + device=DEVICE, + ) + elif self.get_dim_aparam() > 0: + raise ValueError( + f"aparam is required for this model (dim_aparam={self.get_dim_aparam()}) " + "but was not provided." + ) + else: + aparam_t = None + + # Call the model with spin (7 args) + if self._is_pt2: + model_ret = self._pt2_runner( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + ) + else: + model_ret = self.exported_module( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + ) + + # Apply communicate_extended_output to map extended atoms → local atoms + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C for x in request_defs + ) + + # Save pre-computed reduced virial: it includes both real and virtual + # atom contributions. communicate_extended_output would recompute it + # from only the real-atom per-atom virial, losing the virtual part. + saved_virial_redu = model_ret.get("energy_derv_c_redu") + + model_predict = communicate_extended_output( + model_ret, + self._model_output_def, + mapping_t, + do_atomic_virial=do_atomic_virial, + ) + + # Restore the correct reduced virial (includes virtual atom contribution) + if saved_virial_redu is not None: + model_predict["energy_derv_c_redu"] = saved_virial_redu + + # Translate internal keys to backend names and collect results + results = [] + for odef in request_defs: + if odef.name in model_predict: + shape = self._get_output_shape(odef, nframes, natoms) + if model_predict[odef.name] is not None: + out = model_predict[odef.name].detach().cpu().numpy().reshape(shape) + else: + out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append( + np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + ) + return tuple(results) + def _get_output_shape( self, odef: OutputVariableDef, nframes: int, natoms: int ) -> list[int]: diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py index 259ff10698..70f41f0701 100644 --- a/deepmd/pt_expt/model/spin_model.py +++ b/deepmd/pt_expt/model/spin_model.py @@ -129,6 +129,8 @@ def deserialize(cls, data: dict) -> "SpinModel": DPEnergyAtomicModel, ) + data = data.copy() + data.pop("type", None) backbone_model_obj = make_model( DPEnergyAtomicModel, T_Bases=(BaseModel,) ).deserialize(data["backbone_model"]) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index f23d0bb025..f59c397525 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -16,6 +16,36 @@ ) +def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: + """Remove shape-guard assertion nodes from a spin model's exported graph. + + ``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape + relationships discovered during tracing. For the spin model, the atom- + doubling logic creates slice patterns that depend on ``(nall - nloc)``, + producing guards like ``Ne(nall, nloc)``. These guards are spurious: the + model computes correct results even when ``nall == nloc`` (NoPBC, no ghost + atoms). + + This function is **only called for spin models** (guarded by ``if is_spin`` + in ``_trace_and_export``). The assertion messages use opaque symbolic + variable names (e.g. ``Ne(s22, s96)``) rather than human-readable names, + so filtering by message content is not reliable. Since + ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape + guards into these deferred assertions, and the only shape relationships in + the spin model involve nall/nloc, removing all of them is safe in this + context. + """ + graph = graph_module.graph + for node in list(graph.nodes): + if ( + node.op == "call_function" + and node.target is torch.ops.aten._assert_scalar.default + ): + graph.erase_node(node) + graph.eliminate_dead_code() + graph_module.recompile() + + def _numpy_to_json_serializable(model_obj: dict) -> dict: """Convert numpy arrays in a model dict to JSON-serializable lists.""" return traverse_model_dict( @@ -49,6 +79,7 @@ def _make_sample_inputs( model: torch.nn.Module, nframes: int = 1, nloc: int = 7, + has_spin: bool = False, ) -> tuple[torch.Tensor, ...]: """Create sample inputs for tracing forward_lower. @@ -60,11 +91,14 @@ def _make_sample_inputs( Number of frames. nloc : int Number of local atoms. + has_spin : bool + If True, create an extended spin tensor and return 7 tensors. Returns ------- tuple - (ext_coord, ext_atype, nlist, mapping, fparam, aparam) + (ext_coord, ext_atype, nlist, mapping, fparam, aparam) or + (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) when has_spin. """ rcut = model.get_rcut() sel = model.get_sel() @@ -131,22 +165,31 @@ def _make_sample_inputs( else: aparam = None + if has_spin: + nall = extended_coord.shape[1] + ext_spin = torch.zeros( + nframes, nall, 3, dtype=torch.float64, device=_env.DEVICE + ) + return ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam + return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam def _build_dynamic_shapes( - _ext_coord: torch.Tensor, - _ext_atype: torch.Tensor, - _nlist: torch.Tensor, - _mapping: torch.Tensor, - fparam: torch.Tensor | None, - aparam: torch.Tensor | None, + *sample_inputs: torch.Tensor | None, + has_spin: bool = False, ) -> tuple: """Build dynamic shape specifications for torch.export. Marks nframes, nloc and nall as dynamic dimensions so the exported program handles arbitrary frame and atom counts. + Parameters + ---------- + *sample_inputs : torch.Tensor | None + Sample inputs: either 6 tensors (non-spin) or 7 tensors (spin). + has_spin : bool + Whether the inputs include an extended_spin tensor. Returns a tuple (not dict) to match positional args of the make_fx traced module, whose arg names may have suffixes like ``_1``. """ @@ -154,17 +197,34 @@ def _build_dynamic_shapes( nall_dim = torch.export.Dim("nall", min=1) nloc_dim = torch.export.Dim("nloc", min=1) - return ( - {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) - {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) - {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) - {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) - {0: nframes_dim} if fparam is not None else None, # fparam - {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam - ) + if has_spin: + # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) + fparam = sample_inputs[5] + aparam = sample_inputs[6] + return ( + {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) + {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) + {0: nframes_dim, 1: nall_dim}, # extended_spin: (nframes, nall, 3) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam + ) + else: + # (ext_coord, ext_atype, nlist, mapping, fparam, aparam) + fparam = sample_inputs[4] + aparam = sample_inputs[5] + return ( + {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) + {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam + ) -def _collect_metadata(model: torch.nn.Module) -> dict: +def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: """Collect metadata from the model for C++ inference. This metadata is stored as ``metadata.json`` in both .pt2 and .pte archives. @@ -193,7 +253,7 @@ def _collect_metadata(model: torch.nn.Module) -> dict: "intensive": vdef.intensive, } ) - return { + meta = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), "sel": model.get_sel(), @@ -203,7 +263,12 @@ def _collect_metadata(model: torch.nn.Module) -> dict: "has_default_fparam": model.has_default_fparam(), "default_fparam": model.get_default_fparam(), "fitting_output_defs": fitting_output_defs, + "is_spin": is_spin, } + if is_spin: + meta["ntypes_spin"] = model.spin.get_ntypes_spin() + meta["use_spin"] = [bool(v) for v in model.spin.use_spin] + return meta def serialize_from_file(model_file: str) -> dict: @@ -317,51 +382,94 @@ def _trace_and_export( target_device = _env.DEVICE + # Detect spin model + is_spin = data["model"].get("type") == "spin_ener" + # 1. Deserialize model on CPU for make_fx tracing. # make_fx with _allow_non_fake_inputs=True keeps real model parameters; # on CUDA the autograd engine requires CUDA streams for those real # tensors during torch.autograd.grad, but proxy-tensor dispatch doesn't # set streams up → assertion failure. Tracing on CPU avoids this. - model = BaseModel.deserialize(data["model"]) + if is_spin: + from deepmd.pt_expt.model.spin_model import ( + SpinModel, + ) + + model = SpinModel.deserialize(data["model"]) + else: + model = BaseModel.deserialize(data["model"]) model.to("cpu") model.eval() # 2. Collect metadata - metadata = _collect_metadata(model) + metadata = _collect_metadata(model, is_spin=is_spin) # 3. Create sample inputs on CPU for tracing - # Use nframes=5 to avoid two specialization traps: - # - nframes=1 causes make_fx to specialize on the scalar case - # - nframes=N where N == numb_fparam or numb_aparam causes PyTorch's - # symbolic tracer to merge symbols (e.g. fparam.shape=(2,2) when - # nframes=2 and numb_fparam=2), so a guard on one dim constrains - # the other. 5 is unlikely to collide with typical param counts. + # torch.export's duck-sizing unifies dimensions with the same sample value, + # so nframes must differ from every other dimension in the sample tensors. + # We first build with nframes=2, collect all non-batch dimension sizes, + # then rebuild if there is a collision. _orig_device = _env.DEVICE _env.DEVICE = torch.device("cpu") try: - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( - model, nframes=5 - ) + nframes = 2 + sample_inputs = _make_sample_inputs(model, nframes=nframes, has_spin=is_spin) + # Collect all dimension sizes except dim-0 (nframes) from every tensor + other_dims: set[int] = set() + for t in sample_inputs: + if t is not None: + other_dims.update(t.shape[1:]) + while nframes in other_dims: + nframes += 1 + if nframes != 2: + sample_inputs = _make_sample_inputs( + model, nframes=nframes, has_spin=is_spin + ) finally: _env.DEVICE = _orig_device + if is_spin: + ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam = ( + sample_inputs + ) + else: + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = sample_inputs + # 4. Trace via make_fx on CPU. # This decomposes torch.autograd.grad into aten ops so the resulting # GraphModule no longer contains autograd calls. - traced = model.forward_common_lower_exportable( - ext_coord, - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - do_atomic_virial=True, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - ) + if is_spin: + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + # 5. Extract output keys from the CPU-traced module. + sample_out = traced( + ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam + ) + else: + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + # 5. Extract output keys from the CPU-traced module. + sample_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) - # 5. Extract output keys from the CPU-traced module. - sample_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) output_keys = list(sample_out.keys()) # 6. Export on CPU. @@ -369,17 +477,25 @@ def _trace_and_export( # graph. Exporting on CPU keeps devices consistent; we move the # ExportedProgram to the target device afterwards via the official # move_to_device_pass (avoids FakeTensor device-propagation errors). - dynamic_shapes = _build_dynamic_shapes( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam - ) + dynamic_shapes = _build_dynamic_shapes(*sample_inputs, has_spin=is_spin) exported = torch.export.export( traced, - (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), + sample_inputs, dynamic_shapes=dynamic_shapes, strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) + if is_spin: + # torch.export re-introduces shape-guard assertions even when + # the make_fx graph has none. The spin model's atom-doubling + # logic creates slice patterns that depend on (nall - nloc), + # producing guards like Ne(nall, nloc). These guards are + # spurious: the model is correct when nall == nloc (NoPBC). + # Strip them from the exported graph so the model can be + # used with any valid nall >= nloc. + _strip_shape_assertions(exported.graph_module) + # 7. Move the exported program to the target device if needed. if target_device.type != "cpu": from torch.export.passes import ( diff --git a/source/api_c/tests/test_deepspin_a.cc b/source/api_c/tests/test_deepspin_a.cc index 7f39638577..bfd30f1ec4 100644 --- a/source/api_c/tests/test_deepspin_a.cc +++ b/source/api_c/tests/test_deepspin_a.cc @@ -20,24 +20,24 @@ class TestInferDeepSpinA : public ::testing::Test { double box[9] = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; float boxf[9] = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; std::vector expected_e = { - -1.8626545229251095e+00, -2.3502165071948093e+00, -2.3500944968573521e+00, - -2.0688274735854710e+00, -2.3485113271625320e+00, -2.3489022338537353e+00, + 7.020322773655288e-03, 1.099636038493644e-01, 1.093176595258250e-01, + 4.865300228001564e-02, 1.096547558413134e-01, 1.099754340356070e-01, }; std::vector expected_f = { - 3.7989110974834261e-02, -6.8203560994098300e-02, 3.1554995279414300e-02, - -6.0769407958790114e-02, 5.6658432967656878e-03, 2.1814741358389407e-02, - 1.5027739412753049e-02, 6.2090755323245192e-02, -5.3346442187326704e-02, - -5.2134406995188787e-02, 4.0990812807417676e-02, -1.6987454510304811e-02, - -6.7153786204261134e-03, -5.3801784772022326e-02, 5.6707773168242034e-02, - 6.6602343186817375e-02, 1.3257934338691726e-02, -3.9743613108414025e-02, + 2.980086586841411e-02, 2.670602118823960e-03, -6.205408022135627e-03, + -7.946653268248605e-03, 4.217792180550986e-03, 1.822080579891798e-03, + -3.416928812442276e-03, -6.992749479424899e-03, 4.728288289346775e-03, + 5.049869641953204e-03, 1.550913149717830e-02, 1.801899070929784e-02, + -1.411871008097311e-02, -8.283139367982638e-03, -7.058623315726573e-03, + -9.368443348703582e-03, -7.121636949145627e-03, -1.130532824067422e-02, }; std::vector expected_fm = { - 4.8385521455777196e+00, 5.3158441514550137e-01, 1.0855626815019124e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2140862110260138e+00, 9.6823434985033552e-01, 1.0689000529371890e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.112646578617150e+00, -2.239176906831133e-01, -2.513101985142691e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.763058480695873e-02, 1.564710428447471e-02, -3.735332673990924e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, }; int natoms; double expected_tot_e; @@ -163,7 +163,7 @@ TEST_F(TestInferDeepSpinA, float_infer) { TEST_F(TestInferDeepSpinA, cutoff) { double cutoff = DP_DeepSpinGetCutoff(dp); - EXPECT_EQ(cutoff, 4.0); + EXPECT_EQ(cutoff, 6.0); } TEST_F(TestInferDeepSpinA, numb_types) { @@ -195,24 +195,25 @@ class TestInferDeepSpinANoPBC : public ::testing::Test { 0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.}; int atype[6] = {0, 1, 1, 0, 1, 1}; std::vector expected_e = { - -1.9136796509970209e+00, -2.3532121417832528e+00, - -2.3589759416772553e+00, -2.0689533840218703e+00, - -2.3485273598793084e+00, -2.3489022338537353e+00}; + 1.298915294144196e-02, 1.095576145701290e-01, 1.083914166945241e-01, + 4.932338375417146e-02, 1.099860785812512e-01, 1.100478936528533e-01, + }; std::vector expected_f = { - 5.2440246818294511e-02, -8.2643189092284075e-03, -1.6057110078610215e-02, - -5.2440246818295698e-02, 8.2643189092281334e-03, 1.6057110078610277e-02, - -1.6724663644564395e-03, 7.9346065821642349e-05, -2.5251632397208987e-04, - -5.6934098675373246e-02, 4.0398593044712161e-02, -1.6520316500527876e-02, - -7.9878577602028808e-03, -5.3736758888210570e-02, 5.6516778947603999e-02, - 6.6594422800032166e-02, 1.3258819777676990e-02, -3.9743946123104140e-02, + 1.300765817095240e-02, -1.593967210478553e-03, -3.196759265340465e-03, + -1.300765817095220e-02, 1.593967210478477e-03, 3.196759265340509e-03, + 9.196695370628910e-03, -9.044559760114149e-04, 5.658266727670325e-04, + 1.012744085443978e-02, 1.680427054831429e-02, 1.807036969424208e-02, + -1.133453822298158e-02, -8.941333904804914e-03, -6.627672717506913e-03, + -7.989598002087154e-03, -6.958480667497971e-03, -1.200852364950221e-02, }; std::vector expected_fm = { - 4.5904360179010135e+00, 6.2821415259365443e-01, 9.2483695213043082e-01, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2125967529512662e+00, 9.6807902483755459e-01, 1.0691011858092361e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00}; + -9.651705644713781e-01, -1.704326891282164e-01, -2.605677204117113e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.168034653189444e-02, 1.736913887115685e-02, -3.908906640474424e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + }; int natoms; double expected_tot_e; // std::vector expected_tot_v; diff --git a/source/api_c/tests/test_deepspin_a_hpp.cc b/source/api_c/tests/test_deepspin_a_hpp.cc index 29fc201b1b..821bf94bb7 100644 --- a/source/api_c/tests/test_deepspin_a_hpp.cc +++ b/source/api_c/tests/test_deepspin_a_hpp.cc @@ -20,24 +20,24 @@ class TestInferDeepSpinAHPP : public ::testing::Test { std::vector atype = {0, 1, 1, 0, 1, 1}; std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; std::vector expected_e = { - -1.8626545229251095e+00, -2.3502165071948093e+00, -2.3500944968573521e+00, - -2.0688274735854710e+00, -2.3485113271625320e+00, -2.3489022338537353e+00, + 7.020322773655288e-03, 1.099636038493644e-01, 1.093176595258250e-01, + 4.865300228001564e-02, 1.096547558413134e-01, 1.099754340356070e-01, }; std::vector expected_f = { - 3.7989110974834261e-02, -6.8203560994098300e-02, 3.1554995279414300e-02, - -6.0769407958790114e-02, 5.6658432967656878e-03, 2.1814741358389407e-02, - 1.5027739412753049e-02, 6.2090755323245192e-02, -5.3346442187326704e-02, - -5.2134406995188787e-02, 4.0990812807417676e-02, -1.6987454510304811e-02, - -6.7153786204261134e-03, -5.3801784772022326e-02, 5.6707773168242034e-02, - 6.6602343186817375e-02, 1.3257934338691726e-02, -3.9743613108414025e-02, + 2.980086586841411e-02, 2.670602118823960e-03, -6.205408022135627e-03, + -7.946653268248605e-03, 4.217792180550986e-03, 1.822080579891798e-03, + -3.416928812442276e-03, -6.992749479424899e-03, 4.728288289346775e-03, + 5.049869641953204e-03, 1.550913149717830e-02, 1.801899070929784e-02, + -1.411871008097311e-02, -8.283139367982638e-03, -7.058623315726573e-03, + -9.368443348703582e-03, -7.121636949145627e-03, -1.130532824067422e-02, }; std::vector expected_fm = { - 4.8385521455777196e+00, 5.3158441514550137e-01, 1.0855626815019124e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2140862110260138e+00, 9.6823434985033552e-01, 1.0689000529371890e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.112646578617150e+00, -2.239176906831133e-01, -2.513101985142691e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.763058480695873e-02, 1.564710428447471e-02, -3.735332673990924e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, }; unsigned int natoms; double expected_tot_e; @@ -167,24 +167,24 @@ class TestInferDeepSpinANoPbcHPP : public ::testing::Test { std::vector atype = {0, 1, 1, 0, 1, 1}; std::vector box = {}; std::vector expected_e = { - -1.9136796509970209e+00, -2.3532121417832528e+00, - -2.3589759416772553e+00, -2.0689533840218703e+00, - -2.3485273598793084e+00, -2.3489022338537353e+00}; + 1.298915294144196e-02, 1.095576145701290e-01, 1.083914166945241e-01, + 4.932338375417146e-02, 1.099860785812512e-01, 1.100478936528533e-01, + }; std::vector expected_f = { - 5.2440246818294511e-02, -8.2643189092284075e-03, -1.6057110078610215e-02, - -5.2440246818295698e-02, 8.2643189092281334e-03, 1.6057110078610277e-02, - -1.6724663644564395e-03, 7.9346065821642349e-05, -2.5251632397208987e-04, - -5.6934098675373246e-02, 4.0398593044712161e-02, -1.6520316500527876e-02, - -7.9878577602028808e-03, -5.3736758888210570e-02, 5.6516778947603999e-02, - 6.6594422800032166e-02, 1.3258819777676990e-02, -3.9743946123104140e-02, + 1.300765817095240e-02, -1.593967210478553e-03, -3.196759265340465e-03, + -1.300765817095220e-02, 1.593967210478477e-03, 3.196759265340509e-03, + 9.196695370628910e-03, -9.044559760114149e-04, 5.658266727670325e-04, + 1.012744085443978e-02, 1.680427054831429e-02, 1.807036969424208e-02, + -1.133453822298158e-02, -8.941333904804914e-03, -6.627672717506913e-03, + -7.989598002087154e-03, -6.958480667497971e-03, -1.200852364950221e-02, }; std::vector expected_fm = { - 4.5904360179010135e+00, 6.2821415259365443e-01, 9.2483695213043082e-01, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2125967529512662e+00, 9.6807902483755459e-01, 1.0691011858092361e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, + -9.651705644713781e-01, -1.704326891282164e-01, -2.605677204117113e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.168034653189444e-02, 1.736913887115685e-02, -3.908906640474424e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, }; unsigned int natoms; double expected_tot_e; diff --git a/source/api_cc/include/DeepSpin.h b/source/api_cc/include/DeepSpin.h index 07a447cd74..a4f38461ca 100644 --- a/source/api_cc/include/DeepSpin.h +++ b/source/api_cc/include/DeepSpin.h @@ -160,6 +160,13 @@ class DeepSpinBackend : public DeepBaseModelBackend { const std::vector& aparam, const bool atomic) = 0; /** @} */ + + /** + * @brief Get the per-type use_spin flags. + * @return A vector of booleans indicating which atom types have spin enabled. + * Empty if the backend does not provide this information. + **/ + virtual std::vector get_use_spin() const { return {}; }; }; /** @@ -414,6 +421,13 @@ class DeepSpin : public DeepBaseModel { const std::vector& fparam = std::vector(), const std::vector& aparam = std::vector()); /** @} */ + + /** + * @brief Get the per-type use_spin flags. + * @return A vector of booleans indicating which atom types have spin enabled. + **/ + std::vector get_use_spin() const; + protected: std::shared_ptr dp; }; @@ -610,6 +624,12 @@ class DeepSpinModelDevi : public DeepBaseModelDevi { const std::vector& fparam = std::vector(), const std::vector& aparam = std::vector()); + /** + * @brief Get the per-type use_spin flags from the first model. + * @return A vector of booleans indicating which atom types have spin enabled. + **/ + std::vector get_use_spin() const; + protected: std::vector> dps; }; diff --git a/source/api_cc/include/DeepSpinPT.h b/source/api_cc/include/DeepSpinPT.h index 1cc35997e9..6570887d88 100644 --- a/source/api_cc/include/DeepSpinPT.h +++ b/source/api_cc/include/DeepSpinPT.h @@ -183,6 +183,11 @@ class DeepSpinPT : public DeepSpinBackend { assert(inited); return aparam_nall; }; + /** + * @brief Get the per-type use_spin flags. + * @return Empty vector — .pth backend does not store use_spin in metadata. + **/ + std::vector get_use_spin() const override { return {}; }; /** * @brief Check if the model has default frame parameters. * @return true if the model has default frame parameters. diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h new file mode 100644 index 0000000000..af108c7690 --- /dev/null +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#pragma once + +#ifdef BUILD_PYTORCH +#if __has_include() +#define BUILD_PT_EXPT_SPIN 1 +#else +#define BUILD_PT_EXPT_SPIN 0 +#endif + +#if BUILD_PT_EXPT_SPIN + +#include + +#include "DeepSpin.h" + +namespace torch::inductor { +class AOTIModelPackageLoader; +} + +namespace deepmd { +/** + * @brief PyTorch Exportable (AOTInductor .pt2) implementation for Deep + *Potential with spin. + **/ +class DeepSpinPTExpt : public DeepSpinBackend { + public: + DeepSpinPTExpt(); + virtual ~DeepSpinPTExpt(); + DeepSpinPTExpt(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + void init(const std::string& model, + const int& gpu_rank = 0, + const std::string& file_content = ""); + + private: + /** + * @brief Evaluate with nlist (LAMMPS path — extended forces). + **/ + template + void compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + /** + * @brief Evaluate without nlist (standalone — builds nlist, folds back). + **/ + template + void compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + + public: + double cutoff() const { + assert(inited); + return rcut; + }; + int numb_types() const { + assert(inited); + return ntypes; + }; + int numb_types_spin() const { + assert(inited); + return ntypes_spin; + }; + int dim_fparam() const { + assert(inited); + return dfparam; + }; + int dim_aparam() const { + assert(inited); + return daparam; + }; + void get_type_map(std::string& type_map); + bool is_aparam_nall() const { + assert(inited); + return aparam_nall; + }; + std::vector get_use_spin() const override { + assert(inited); + return use_spin_; + }; + bool has_default_fparam() const { + assert(inited); + return has_default_fparam_; + }; + + // forward to template class + void computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + + private: + bool inited; + int ntypes; + int ntypes_spin; + int dfparam; + int daparam; + bool aparam_nall; + bool has_default_fparam_; + std::vector default_fparam_; + std::vector use_spin_; + double rcut; + int gpu_id; + bool gpu_enabled; + std::vector type_map; + std::vector output_keys; + bool mixed_types; + std::vector sel; + NeighborListData nlist_data; + std::unique_ptr loader; + + std::vector run_model(const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& spin, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam); + + void extract_outputs(std::map& output_map, + const std::vector& flat_outputs); + + void translate_error(std::function f); +}; + +} // namespace deepmd + +#endif // BUILD_PT_EXPT_SPIN +#endif // BUILD_PYTORCH diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index b30405ba55..c1f3d9d674 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -12,550 +12,17 @@ #include "SimulationRegion.h" #include "common.h" +#include "commonPTExpt.h" #include "device.h" #include "errors.h" #include "neighbor_list.h" -// Minimal JSON value parser for reading metadata from .pt2 archives. -// Supports: strings, numbers, booleans, arrays, objects. -// This avoids adding a dependency on nlohmann/json for the api_cc library. -namespace { - -struct JsonValue; -using JsonObject = std::map; -using JsonArray = std::vector; - -struct JsonValue { - enum Type { Null, Bool, Number, String, Array, Object }; - Type type = Null; - bool bool_val = false; - double num_val = 0.0; - std::string str_val; - JsonArray arr_val; - JsonObject obj_val; - - std::string as_string() const { return str_val; } - double as_double() const { return num_val; } - int as_int() const { return static_cast(num_val); } - bool as_bool() const { return bool_val; } - const JsonArray& as_array() const { return arr_val; } - const JsonObject& as_object() const { return obj_val; } - const JsonValue& operator[](const std::string& key) const { - return obj_val.at(key); - } - const JsonValue& operator[](size_t idx) const { return arr_val.at(idx); } - bool has(const std::string& key) const { - return obj_val.find(key) != obj_val.end(); - } -}; - -class JsonParser { - public: - explicit JsonParser(const std::string& s) : s_(s), pos_(0) {} - JsonValue parse() { - skip_ws(); - auto val = parse_value(); - return val; - } - - private: - const std::string& s_; - size_t pos_; - - char peek() const { return pos_ < s_.size() ? s_[pos_] : '\0'; } - char get() { - if (pos_ >= s_.size()) { - throw std::runtime_error("JSON parse error: unexpected end of input"); - } - return s_[pos_++]; - } - void skip_ws() { - while (pos_ < s_.size() && (s_[pos_] == ' ' || s_[pos_] == '\t' || - s_[pos_] == '\n' || s_[pos_] == '\r')) { - ++pos_; - } - } - - JsonValue parse_value() { - skip_ws(); - char c = peek(); - if (c == '"') { - return parse_string_val(); - } else if (c == '{') { - return parse_object(); - } else if (c == '[') { - return parse_array(); - } else if (c == 't' || c == 'f') { - return parse_bool(); - } else if (c == 'n') { - return parse_null(); - } else { - return parse_number(); - } - } - - std::string parse_string_raw() { - get(); // consume '"' - std::string result; - while (pos_ < s_.size() && peek() != '"') { - if (peek() == '\\') { - get(); - char esc = get(); - switch (esc) { - case '"': - result += '"'; - break; - case '\\': - result += '\\'; - break; - case '/': - result += '/'; - break; - case 'n': - result += '\n'; - break; - case 't': - result += '\t'; - break; - case 'r': - result += '\r'; - break; - default: - result += esc; - break; - } - } else { - result += get(); - } - } - get(); // consume closing '"' - return result; - } - - JsonValue parse_string_val() { - JsonValue v; - v.type = JsonValue::String; - v.str_val = parse_string_raw(); - return v; - } - - JsonValue parse_number() { - size_t start = pos_; - if (peek() == '-') { - get(); - } - while (pos_ < s_.size() && - (std::isdigit(s_[pos_]) || s_[pos_] == '.' || s_[pos_] == 'e' || - s_[pos_] == 'E' || s_[pos_] == '+' || s_[pos_] == '-')) { - // handle sign only if after e/E - if ((s_[pos_] == '+' || s_[pos_] == '-') && pos_ > start && - s_[pos_ - 1] != 'e' && s_[pos_ - 1] != 'E') { - break; - } - ++pos_; - } - JsonValue v; - v.type = JsonValue::Number; - try { - v.num_val = std::stod(s_.substr(start, pos_ - start)); - } catch (const std::exception& e) { - throw std::runtime_error("JSON parse error: invalid number at position " + - std::to_string(start)); - } - return v; - } - - JsonValue parse_bool() { - JsonValue v; - v.type = JsonValue::Bool; - if (s_.substr(pos_, 4) == "true") { - v.bool_val = true; - pos_ += 4; - } else if (s_.substr(pos_, 5) == "false") { - v.bool_val = false; - pos_ += 5; - } else { - throw std::runtime_error( - "JSON parse error: expected 'true' or 'false' at position " + - std::to_string(pos_)); - } - return v; - } - - JsonValue parse_null() { - if (s_.substr(pos_, 4) != "null") { - throw std::runtime_error( - "JSON parse error: expected 'null' at position " + - std::to_string(pos_)); - } - pos_ += 4; - return JsonValue(); - } - - JsonValue parse_array() { - get(); // consume '[' - JsonValue v; - v.type = JsonValue::Array; - skip_ws(); - if (peek() == ']') { - get(); - return v; - } - while (true) { - v.arr_val.push_back(parse_value()); - skip_ws(); - if (peek() == ',') { - get(); - } else { - break; - } - } - skip_ws(); - get(); // consume ']' - return v; - } - - JsonValue parse_object() { - get(); // consume '{' - JsonValue v; - v.type = JsonValue::Object; - skip_ws(); - if (peek() == '}') { - get(); - return v; - } - while (true) { - skip_ws(); - std::string key = parse_string_raw(); - skip_ws(); - get(); // consume ':' - v.obj_val[key] = parse_value(); - skip_ws(); - if (peek() == ',') { - get(); - } else { - break; - } - } - skip_ws(); - get(); // consume '}' - return v; - } -}; - -JsonValue parse_json(const std::string& s) { - JsonParser parser(s); - return parser.parse(); -} - -// Read a file from a ZIP archive using caffe2::serialize::PyTorchStreamReader. -// We avoid depending on caffe2 headers by using a simpler approach: -// just read the file directly as a ZIP file. -std::string read_zip_entry(const std::string& zip_path, - const std::string& entry_name) { - // Use a simple approach: scan all possible prefixed names. - // .pt2 files from AOTInductor store extra files at "extra/" - // within the ZIP archive. - std::ifstream ifs(zip_path, std::ios::binary); - if (!ifs.is_open()) { - throw deepmd::deepmd_exception("Cannot open file: " + zip_path); - } - - // Read entire file - std::string content((std::istreambuf_iterator(ifs)), - std::istreambuf_iterator()); - ifs.close(); - - // Simple ZIP central directory parser - // Find End of Central Directory Record (EOCD) - // EOCD signature: 0x06054b50 - // Minimum EOCD size is 22 bytes - if (content.size() < 22) { - throw deepmd::deepmd_exception( - "File too small to be a valid ZIP archive: " + zip_path); - } - size_t eocd_pos = std::string::npos; - for (int64_t i = static_cast(content.size()) - 22; - i >= 0 && static_cast(i) + 3 < content.size(); --i) { - if (content[i] == 0x50 && content[i + 1] == 0x4b && - content[i + 2] == 0x05 && content[i + 3] == 0x06) { - eocd_pos = static_cast(i); - break; - } - } - if (eocd_pos == std::string::npos) { - throw deepmd::deepmd_exception("Invalid ZIP file: " + zip_path); - } - - // Parse EOCD to get central directory offset and size - auto read_u16 = [&](size_t offset) -> uint16_t { - return static_cast(static_cast(content[offset])) | - (static_cast( - static_cast(content[offset + 1])) - << 8); - }; - auto read_u32 = [&](size_t offset) -> uint32_t { - return static_cast(static_cast(content[offset])) | - (static_cast( - static_cast(content[offset + 1])) - << 8) | - (static_cast( - static_cast(content[offset + 2])) - << 16) | - (static_cast( - static_cast(content[offset + 3])) - << 24); - }; - - uint64_t num_entries = read_u16(eocd_pos + 10); - uint64_t cd_offset = read_u32(eocd_pos + 16); - - // If this is a ZIP64 file, look for the ZIP64 EOCD locator - if (cd_offset == 0xFFFFFFFF || num_entries == 0xFFFF) { - // ZIP64 EOCD locator signature: 0x07064b50 - // It should be right before the EOCD (20 bytes) - if (eocd_pos < 20) { - throw deepmd::deepmd_exception( - "Invalid ZIP64 file (truncated EOCD locator): " + zip_path); - } - size_t zip64_locator_pos = eocd_pos - 20; - if (content[zip64_locator_pos] == 0x50 && - content[zip64_locator_pos + 1] == 0x4b && - content[zip64_locator_pos + 2] == 0x06 && - content[zip64_locator_pos + 3] == 0x07) { - // Read ZIP64 EOCD offset from locator - uint64_t zip64_eocd_offset = 0; - for (int b = 0; b < 8; ++b) { - zip64_eocd_offset |= static_cast(static_cast( - content[zip64_locator_pos + 8 + b])) - << (8 * b); - } - // Parse ZIP64 EOCD - // ZIP64 EOCD signature: 0x06064b50 - size_t z64_pos = static_cast(zip64_eocd_offset); - if (z64_pos + 56 > content.size()) { - throw deepmd::deepmd_exception( - "Invalid ZIP64 file (truncated EOCD record): " + zip_path); - } - // num entries at offset 32 (8 bytes in ZIP64) - num_entries = 0; - for (int b = 0; b < 8; ++b) { - num_entries |= static_cast(static_cast( - content[z64_pos + 32 + b])) - << (8 * b); - } - // cd offset at offset 48 (8 bytes in ZIP64) - cd_offset = 0; - for (int b = 0; b < 8; ++b) { - cd_offset |= static_cast( - static_cast(content[z64_pos + 48 + b])) - << (8 * b); - } - } - } - - // Iterate central directory entries - size_t pos = cd_offset; - for (uint64_t i = 0; i < num_entries; ++i) { - // Central directory entry signature: 0x02014b50 - if (pos + 46 > content.size()) { - break; - } - uint16_t name_len = read_u16(pos + 28); - uint16_t extra_len = read_u16(pos + 30); - uint16_t comment_len = read_u16(pos + 32); - uint32_t compressed_size_u32 = read_u32(pos + 20); - uint32_t uncompressed_size_u32 = read_u32(pos + 24); - uint32_t local_header_offset_u32 = read_u32(pos + 42); - - // Use 64-bit types so ZIP64 values are not truncated - uint64_t compressed_size = compressed_size_u32; - uint64_t uncompressed_size = uncompressed_size_u32; - uint64_t local_header_offset = local_header_offset_u32; - - std::string name = content.substr(pos + 46, name_len); - - // Handle ZIP64 extra field for large files - if (uncompressed_size_u32 == 0xFFFFFFFF || - local_header_offset_u32 == 0xFFFFFFFF) { - // Parse ZIP64 extended information extra field - size_t extra_pos = pos + 46 + name_len; - size_t extra_end = extra_pos + extra_len; - while (extra_pos + 4 <= extra_end) { - uint16_t field_id = read_u16(extra_pos); - uint16_t field_size = read_u16(extra_pos + 2); - if (field_id == 0x0001) { // ZIP64 extra field - size_t field_data = extra_pos + 4; - int offset_in_field = 0; - if (uncompressed_size_u32 == 0xFFFFFFFF) { - uncompressed_size = 0; - for (int b = 0; b < 8; ++b) { - uncompressed_size |= - static_cast(static_cast( - content[field_data + offset_in_field + b])) - << (8 * b); - } - offset_in_field += 8; - } - if (compressed_size_u32 == 0xFFFFFFFF) { - compressed_size = 0; - for (int b = 0; b < 8; ++b) { - compressed_size |= - static_cast(static_cast( - content[field_data + offset_in_field + b])) - << (8 * b); - } - offset_in_field += 8; - } - if (local_header_offset_u32 == 0xFFFFFFFF) { - local_header_offset = 0; - for (int b = 0; b < 8; ++b) { - local_header_offset |= - static_cast(static_cast( - content[field_data + offset_in_field + b])) - << (8 * b); - } - } - break; - } - extra_pos += 4 + field_size; - } - } - - // Match exact name or suffix (handles archives with directory prefixes, - // e.g. "model/extra/metadata.json" matches "extra/metadata.json") - bool match = (name == entry_name); - if (!match && name.size() > entry_name.size()) { - size_t suffix_start = name.size() - entry_name.size(); - if (name[suffix_start - 1] == '/' && - name.substr(suffix_start) == entry_name) { - match = true; - } - } - if (match) { - // Read from local file header - uint16_t local_name_len = read_u16(local_header_offset + 26); - uint16_t local_extra_len = read_u16(local_header_offset + 28); - size_t data_offset = - local_header_offset + 30 + local_name_len + local_extra_len; - return content.substr(data_offset, uncompressed_size); - } - - pos += 46 + name_len + extra_len + comment_len; - } - - throw deepmd::deepmd_exception("Entry not found in ZIP: " + entry_name + - " in " + zip_path); -} - -} // namespace +using deepmd::ptexpt::buildTypeSortedNlist; +using deepmd::ptexpt::parse_json; +using deepmd::ptexpt::read_zip_entry; using namespace deepmd; -/** - * @brief Convert a raw neighbor list to the sel-limited format expected by the - * pt_expt model. - * - * For non-mixed-type models (distinguish_types=true): the nlist has shape - * (nframes, nloc, sum(sel)), where the first sel[0] entries are neighbors of - * type 0, the next sel[1] are type 1, etc. Within each type group neighbors - * are sorted by distance (ascending). - * - * For mixed-type models (distinguish_types=false): all neighbors go into a - * single group sorted by distance, truncated to sum(sel). - * - * Missing slots are filled with -1. - * - * @param[in] raw_nlist Raw neighbor list (nloc x variable-nnei). - * @param[in] coord_ext Extended coordinates (nall x 3), flat. - * @param[in] atype_ext Extended atom types (nall). - * @param[in] sel Per-type neighbor selection counts. - * @param[in] nloc Number of local atoms. - * @param[in] mixed_types Whether the model uses mixed types - * (distinguish_types=false). - * @return Tensor of shape (1, nloc, sum(sel)), dtype int64. - */ -template -torch::Tensor buildTypeSortedNlist( - const std::vector>& raw_nlist, - const std::vector& coord_ext, - const std::vector& atype_ext, - const std::vector& sel, - int nloc, - bool mixed_types) { - int nsel = 0; - for (auto s : sel) { - nsel += s; - } - int ntypes = sel.size(); - std::vector result(static_cast(nloc) * nsel, -1); - - for (int ii = 0; ii < nloc; ++ii) { - const auto& neighbors = raw_nlist[ii]; - VALUETYPE xi = coord_ext[ii * 3 + 0]; - VALUETYPE yi = coord_ext[ii * 3 + 1]; - VALUETYPE zi = coord_ext[ii * 3 + 2]; - int offset = ii * nsel; - - if (mixed_types) { - // Mixed-type: all neighbors in one group, sort by distance - std::vector> all_neighbors; - for (int jj : neighbors) { - if (jj < 0) { - continue; - } - int jtype = atype_ext[jj]; - if (jtype < 0) { - continue; // skip invalid atoms - } - VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; - VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; - VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; - VALUETYPE rr = dx * dx + dy * dy + dz * dz; - all_neighbors.emplace_back(rr, jj); - } - std::sort(all_neighbors.begin(), all_neighbors.end()); - int count = std::min(static_cast(all_neighbors.size()), nsel); - for (int kk = 0; kk < count; ++kk) { - result[offset + kk] = all_neighbors[kk].second; - } - } else { - // Non-mixed-type: group by type, sort each group - std::vector>> by_type(ntypes); - for (int jj : neighbors) { - if (jj < 0) { - continue; - } - int jtype = atype_ext[jj]; - if (jtype < 0 || jtype >= ntypes) { - continue; // skip virtual/unknown type atoms - } - VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; - VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; - VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; - VALUETYPE rr = dx * dx + dy * dy + dz * dz; - by_type[jtype].emplace_back(rr, jj); - } - int col = 0; - for (int tt = 0; tt < ntypes; ++tt) { - auto& group = by_type[tt]; - std::sort(group.begin(), group.end()); - int count = std::min(static_cast(group.size()), sel[tt]); - for (int kk = 0; kk < count; ++kk) { - result[offset + col + kk] = group[kk].second; - } - col += sel[tt]; - } - } - } - - torch::Tensor tensor = - torch::from_blob(result.data(), {1, nloc, nsel}, - torch::TensorOptions().dtype(torch::kInt64)) - .clone(); - return tensor; -} - void DeepPotPTExpt::translate_error(std::function f) { try { f(); @@ -1169,31 +636,6 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, } } -template void DeepPotPTExpt::compute>( - std::vector& ener, - std::vector& force, - std::vector& virial, - std::vector& atom_energy, - std::vector& atom_virial, - const std::vector& coord, - const std::vector& atype, - const std::vector& box, - const std::vector& fparam, - const std::vector& aparam, - const bool atomic); -template void DeepPotPTExpt::compute>( - std::vector& ener, - std::vector& force, - std::vector& virial, - std::vector& atom_energy, - std::vector& atom_virial, - const std::vector& coord, - const std::vector& atype, - const std::vector& box, - const std::vector& fparam, - const std::vector& aparam, - const bool atomic); - template void DeepPotPTExpt::compute_nframes(ENERGYVTYPE& ener, std::vector& force, @@ -1253,6 +695,31 @@ void DeepPotPTExpt::compute_nframes(ENERGYVTYPE& ener, } } +template void DeepPotPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); +template void DeepPotPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + void DeepPotPTExpt::get_type_map(std::string& type_map_str) { for (const auto& t : type_map) { type_map_str += t; diff --git a/source/api_cc/src/DeepSpin.cc b/source/api_cc/src/DeepSpin.cc index eb37828410..0ed694f280 100644 --- a/source/api_cc/src/DeepSpin.cc +++ b/source/api_cc/src/DeepSpin.cc @@ -11,6 +11,7 @@ #endif #ifdef BUILD_PYTORCH #include "DeepSpinPT.h" +#include "DeepSpinPTExpt.h" #endif #include "device.h" @@ -48,6 +49,14 @@ void DeepSpin::init(const std::string& model, dp = std::make_shared(model, gpu_rank, file_content); #else throw deepmd::deepmd_exception("PyTorch backend is not built"); +#endif + } else if (deepmd::DPBackend::PyTorchExportable == backend) { +#if defined(BUILD_PYTORCH) && BUILD_PT_EXPT_SPIN + dp = + std::make_shared(model, gpu_rank, file_content); +#else + throw deepmd::deepmd_exception( + "PyTorch Exportable backend is not available"); #endif } else if (deepmd::DPBackend::Paddle == backend) { throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet"); @@ -442,6 +451,13 @@ template void DeepSpin::compute(std::vector& dener, const std::vector& fparam, const std::vector& aparam_); +std::vector DeepSpin::get_use_spin() const { + if (dp) { + return dp->get_use_spin(); + } + return {}; +} + DeepSpinModelDevi::DeepSpinModelDevi() { inited = false; numb_models = 0; @@ -723,3 +739,10 @@ template void DeepSpinModelDevi::compute( const int& ago, const std::vector& fparam, const std::vector& aparam); + +std::vector DeepSpinModelDevi::get_use_spin() const { + if (!dps.empty()) { + return dps[0]->get_use_spin(); + } + return {}; +} diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc new file mode 100644 index 0000000000..ae4ef423ed --- /dev/null +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -0,0 +1,810 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include "DeepSpinPTExpt.h" + +#if defined(BUILD_PYTORCH) && BUILD_PT_EXPT_SPIN +#include + +#include +#include +#include +#include +#include + +#include "SimulationRegion.h" +#include "common.h" +#include "commonPTExpt.h" +#include "device.h" +#include "errors.h" +#include "neighbor_list.h" + +using deepmd::ptexpt::buildTypeSortedNlist; +using deepmd::ptexpt::parse_json; +using deepmd::ptexpt::read_zip_entry; + +using namespace deepmd; + +void DeepSpinPTExpt::translate_error(std::function f) { + try { + f(); + } catch (const c10::Error& e) { + throw deepmd::deepmd_exception( + "DeePMD-kit PyTorch Exportable backend error: " + + std::string(e.what())); + } catch (const deepmd::deepmd_exception&) { + throw; + } catch (const std::exception& e) { + throw deepmd::deepmd_exception( + "DeePMD-kit PyTorch Exportable backend error: " + + std::string(e.what())); + } +} + +DeepSpinPTExpt::DeepSpinPTExpt() : inited(false) {} + +DeepSpinPTExpt::DeepSpinPTExpt(const std::string& model, + const int& gpu_rank, + const std::string& file_content) + : inited(false) { + try { + translate_error([&] { init(model, gpu_rank, file_content); }); + } catch (...) { + throw; + } +} + +void DeepSpinPTExpt::init(const std::string& model, + const int& gpu_rank, + const std::string& file_content) { + if (inited) { + std::cerr << "WARNING: deepmd-kit should not be initialized twice, do " + "nothing at the second call of initializer" + << std::endl; + return; + } + + if (!file_content.empty()) { + throw deepmd::deepmd_exception( + "In-memory file_content loading is not supported for .pt2 models. " + "Please provide a file path instead."); + } + + int gpu_num = torch::cuda::device_count(); + gpu_id = (gpu_num > 0) ? (gpu_rank % gpu_num) : 0; + gpu_enabled = torch::cuda::is_available(); + + std::string device_str; + if (!gpu_enabled) { + device_str = "cpu"; + std::cout << "load model from: " << model << " to cpu" << std::endl; + } else { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + DPErrcheck(DPSetDevice(gpu_id)); +#endif + device_str = "cuda:" + std::to_string(gpu_id); + std::cout << "load model from: " << model << " to gpu " << gpu_id + << std::endl; + } + + // Read metadata from the .pt2 ZIP archive + std::string metadata_json = read_zip_entry(model, "extra/metadata.json"); + + auto metadata = parse_json(metadata_json); + rcut = metadata["rcut"].as_double(); + ntypes = static_cast(metadata["type_map"].as_array().size()); + dfparam = metadata["dim_fparam"].as_int(); + daparam = metadata["dim_aparam"].as_int(); + mixed_types = metadata["mixed_types"].as_bool(); + aparam_nall = false; + + // Spin-specific metadata + if (metadata.obj_val.count("ntypes_spin")) { + ntypes_spin = metadata["ntypes_spin"].as_int(); + } else { + ntypes_spin = 0; + } + + use_spin_.clear(); + if (metadata.obj_val.count("use_spin")) { + for (const auto& v : metadata["use_spin"].as_array()) { + use_spin_.push_back(v.as_bool()); + } + } + + if (metadata.obj_val.count("has_default_fparam")) { + has_default_fparam_ = metadata["has_default_fparam"].as_bool(); + } else { + has_default_fparam_ = false; + } + if (has_default_fparam_) { + if (metadata.obj_val.count("default_fparam")) { + default_fparam_.clear(); + for (const auto& v : metadata["default_fparam"].as_array()) { + default_fparam_.push_back(v.as_double()); + } + if (static_cast(default_fparam_.size()) != dfparam) { + throw deepmd::deepmd_exception( + "default_fparam length (" + std::to_string(default_fparam_.size()) + + ") does not match dim_fparam (" + std::to_string(dfparam) + ")."); + } + } else { + std::cerr << "WARNING: Model has has_default_fparam=true but " + "default_fparam values are missing from metadata." + << std::endl; + } + } + + type_map.clear(); + for (const auto& v : metadata["type_map"].as_array()) { + type_map.push_back(v.as_string()); + } + + sel.clear(); + for (const auto& v : metadata["sel"].as_array()) { + sel.push_back(v.as_int()); + } + + output_keys.clear(); + for (const auto& v : metadata["output_keys"].as_array()) { + output_keys.push_back(v.as_string()); + } + + // Load the AOTInductor model package + loader = std::make_unique( + model, "model", false, 1, + gpu_enabled ? static_cast(gpu_id) + : static_cast(-1)); + + int num_intra_nthreads, num_inter_nthreads; + get_env_nthreads(num_intra_nthreads, num_inter_nthreads); + if (num_inter_nthreads) { + try { + at::set_num_interop_threads(num_inter_nthreads); + } catch (...) { + } + } + if (num_intra_nthreads) { + try { + at::set_num_threads(num_intra_nthreads); + } catch (...) { + } + } + + inited = true; +} + +DeepSpinPTExpt::~DeepSpinPTExpt() {} + +std::vector DeepSpinPTExpt::run_model( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& spin, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam) { + // Spin model has 7 positional args: coord, atype, spin, nlist, mapping, + // fparam, aparam Only include fparam/aparam if the model was exported with + // them. + std::vector inputs = {coord, atype, spin, nlist, mapping}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + return loader->run(inputs); +} + +void DeepSpinPTExpt::extract_outputs( + std::map& output_map, + const std::vector& flat_outputs) { + if (flat_outputs.size() != output_keys.size()) { + throw deepmd::deepmd_exception( + "Model returned " + std::to_string(flat_outputs.size()) + + " outputs but expected " + std::to_string(output_keys.size()) + + " (from metadata.json)"); + } + for (size_t i = 0; i < output_keys.size(); ++i) { + output_map[output_keys[i]] = flat_outputs[i]; + } +} + +// ============================================================================ +// LAMMPS path: compute with pre-built neighbor list +// ============================================================================ + +template +void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + torch::Device device(torch::kCUDA, gpu_id); + if (!gpu_enabled) { + device = torch::Device(torch::kCPU); + } + int natoms = atype.size(); + auto options = torch::TensorOptions().dtype(torch::kFloat64); + torch::ScalarType floatType = torch::kFloat64; + if (std::is_same::value) { + floatType = torch::kFloat32; + } + auto int_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); + + // Select real atoms (filter NULL-type atoms) + std::vector dcoord, dforce, dforce_mag, aparam_, datom_energy, + datom_virial; + std::vector datype, fwd_map, bkw_map; + int nghost_real, nall_real, nloc_real; + int nall = natoms; + select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map, + bkw_map, nall_real, nloc_real, coord, atype, aparam, + nghost, ntypes, 1, daparam, nall, aparam_nall); + int nloc = nall_real - nghost_real; + int nframes = 1; + + // Build spin tensor for real atoms using bkw_map + std::vector dspin(static_cast(nall_real) * 3); + for (int ii = 0; ii < nall_real; ++ii) { + for (int dd = 0; dd < 3; ++dd) { + dspin[static_cast(ii) * 3 + dd] = + spin[static_cast(bkw_map[ii]) * 3 + dd]; + } + } + + // Convert coord and spin to float64 + std::vector coord_d(dcoord.begin(), dcoord.end()); + std::vector spin_d(dspin.begin(), dspin.end()); + at::Tensor coord_Tensor = + torch::from_blob(coord_d.data(), {1, nall_real, 3}, options) + .clone() + .to(device); + at::Tensor spin_Tensor = + torch::from_blob(spin_d.data(), {1, nall_real, 3}, options) + .clone() + .to(device); + std::vector atype_64(datype.begin(), datype.end()); + at::Tensor atype_Tensor = + torch::from_blob(atype_64.data(), {1, nall_real}, int_option) + .clone() + .to(device); + + if (ago == 0) { + nlist_data.copy_from_nlist(lmp_list, nall - nghost); + nlist_data.shuffle_exclude_empty(fwd_map); + nlist_data.padding(); + } + at::Tensor firstneigh_tensor = + buildTypeSortedNlist(nlist_data.jlist, coord_d, datype, sel, nloc, + mixed_types) + .to(device); + + // Build mapping tensor + at::Tensor mapping_tensor; + if (lmp_list.mapping) { + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); + } else { + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = ii; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); + } + + // Build fparam/aparam tensors + auto valuetype_options = std::is_same::value + ? torch::TensorOptions().dtype(torch::kFloat32) + : torch::TensorOptions().dtype(torch::kFloat64); + at::Tensor fparam_tensor; + if (!fparam.empty()) { + fparam_tensor = + torch::from_blob(const_cast(fparam.data()), + {1, static_cast(fparam.size())}, + valuetype_options) + .to(torch::kFloat64) + .to(device); + } else if (has_default_fparam_ && !default_fparam_.empty()) { + fparam_tensor = + torch::from_blob(const_cast(default_fparam_.data()), + {1, static_cast(default_fparam_.size())}, + options) + .clone() + .to(device); + } else if (has_default_fparam_) { + throw deepmd::deepmd_exception( + "fparam is empty and default_fparam values are missing from the .pt2 " + "metadata. Please regenerate the model or provide fparam explicitly."); + } else { + fparam_tensor = torch::zeros({0}, options).to(device); + } + + at::Tensor aparam_tensor; + if (!aparam_.empty()) { + aparam_tensor = + torch::from_blob( + const_cast(aparam_.data()), + {1, nloc, static_cast(aparam_.size()) / nloc}, + valuetype_options) + .to(torch::kFloat64) + .to(device); + } else { + aparam_tensor = torch::zeros({0}, options).to(device); + } + + // Run the .pt2 model (7 args for spin) + auto flat_outputs = + run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, + mapping_tensor, fparam_tensor, aparam_tensor); + + std::map output_map; + extract_outputs(output_map, flat_outputs); + + // Extract energy + torch::Tensor flat_energy_ = + output_map["energy_redu"].view({-1}).to(torch::kCPU); + ener.assign(flat_energy_.data_ptr(), + flat_energy_.data_ptr() + flat_energy_.numel()); + + // Extract force: energy_derv_r (nf, nall, 1, 3) -> (nf, nall, 3) + torch::Tensor force_tensor = + output_map["energy_derv_r"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_force_ = force_tensor.to(torch::kCPU); + dforce.assign(cpu_force_.data_ptr(), + cpu_force_.data_ptr() + cpu_force_.numel()); + + // Extract force_mag: energy_derv_r_mag (nf, nall, 1, 3) -> (nf, nall, 3) + torch::Tensor force_mag_tensor = + output_map["energy_derv_r_mag"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_force_mag_ = force_mag_tensor.to(torch::kCPU); + dforce_mag.assign( + cpu_force_mag_.data_ptr(), + cpu_force_mag_.data_ptr() + cpu_force_mag_.numel()); + + // Extract virial + torch::Tensor virial_tensor = + output_map["energy_derv_c_redu"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_virial_ = virial_tensor.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); + + // bkw map: map force from real atoms back to full atom list + force.resize(static_cast(nframes) * fwd_map.size() * 3); + force_mag.resize(static_cast(nframes) * fwd_map.size() * 3); + select_map(force, dforce, bkw_map, 3, nframes, fwd_map.size(), + nall_real); + select_map(force_mag, dforce_mag, bkw_map, 3, nframes, + fwd_map.size(), nall_real); + + if (atomic) { + torch::Tensor atom_energy_tensor = + output_map["energy"].view({-1}).to(floatType); + torch::Tensor cpu_atom_energy_ = atom_energy_tensor.to(torch::kCPU); + datom_energy.resize(nall_real, 0.0); + datom_energy.assign( + cpu_atom_energy_.data_ptr(), + cpu_atom_energy_.data_ptr() + cpu_atom_energy_.numel()); + + torch::Tensor atom_virial_tensor = + output_map["energy_derv_c"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_atom_virial_ = atom_virial_tensor.to(torch::kCPU); + datom_virial.assign( + cpu_atom_virial_.data_ptr(), + cpu_atom_virial_.data_ptr() + cpu_atom_virial_.numel()); + + atom_energy.resize(static_cast(nframes) * fwd_map.size()); + atom_virial.resize(static_cast(nframes) * fwd_map.size() * 9); + select_map(atom_energy, datom_energy, bkw_map, 1, nframes, + fwd_map.size(), nall_real); + select_map(atom_virial, datom_virial, bkw_map, 9, nframes, + fwd_map.size(), nall_real); + } +} + +template void DeepSpinPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); +template void DeepSpinPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& lmp_list, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + +// ============================================================================ +// Standalone path: compute without pre-built neighbor list +// ============================================================================ + +template +void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + int natoms = atype.size(); + + torch::Device device(torch::kCUDA, gpu_id); + if (!gpu_enabled) { + device = torch::Device(torch::kCPU); + } + + auto options = torch::TensorOptions().dtype(torch::kFloat64); + torch::ScalarType floatType = torch::kFloat64; + if (std::is_same::value) { + floatType = torch::kFloat32; + } + auto int_options = torch::TensorOptions().dtype(torch::kInt64); + int nframes = 1; + + // 1. Handle box: if empty (NoPbc), create a fake box large enough + std::vector coord_d(coord.begin(), coord.end()); + std::vector spin_d(spin.begin(), spin.end()); + std::vector box_d(box.begin(), box.end()); + if (box_d.empty()) { + // Create a fake orthorhombic box that contains all atoms with margin + double min_x = coord_d[0], max_x = coord_d[0]; + double min_y = coord_d[1], max_y = coord_d[1]; + double min_z = coord_d[2], max_z = coord_d[2]; + for (int ii = 1; ii < natoms; ++ii) { + min_x = std::min(min_x, coord_d[ii * 3 + 0]); + max_x = std::max(max_x, coord_d[ii * 3 + 0]); + min_y = std::min(min_y, coord_d[ii * 3 + 1]); + max_y = std::max(max_y, coord_d[ii * 3 + 1]); + min_z = std::min(min_z, coord_d[ii * 3 + 2]); + max_z = std::max(max_z, coord_d[ii * 3 + 2]); + } + // Shift coords so minimum is at rcut (ensures all atoms are in [0, L)) + double shift_x = rcut - min_x; + double shift_y = rcut - min_y; + double shift_z = rcut - min_z; + for (int ii = 0; ii < natoms; ++ii) { + coord_d[ii * 3 + 0] += shift_x; + coord_d[ii * 3 + 1] += shift_y; + coord_d[ii * 3 + 2] += shift_z; + } + box_d.resize(9, 0.0); + box_d[0] = (max_x - min_x) + 2.0 * rcut; + box_d[4] = (max_y - min_y) + 2.0 * rcut; + box_d[8] = (max_z - min_z) + 2.0 * rcut; + } + + // 2. Extend coords with ghosts + std::vector coord_cpy_d; + std::vector atype_cpy, mapping_vec; + std::vector ncell, ngcell; + { + SimulationRegion region; + region.reinitBox(&box_d[0]); + copy_coord(coord_cpy_d, atype_cpy, mapping_vec, ncell, ngcell, coord_d, + atype, static_cast(rcut), region); + } + + int nloc = natoms; + int nall = coord_cpy_d.size() / 3; + + // 2b. Extend spin to ghost atoms using mapping + std::vector spin_cpy_d(static_cast(nall) * 3, 0.0); + for (int ii = 0; ii < nloc; ++ii) { + for (int dd = 0; dd < 3; ++dd) { + spin_cpy_d[static_cast(ii) * 3 + dd] = + spin_d[static_cast(ii) * 3 + dd]; + } + } + for (int ii = nloc; ii < nall; ++ii) { + int li = mapping_vec[ii]; + for (int dd = 0; dd < 3; ++dd) { + spin_cpy_d[static_cast(ii) * 3 + dd] = + spin_d[static_cast(li) * 3 + dd]; + } + } + + // 3. Build neighbor list on extended coords + std::vector> nlist_raw, nlist_r_cpy; + { + SimulationRegion region; + region.reinitBox(&box_d[0]); + std::vector nat_stt(3, 0), ext_stt(3), ext_end(3); + for (int dd = 0; dd < 3; ++dd) { + ext_stt[dd] = -ngcell[dd]; + ext_end[dd] = ncell[dd] + ngcell[dd]; + } + build_nlist(nlist_raw, nlist_r_cpy, coord_cpy_d, nloc, rcut, rcut, nat_stt, + ncell, ext_stt, ext_end, region, ncell); + } + + // 4. Convert to tensors + at::Tensor coord_Tensor = + torch::from_blob(coord_cpy_d.data(), {1, nall, 3}, options) + .clone() + .to(device); + at::Tensor spin_Tensor = + torch::from_blob(spin_cpy_d.data(), {1, nall, 3}, options) + .clone() + .to(device); + std::vector atype_64(atype_cpy.begin(), atype_cpy.end()); + at::Tensor atype_Tensor = + torch::from_blob(atype_64.data(), {1, nall}, int_options) + .clone() + .to(device); + at::Tensor nlist_tensor = + buildTypeSortedNlist(nlist_raw, coord_cpy_d, atype_cpy, sel, nloc, + mixed_types) + .to(device); + std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); + at::Tensor mapping_tensor = + torch::from_blob(mapping_64.data(), {1, nall}, int_options) + .clone() + .to(device); + + // Build fparam/aparam tensors + auto valuetype_options = std::is_same::value + ? torch::TensorOptions().dtype(torch::kFloat32) + : torch::TensorOptions().dtype(torch::kFloat64); + at::Tensor fparam_tensor; + if (!fparam.empty()) { + fparam_tensor = + torch::from_blob(const_cast(fparam.data()), + {1, static_cast(fparam.size())}, + valuetype_options) + .to(torch::kFloat64) + .to(device); + } else if (has_default_fparam_ && !default_fparam_.empty()) { + fparam_tensor = + torch::from_blob(const_cast(default_fparam_.data()), + {1, static_cast(default_fparam_.size())}, + options) + .clone() + .to(device); + } else if (has_default_fparam_) { + throw deepmd::deepmd_exception( + "fparam is empty and default_fparam values are missing from the .pt2 " + "metadata. Please regenerate the model or provide fparam explicitly."); + } else { + fparam_tensor = torch::zeros({0}, options).to(device); + } + + at::Tensor aparam_tensor; + if (!aparam.empty()) { + aparam_tensor = + torch::from_blob( + const_cast(aparam.data()), + {1, natoms, static_cast(aparam.size()) / natoms}, + valuetype_options) + .to(torch::kFloat64) + .to(device); + } else { + aparam_tensor = torch::zeros({0}, options).to(device); + } + + // 5. Run the .pt2 model (7 args for spin) + auto flat_outputs = + run_model(coord_Tensor, atype_Tensor, spin_Tensor, nlist_tensor, + mapping_tensor, fparam_tensor, aparam_tensor); + + // 6. Extract outputs + std::map output_map; + extract_outputs(output_map, flat_outputs); + + // 7. Extract energy + torch::Tensor flat_energy_ = + output_map["energy_redu"].view({-1}).to(torch::kCPU); + ener.assign(flat_energy_.data_ptr(), + flat_energy_.data_ptr() + flat_energy_.numel()); + + // 8. Extract virial + torch::Tensor virial_tensor = + output_map["energy_derv_c_redu"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_virial_ = virial_tensor.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); + + // 9. Extract force and fold back: energy_derv_r (nf, nall, 1, 3) + torch::Tensor force_ext = + output_map["energy_derv_r"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_force_ext = force_ext.to(torch::kCPU); + std::vector extended_force( + cpu_force_ext.data_ptr(), + cpu_force_ext.data_ptr() + cpu_force_ext.numel()); + fold_back(force, extended_force, mapping_vec, nloc, nall, 3, nframes); + + // 10. Extract force_mag and fold back: energy_derv_r_mag (nf, nall, 1, 3) + torch::Tensor force_mag_ext = + output_map["energy_derv_r_mag"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_force_mag_ext = force_mag_ext.to(torch::kCPU); + std::vector extended_force_mag( + cpu_force_mag_ext.data_ptr(), + cpu_force_mag_ext.data_ptr() + cpu_force_mag_ext.numel()); + fold_back(force_mag, extended_force_mag, mapping_vec, nloc, nall, 3, nframes); + + if (atomic) { + // atom_energy: energy (nf, nloc, 1) + torch::Tensor atom_energy_tensor = + output_map["energy"].view({-1}).to(floatType); + torch::Tensor cpu_atom_energy_ = atom_energy_tensor.to(torch::kCPU); + atom_energy.assign( + cpu_atom_energy_.data_ptr(), + cpu_atom_energy_.data_ptr() + cpu_atom_energy_.numel()); + + // atom_virial: energy_derv_c (nf, nall, 1, 9) -> fold back + torch::Tensor atom_virial_ext = + output_map["energy_derv_c"].squeeze(-2).view({-1}).to(floatType); + torch::Tensor cpu_atom_virial_ext = atom_virial_ext.to(torch::kCPU); + std::vector extended_atom_virial( + cpu_atom_virial_ext.data_ptr(), + cpu_atom_virial_ext.data_ptr() + + cpu_atom_virial_ext.numel()); + fold_back(atom_virial, extended_atom_virial, mapping_vec, nloc, nall, 9, + nframes); + } +} + +template void DeepSpinPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); +template void DeepSpinPTExpt::compute>( + std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic); + +void DeepSpinPTExpt::get_type_map(std::string& type_map_str) { + for (const auto& t : type_map) { + type_map_str += t; + type_map_str += " "; + } +} + +// forward to template method +void DeepSpinPTExpt::computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + translate_error([&] { + compute(ener, force, force_mag, virial, atom_energy, atom_virial, coord, + spin, atype, box, fparam, aparam, atomic); + }); +} +void DeepSpinPTExpt::computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + translate_error([&] { + compute(ener, force, force_mag, virial, atom_energy, atom_virial, coord, + spin, atype, box, fparam, aparam, atomic); + }); +} +void DeepSpinPTExpt::computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + translate_error([&] { + compute(ener, force, force_mag, virial, atom_energy, atom_virial, coord, + spin, atype, box, nghost, inlist, ago, fparam, aparam, atomic); + }); +} +void DeepSpinPTExpt::computew(std::vector& ener, + std::vector& force, + std::vector& force_mag, + std::vector& virial, + std::vector& atom_energy, + std::vector& atom_virial, + const std::vector& coord, + const std::vector& spin, + const std::vector& atype, + const std::vector& box, + const int nghost, + const InputNlist& inlist, + const int& ago, + const std::vector& fparam, + const std::vector& aparam, + const bool atomic) { + translate_error([&] { + compute(ener, force, force_mag, virial, atom_energy, atom_virial, coord, + spin, atype, box, nghost, inlist, ago, fparam, aparam, atomic); + }); +} +#endif diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h new file mode 100644 index 0000000000..7dd02d09a9 --- /dev/null +++ b/source/api_cc/src/commonPTExpt.h @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Shared utilities for pt_expt (.pt2 / AOTInductor) backend classes. +// Provides: JSON parser, ZIP archive reader, and type-sorted nlist builder. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "errors.h" + +namespace deepmd { +namespace ptexpt { + +// ============================================================================ +// Minimal JSON value parser for reading metadata from .pt2 archives. +// Supports: strings, numbers, booleans, arrays, objects. +// ============================================================================ + +struct JsonValue; +using JsonObject = std::map; +using JsonArray = std::vector; + +struct JsonValue { + enum Type { Null, Bool, Number, String, Array, Object }; + Type type = Null; + bool bool_val = false; + double num_val = 0.0; + std::string str_val; + JsonArray arr_val; + JsonObject obj_val; + + std::string as_string() const { return str_val; } + double as_double() const { return num_val; } + int as_int() const { return static_cast(num_val); } + bool as_bool() const { return bool_val; } + const JsonArray& as_array() const { return arr_val; } + const JsonObject& as_object() const { return obj_val; } + const JsonValue& operator[](const std::string& key) const { + return obj_val.at(key); + } + const JsonValue& operator[](size_t idx) const { return arr_val.at(idx); } + bool has(const std::string& key) const { + return obj_val.find(key) != obj_val.end(); + } +}; + +class JsonParser { + public: + explicit JsonParser(const std::string& s) : s_(s), pos_(0) {} + JsonValue parse() { + skip_ws(); + auto val = parse_value(); + return val; + } + + private: + const std::string& s_; + size_t pos_; + + char peek() const { return pos_ < s_.size() ? s_[pos_] : '\0'; } + char get() { + if (pos_ >= s_.size()) { + throw std::runtime_error("JSON parse error: unexpected end of input"); + } + return s_[pos_++]; + } + void skip_ws() { + while (pos_ < s_.size() && (s_[pos_] == ' ' || s_[pos_] == '\t' || + s_[pos_] == '\n' || s_[pos_] == '\r')) { + ++pos_; + } + } + + JsonValue parse_value() { + skip_ws(); + char c = peek(); + if (c == '"') { + return parse_string_val(); + } else if (c == '{') { + return parse_object(); + } else if (c == '[') { + return parse_array(); + } else if (c == 't' || c == 'f') { + return parse_bool(); + } else if (c == 'n') { + return parse_null(); + } else { + return parse_number(); + } + } + + std::string parse_string_raw() { + get(); // consume '"' + std::string result; + while (pos_ < s_.size() && peek() != '"') { + if (peek() == '\\') { + get(); + char esc = get(); + switch (esc) { + case '"': + result += '"'; + break; + case '\\': + result += '\\'; + break; + case '/': + result += '/'; + break; + case 'n': + result += '\n'; + break; + case 't': + result += '\t'; + break; + case 'r': + result += '\r'; + break; + default: + result += esc; + break; + } + } else { + result += get(); + } + } + get(); // consume closing '"' + return result; + } + + JsonValue parse_string_val() { + JsonValue v; + v.type = JsonValue::String; + v.str_val = parse_string_raw(); + return v; + } + + JsonValue parse_number() { + size_t start = pos_; + if (peek() == '-') { + get(); + } + while (pos_ < s_.size() && + (std::isdigit(s_[pos_]) || s_[pos_] == '.' || s_[pos_] == 'e' || + s_[pos_] == 'E' || s_[pos_] == '+' || s_[pos_] == '-')) { + // handle sign only if after e/E + if ((s_[pos_] == '+' || s_[pos_] == '-') && pos_ > start && + s_[pos_ - 1] != 'e' && s_[pos_ - 1] != 'E') { + break; + } + ++pos_; + } + JsonValue v; + v.type = JsonValue::Number; + try { + v.num_val = std::stod(s_.substr(start, pos_ - start)); + } catch (const std::exception& e) { + throw std::runtime_error("JSON parse error: invalid number at position " + + std::to_string(start)); + } + return v; + } + + JsonValue parse_bool() { + JsonValue v; + v.type = JsonValue::Bool; + if (s_.substr(pos_, 4) == "true") { + v.bool_val = true; + pos_ += 4; + } else if (s_.substr(pos_, 5) == "false") { + v.bool_val = false; + pos_ += 5; + } else { + throw std::runtime_error( + "JSON parse error: expected 'true' or 'false' at position " + + std::to_string(pos_)); + } + return v; + } + + JsonValue parse_null() { + if (s_.substr(pos_, 4) != "null") { + throw std::runtime_error( + "JSON parse error: expected 'null' at position " + + std::to_string(pos_)); + } + pos_ += 4; + return JsonValue(); + } + + JsonValue parse_array() { + get(); // consume '[' + JsonValue v; + v.type = JsonValue::Array; + skip_ws(); + if (peek() == ']') { + get(); + return v; + } + while (true) { + v.arr_val.push_back(parse_value()); + skip_ws(); + if (peek() == ',') { + get(); + } else { + break; + } + } + skip_ws(); + get(); // consume ']' + return v; + } + + JsonValue parse_object() { + get(); // consume '{' + JsonValue v; + v.type = JsonValue::Object; + skip_ws(); + if (peek() == '}') { + get(); + return v; + } + while (true) { + skip_ws(); + std::string key = parse_string_raw(); + skip_ws(); + get(); // consume ':' + v.obj_val[key] = parse_value(); + skip_ws(); + if (peek() == ',') { + get(); + } else { + break; + } + } + skip_ws(); + get(); // consume '}' + return v; + } +}; + +inline JsonValue parse_json(const std::string& s) { + JsonParser parser(s); + return parser.parse(); +} + +// ============================================================================ +// ZIP archive reader — reads a file from a ZIP archive. +// ============================================================================ + +inline std::string read_zip_entry(const std::string& zip_path, + const std::string& entry_name) { + std::ifstream ifs(zip_path, std::ios::binary); + if (!ifs.is_open()) { + throw deepmd::deepmd_exception("Cannot open file: " + zip_path); + } + + // Read entire file + std::string content((std::istreambuf_iterator(ifs)), + std::istreambuf_iterator()); + ifs.close(); + + // Simple ZIP central directory parser + // Find End of Central Directory Record (EOCD) + if (content.size() < 22) { + throw deepmd::deepmd_exception( + "File too small to be a valid ZIP archive: " + zip_path); + } + size_t eocd_pos = std::string::npos; + for (int64_t i = static_cast(content.size()) - 22; + i >= 0 && static_cast(i) + 3 < content.size(); --i) { + if (content[i] == 0x50 && content[i + 1] == 0x4b && + content[i + 2] == 0x05 && content[i + 3] == 0x06) { + eocd_pos = static_cast(i); + break; + } + } + if (eocd_pos == std::string::npos) { + throw deepmd::deepmd_exception("Invalid ZIP file: " + zip_path); + } + + auto read_u16 = [&](size_t offset) -> uint16_t { + return static_cast(static_cast(content[offset])) | + (static_cast( + static_cast(content[offset + 1])) + << 8); + }; + auto read_u32 = [&](size_t offset) -> uint32_t { + return static_cast(static_cast(content[offset])) | + (static_cast( + static_cast(content[offset + 1])) + << 8) | + (static_cast( + static_cast(content[offset + 2])) + << 16) | + (static_cast( + static_cast(content[offset + 3])) + << 24); + }; + + uint64_t num_entries = read_u16(eocd_pos + 10); + uint64_t cd_offset = read_u32(eocd_pos + 16); + + // Handle ZIP64 + if (cd_offset == 0xFFFFFFFF || num_entries == 0xFFFF) { + if (eocd_pos < 20) { + throw deepmd::deepmd_exception( + "Invalid ZIP64 file (truncated EOCD locator): " + zip_path); + } + size_t zip64_locator_pos = eocd_pos - 20; + if (content[zip64_locator_pos] == 0x50 && + content[zip64_locator_pos + 1] == 0x4b && + content[zip64_locator_pos + 2] == 0x06 && + content[zip64_locator_pos + 3] == 0x07) { + uint64_t zip64_eocd_offset = 0; + for (int b = 0; b < 8; ++b) { + zip64_eocd_offset |= static_cast(static_cast( + content[zip64_locator_pos + 8 + b])) + << (8 * b); + } + size_t z64_pos = static_cast(zip64_eocd_offset); + if (z64_pos + 56 > content.size()) { + throw deepmd::deepmd_exception( + "Invalid ZIP64 file (truncated EOCD record): " + zip_path); + } + num_entries = 0; + for (int b = 0; b < 8; ++b) { + num_entries |= static_cast(static_cast( + content[z64_pos + 32 + b])) + << (8 * b); + } + cd_offset = 0; + for (int b = 0; b < 8; ++b) { + cd_offset |= static_cast( + static_cast(content[z64_pos + 48 + b])) + << (8 * b); + } + } + } + + // Iterate central directory entries + size_t pos = cd_offset; + for (uint64_t i = 0; i < num_entries; ++i) { + if (pos + 46 > content.size()) { + break; + } + uint16_t name_len = read_u16(pos + 28); + uint16_t extra_len = read_u16(pos + 30); + uint16_t comment_len = read_u16(pos + 32); + uint32_t compressed_size_u32 = read_u32(pos + 20); + uint32_t uncompressed_size_u32 = read_u32(pos + 24); + uint32_t local_header_offset_u32 = read_u32(pos + 42); + + uint64_t compressed_size = compressed_size_u32; + uint64_t uncompressed_size = uncompressed_size_u32; + uint64_t local_header_offset = local_header_offset_u32; + + std::string name = content.substr(pos + 46, name_len); + + // Handle ZIP64 extra field + if (uncompressed_size_u32 == 0xFFFFFFFF || + local_header_offset_u32 == 0xFFFFFFFF) { + size_t extra_pos = pos + 46 + name_len; + size_t extra_end = extra_pos + extra_len; + while (extra_pos + 4 <= extra_end) { + uint16_t field_id = read_u16(extra_pos); + uint16_t field_size = read_u16(extra_pos + 2); + if (field_id == 0x0001) { + size_t field_data = extra_pos + 4; + int offset_in_field = 0; + if (uncompressed_size_u32 == 0xFFFFFFFF) { + uncompressed_size = 0; + for (int b = 0; b < 8; ++b) { + uncompressed_size |= + static_cast(static_cast( + content[field_data + offset_in_field + b])) + << (8 * b); + } + offset_in_field += 8; + } + if (compressed_size_u32 == 0xFFFFFFFF) { + compressed_size = 0; + for (int b = 0; b < 8; ++b) { + compressed_size |= + static_cast(static_cast( + content[field_data + offset_in_field + b])) + << (8 * b); + } + offset_in_field += 8; + } + if (local_header_offset_u32 == 0xFFFFFFFF) { + local_header_offset = 0; + for (int b = 0; b < 8; ++b) { + local_header_offset |= + static_cast(static_cast( + content[field_data + offset_in_field + b])) + << (8 * b); + } + } + break; + } + extra_pos += 4 + field_size; + } + } + + // Match exact name or suffix + bool match = (name == entry_name); + if (!match && name.size() > entry_name.size()) { + size_t suffix_start = name.size() - entry_name.size(); + if (name[suffix_start - 1] == '/' && + name.substr(suffix_start) == entry_name) { + match = true; + } + } + if (match) { + uint16_t local_name_len = read_u16(local_header_offset + 26); + uint16_t local_extra_len = read_u16(local_header_offset + 28); + size_t data_offset = + local_header_offset + 30 + local_name_len + local_extra_len; + // PyTorch archives (.pth, .pte, .pt2) always use ZIP STORED (compression + // method 0) for every entry. PyTorch needs to mmap tensor data directly + // from the archive without decompression, so its C++ writer + // (caffe2::serialize::PyTorchStreamWriter) and torch.export.save both + // write uncompressed entries with 64-byte alignment. No decompression is + // needed. + return content.substr(data_offset, uncompressed_size); + } + + pos += 46 + name_len + extra_len + comment_len; + } + + throw deepmd::deepmd_exception("Entry not found in ZIP: " + entry_name + + " in " + zip_path); +} + +// ============================================================================ +// Build type-sorted, sel-limited neighbor list tensor. +// ============================================================================ + +/** + * @brief Convert a raw neighbor list to the sel-limited format expected by the + * pt_expt model. + * + * For non-mixed-type models (distinguish_types=true): the nlist has shape + * (nframes, nloc, sum(sel)), where the first sel[0] entries are neighbors of + * type 0, the next sel[1] are type 1, etc. Within each type group neighbors + * are sorted by distance (ascending). + * + * For mixed-type models (distinguish_types=false): all neighbors go into a + * single group sorted by distance, truncated to sum(sel). + * + * Missing slots are filled with -1. + */ +template +inline torch::Tensor buildTypeSortedNlist( + const std::vector>& raw_nlist, + const std::vector& coord_ext, + const std::vector& atype_ext, + const std::vector& sel, + int nloc, + bool mixed_types) { + int nsel = 0; + for (auto s : sel) { + nsel += s; + } + int ntypes = sel.size(); + std::vector result(static_cast(nloc) * nsel, -1); + + for (int ii = 0; ii < nloc; ++ii) { + const auto& neighbors = raw_nlist[ii]; + VALUETYPE xi = coord_ext[ii * 3 + 0]; + VALUETYPE yi = coord_ext[ii * 3 + 1]; + VALUETYPE zi = coord_ext[ii * 3 + 2]; + int offset = ii * nsel; + + if (mixed_types) { + std::vector> all_neighbors; + for (int jj : neighbors) { + if (jj < 0) { + continue; + } + int jtype = atype_ext[jj]; + if (jtype < 0) { + continue; + } + VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; + VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; + VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; + VALUETYPE rr = dx * dx + dy * dy + dz * dz; + all_neighbors.emplace_back(rr, jj); + } + std::sort(all_neighbors.begin(), all_neighbors.end()); + int count = std::min(static_cast(all_neighbors.size()), nsel); + for (int kk = 0; kk < count; ++kk) { + result[offset + kk] = all_neighbors[kk].second; + } + } else { + std::vector>> by_type(ntypes); + for (int jj : neighbors) { + if (jj < 0) { + continue; + } + int jtype = atype_ext[jj]; + if (jtype < 0 || jtype >= ntypes) { + continue; + } + VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; + VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; + VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; + VALUETYPE rr = dx * dx + dy * dy + dz * dz; + by_type[jtype].emplace_back(rr, jj); + } + int col = 0; + for (int tt = 0; tt < ntypes; ++tt) { + auto& group = by_type[tt]; + std::sort(group.begin(), group.end()); + int count = std::min(static_cast(group.size()), sel[tt]); + for (int kk = 0; kk < count; ++kk) { + result[offset + col + kk] = group[kk].second; + } + col += sel[tt]; + } + } + } + + torch::Tensor tensor = + torch::from_blob(result.data(), {1, nloc, nsel}, + torch::TensorOptions().dtype(torch::kInt64)) + .clone(); + return tensor; +} + +} // namespace ptexpt +} // namespace deepmd diff --git a/source/api_cc/tests/test_deeppot_dpa_pt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_pt_spin.cc index d079635565..20318504a2 100644 --- a/source/api_cc/tests/test_deeppot_dpa_pt_spin.cc +++ b/source/api_cc/tests/test_deeppot_dpa_pt_spin.cc @@ -49,49 +49,49 @@ class TestInferDeepSpinDpaPt : public ::testing::Test { // {ae.ravel()=}") std::vector expected_e = { - -1.8626545229251095e+00, -2.3502165071948093e+00, -2.3500944968573521e+00, - -2.0688274735854710e+00, -2.3485113271625320e+00, -2.3489022338537353e+00, + 7.020322773655288e-03, 1.099636038493644e-01, 1.093176595258250e-01, + 4.865300228001564e-02, 1.096547558413134e-01, 1.099754340356070e-01, }; std::vector expected_f = { - 3.7989110974834261e-02, -6.8203560994098300e-02, 3.1554995279414300e-02, - -6.0769407958790114e-02, 5.6658432967656878e-03, 2.1814741358389407e-02, - 1.5027739412753049e-02, 6.2090755323245192e-02, -5.3346442187326704e-02, - -5.2134406995188787e-02, 4.0990812807417676e-02, -1.6987454510304811e-02, - -6.7153786204261134e-03, -5.3801784772022326e-02, 5.6707773168242034e-02, - 6.6602343186817375e-02, 1.3257934338691726e-02, -3.9743613108414025e-02, + 2.980086586841411e-02, 2.670602118823960e-03, -6.205408022135627e-03, + -7.946653268248605e-03, 4.217792180550986e-03, 1.822080579891798e-03, + -3.416928812442276e-03, -6.992749479424899e-03, 4.728288289346775e-03, + 5.049869641953204e-03, 1.550913149717830e-02, 1.801899070929784e-02, + -1.411871008097311e-02, -8.283139367982638e-03, -7.058623315726573e-03, + -9.368443348703582e-03, -7.121636949145627e-03, -1.130532824067422e-02, }; std::vector expected_fm = { - 4.8385521455777196e+00, 5.3158441514550137e-01, 1.0855626815019124e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2140862110260138e+00, 9.6823434985033552e-01, 1.0689000529371890e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.112646578617150e+00, -2.239176906831133e-01, -2.513101985142691e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.763058480695873e-02, 1.564710428447471e-02, -3.735332673990924e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, }; std::vector expected_tot_v = { - 1.3824836617610417e-01, 1.3936741842895785e-02, -6.6368919371499843e-02, - 2.7457622909082817e-02, 7.8421669005063782e-02, -7.3855775048417324e-02, - -6.6291165687501666e-02, -7.0321379535767850e-02, 9.3633139281050631e-02, + -7.128128841285764e-02, -8.315622874246201e-03, -4.731536549332887e-03, + -1.341290773830820e-02, -1.469079587670170e-03, 1.056456080782730e-03, + -1.192681253424153e-02, 4.978114518702803e-03, -3.966123865761229e-03, }; std::vector expected_atom_v = { - 1.5062258683232639e-02, 1.2127480962864944e-03, -1.0162650648943013e-02, - 1.5209474134091445e-03, 1.1506421293176305e-02, -3.1879228754013102e-03, - -9.8921319202609839e-03, -3.2539620487815153e-03, 8.8506629253760523e-03, - 3.5246401477549733e-02, -1.2095530141782164e-02, -2.7192583359447204e-02, - -8.0687617219887668e-04, 8.7274423879230294e-04, 5.0653534823779300e-04, - -8.8959039912723942e-03, 1.4282170948379993e-03, 7.1793794296111050e-03, - 1.0454610564796012e-02, 4.9337241996913368e-03, -5.1956142767175060e-03, - 2.1322349032843767e-02, 3.6589953171335281e-02, -2.5233310907452358e-02, - -1.8978868056705719e-02, -3.1879121076347043e-02, 2.2823386401396493e-02, - 3.0685989273260134e-02, 3.3832992712807563e-03, -2.6187478868556556e-03, - 3.0642173785931769e-03, 6.0642419846717189e-03, -9.2730679711597770e-03, - -2.9147581597328805e-03, -8.9463130935045487e-03, 1.3848366583449021e-02, - 4.3089321022873554e-03, 4.6119905760196971e-03, -6.9703908253700266e-03, - -4.1294359596699753e-03, 2.0744263944150482e-02, -3.3186693638875019e-02, - 2.8946777313109276e-03, -1.9223536156324161e-02, 3.0643617186671183e-02, - 4.2490174074978321e-02, 1.1890509841399666e-02, -1.4228932374166448e-02, - 6.4864212161055840e-03, 2.6440443729376928e-03, -3.4813150037666628e-03, - -2.8504181290840619e-02, -8.4466642556485783e-03, 1.0287726754546783e-02, + -3.619151252102697e-03, -1.915456807909512e-03, -1.193634026800710e-03, + -1.785172610866009e-03, 4.134184812978909e-03, -6.166519257514488e-05, + -2.157590794793008e-03, -2.693885020778555e-04, 1.587443946069703e-03, + 2.056276410594536e-03, -1.371452525359132e-03, -3.000322809812329e-03, + -2.204910668541369e-03, -2.235579614907452e-04, 1.766238823282138e-03, + 1.610936535839619e-04, -2.104451576692437e-04, -1.804779669204414e-06, + -3.221463421325033e-02, -7.637227088481205e-03, 5.718223298429893e-03, + -2.919861817651578e-04, -7.691370555547318e-03, 4.007185089114448e-03, + 4.811410568749122e-04, 4.415618310181523e-03, -2.774565137306631e-03, + -5.845767386476995e-03, -9.718890770743884e-04, 2.149933205072918e-03, + -1.901140917532940e-03, 7.116473330687919e-04, -1.348184104909009e-03, + 2.280135921727739e-03, -1.281729466483908e-03, 2.259352377835986e-03, + -1.358169386053482e-02, 6.217682155741852e-03, -1.084162229060967e-02, + 1.479497817858454e-04, 3.526517898332583e-03, -5.588222116376142e-03, + -3.264343884716279e-03, 4.301831360658648e-03, -7.109217034201159e-03, + -1.807631811108731e-02, -2.637279531163817e-03, 2.435886074387011e-03, + -7.377647141388572e-03, -1.926501115012390e-03, 2.281103582246440e-03, + -9.427248486918859e-03, -1.977772025906360e-03, 2.072666761510075e-03, }; int natoms; @@ -227,49 +227,49 @@ class TestInferDeepSpinDpaPtNopbc : public ::testing::Test { // {ae.ravel()=}") std::vector expected_e = { - -1.9136796509970209e+00, -2.3532121417832528e+00, - -2.3589759416772553e+00, -2.0689533840218703e+00, - -2.3485273598793084e+00, -2.3489022338537353e+00}; + 1.298915294144196e-02, 1.095576145701290e-01, 1.083914166945241e-01, + 4.932338375417146e-02, 1.099860785812512e-01, 1.100478936528533e-01, + }; std::vector expected_f = { - 5.2440246818294511e-02, -8.2643189092284075e-03, -1.6057110078610215e-02, - -5.2440246818295698e-02, 8.2643189092281334e-03, 1.6057110078610277e-02, - -1.6724663644564395e-03, 7.9346065821642349e-05, -2.5251632397208987e-04, - -5.6934098675373246e-02, 4.0398593044712161e-02, -1.6520316500527876e-02, - -7.9878577602028808e-03, -5.3736758888210570e-02, 5.6516778947603999e-02, - 6.6594422800032166e-02, 1.3258819777676990e-02, -3.9743946123104140e-02, + 1.300765817095240e-02, -1.593967210478553e-03, -3.196759265340465e-03, + -1.300765817095220e-02, 1.593967210478477e-03, 3.196759265340509e-03, + 9.196695370628910e-03, -9.044559760114149e-04, 5.658266727670325e-04, + 1.012744085443978e-02, 1.680427054831429e-02, 1.807036969424208e-02, + -1.133453822298158e-02, -8.941333904804914e-03, -6.627672717506913e-03, + -7.989598002087154e-03, -6.958480667497971e-03, -1.200852364950221e-02, }; std::vector expected_fm = { - 4.5904360179010135e+00, 6.2821415259365443e-01, 9.2483695213043082e-01, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 1.2125967529512662e+00, 9.6807902483755459e-01, 1.0691011858092361e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00, + -9.651705644713781e-01, -1.704326891282164e-01, -2.605677204117113e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.168034653189444e-02, 1.736913887115685e-02, -3.908906640474424e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, }; std::vector expected_tot_v = { - 1.0340989912297673e-01, 1.7731575682021676e-03, -5.2108130921436818e-02, - -2.3573500930805756e-03, 3.1835281809236504e-02, -4.1149040865495201e-02, - -3.8786409820505775e-02, -3.1539990930710131e-02, 6.3609050665518849e-02, + -2.794677047149250e-02, 2.706780654409338e-03, -1.543674466954783e-02, + -6.040095143502448e-03, 3.055091789429398e-03, -4.313832703172029e-03, + -1.604723023732680e-02, 1.777744336229196e-03, -5.171028133240485e-04, }; std::vector expected_atom_v = { - 7.4120023328214657e-03, -3.1050280043021516e-03, -5.6090828464736842e-03, - -2.9955764634068682e-03, 1.2549036535866968e-03, 2.2669227290622353e-03, - -5.4203732932148890e-03, 2.2706969201286079e-03, 4.1019041137797078e-03, - 3.1393780312700348e-02, -1.3151448509374553e-02, -2.3757455371773359e-02, - -3.1200195294236038e-03, 1.3070352082720678e-03, 2.3610958601044327e-03, - -6.4618881649612936e-03, 2.7070072042405802e-03, 4.8900775302411284e-03, - 5.0796934697206701e-03, -6.2368435756216933e-04, 3.7915727217144920e-04, - -2.3541456235528688e-04, 4.7172356421924691e-05, -4.7741044651678936e-05, - 7.8040516599487829e-04, -1.0555907766399088e-04, 7.4337745788048400e-05, - 1.6774660014217457e-02, 1.9900026394719939e-03, -1.3757156420267647e-03, - 1.4918414385930168e-03, 5.8652659168653282e-03, -9.1012492701524200e-03, - -1.5633397202315658e-03, -8.7822339826287010e-03, 1.3702299440830762e-02, - 2.6638349513850509e-04, 4.7758974965528556e-03, -7.5202570595042720e-03, - -3.9854902437910233e-03, 2.0716934356864808e-02, -3.3146966232075720e-02, - 2.3833618251221655e-03, -1.9183462719327554e-02, 3.0553102668094204e-02, - 4.2483379498378271e-02, 1.1887418303416192e-02, -1.4224777273830189e-02, - 6.4873092673031901e-03, 2.6439703172256787e-03, -3.4811029077820542e-03, - -2.8504575633215068e-02, -8.4464392754590722e-03, 1.0287329166785000e-02, + 2.306155467987514e-03, -9.660921555042744e-04, -1.745198732526081e-03, + -9.476983959938838e-04, 3.970087875108642e-04, 7.171771645359225e-04, + -1.713485354054099e-03, 7.178114321056426e-04, 1.296691619287112e-03, + 7.319511578522508e-03, -3.066281877489238e-03, -5.539089843206328e-03, + -2.318373397601943e-04, 9.712104773733529e-05, 1.754444733319547e-04, + -6.521165022942005e-04, 2.731839401501884e-04, 4.934935693035620e-04, + -2.688582999243521e-02, 2.879540148031285e-03, -1.319893926076197e-03, + 2.654650285707574e-03, -3.069982439259916e-04, 1.657790217325663e-04, + -1.692335959139460e-03, 2.136332667243511e-04, -1.340912177342011e-04, + 4.337039816878673e-03, -4.892467685548227e-04, 1.504317649360216e-03, + -5.929947364721471e-04, 9.606430484810017e-04, -1.606901551850972e-03, + 1.319957388254711e-03, -1.503407517480697e-03, 2.524877227122832e-03, + -3.426950456916239e-03, 6.054673894090045e-03, -9.902054694804960e-03, + -4.978554214552650e-04, 3.694171795596980e-03, -5.898283613865629e-03, + -2.340485739136391e-03, 4.340998169957810e-03, -7.094248555283199e-03, + -1.159669688552975e-02, -1.705812586163657e-03, 1.565174877705520e-03, + -6.424359535528532e-03, -1.786854645970792e-03, 2.132951802944129e-03, + -1.096876407095736e-02, -2.264474955228100e-03, 2.396174543979845e-03, }; int natoms; diff --git a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc new file mode 100644 index 0000000000..27c1836aa1 --- /dev/null +++ b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Test C++ inference for pt_expt (.pt2) backend with DPA1 spin model. +#include + +#include +#include +#include +#include + +#include "DeepSpin.h" +#include "neighbor_list.h" +#include "test_utils.h" + +// Spin models need relaxed epsilon +#undef EPSILON +#define EPSILON (std::is_same::value ? 1e-6 : 1e-1) + +// ============================================================================ +// PBC test fixture +// ============================================================================ + +template +class TestInferDeepSpinDpaPtExpt : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector spin = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0., + 0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + + // Reference values generated by source/tests/infer/gen_spin.py + std::vector expected_e; + std::vector expected_f; + std::vector expected_fm; + std::vector expected_tot_v; + std::vector expected_atom_v; + + int natoms; + double expected_tot_e; + + deepmd::DeepSpin dp; + + void SetUp() override { + // The .pt2 spin model requires the BUILD_PT_EXPT guard from the header. + // If AOTInductor headers are missing, skip. + std::string model_path = "../../tests/infer/deeppot_dpa_spin.pt2"; + { + std::ifstream f(model_path); + if (!f.good()) { + GTEST_SKIP() << "Skipping: " << model_path << " not found."; + } + } +#ifndef BUILD_PYTORCH + GTEST_SKIP() << "Skip because PyTorch support is not enabled."; +#endif + dp.init(model_path); + + // PBC reference values from gen_spin.py + expected_e = { + 7.020322773655288e-03, 1.099636038493644e-01, 1.093176595258250e-01, + 4.865300228001564e-02, 1.096547558413134e-01, 1.099754340356070e-01, + }; + expected_f = { + 2.980086586841411e-02, 2.670602118823960e-03, -6.205408022135627e-03, + -7.946653268248605e-03, 4.217792180550986e-03, 1.822080579891798e-03, + -3.416928812442276e-03, -6.992749479424899e-03, 4.728288289346775e-03, + 5.049869641953204e-03, 1.550913149717830e-02, 1.801899070929784e-02, + -1.411871008097311e-02, -8.283139367982638e-03, -7.058623315726573e-03, + -9.368443348703582e-03, -7.121636949145627e-03, -1.130532824067422e-02, + }; + expected_fm = { + -1.112646578617150e+00, -2.239176906831133e-01, -2.513101985142691e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.763058480695873e-02, 1.564710428447471e-02, -3.735332673990924e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + }; + expected_tot_v = { + -7.128128841285764e-02, -8.315622874246201e-03, -4.731536549332887e-03, + -1.341290773830820e-02, -1.469079587670170e-03, 1.056456080782730e-03, + -1.192681253424153e-02, 4.978114518702803e-03, -3.966123865761229e-03, + }; + expected_atom_v = { + -3.619151252102697e-03, -1.915456807909512e-03, -1.193634026800710e-03, + -1.785172610866009e-03, 4.134184812978909e-03, -6.166519257514488e-05, + -2.157590794793008e-03, -2.693885020778555e-04, 1.587443946069703e-03, + 2.056276410594536e-03, -1.371452525359132e-03, -3.000322809812329e-03, + -2.204910668541369e-03, -2.235579614907452e-04, 1.766238823282138e-03, + 1.610936535839619e-04, -2.104451576692437e-04, -1.804779669204414e-06, + -3.221463421325033e-02, -7.637227088481205e-03, 5.718223298429893e-03, + -2.919861817651578e-04, -7.691370555547318e-03, 4.007185089114448e-03, + 4.811410568749122e-04, 4.415618310181523e-03, -2.774565137306631e-03, + -5.845767386476995e-03, -9.718890770743884e-04, 2.149933205072918e-03, + -1.901140917532940e-03, 7.116473330687919e-04, -1.348184104909009e-03, + 2.280135921727739e-03, -1.281729466483908e-03, 2.259352377835986e-03, + -1.358169386053482e-02, 6.217682155741852e-03, -1.084162229060967e-02, + 1.479497817858454e-04, 3.526517898332583e-03, -5.588222116376142e-03, + -3.264343884716279e-03, 4.301831360658648e-03, -7.109217034201159e-03, + -1.807631811108731e-02, -2.637279531163817e-03, 2.435886074387011e-03, + -7.377647141388572e-03, -1.926501115012390e-03, 2.281103582246440e-03, + -9.427248486918859e-03, -1.977772025906360e-03, 2.072666761510075e-03, + }; + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 3, expected_fm.size()); + EXPECT_EQ(9, expected_tot_v.size()); + EXPECT_EQ(natoms * 9, expected_atom_v.size()); + expected_tot_e = 0.; + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + }; + + void TearDown() override {}; +}; + +TYPED_TEST_SUITE(TestInferDeepSpinDpaPtExpt, ValueTypes); + +TYPED_TEST(TestInferDeepSpinDpaPtExpt, test_get_use_spin) { + deepmd::DeepSpin& dp = this->dp; + std::vector use_spin = dp.get_use_spin(); + EXPECT_EQ(use_spin.size(), 3); + EXPECT_TRUE(use_spin[0]); // Ni has spin + EXPECT_FALSE(use_spin[1]); // O has no spin + EXPECT_FALSE(use_spin[2]); // H has no spin +} + +TYPED_TEST(TestInferDeepSpinDpaPtExpt, cpu_build_nlist) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial; + dp.compute(ener, force, force_mag, virial, coord, spin, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepSpinDpaPtExpt, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + std::vector& expected_atom_v = this->expected_atom_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial, atom_ener, atom_vir; + dp.compute(ener, force, force_mag, virial, atom_ener, atom_vir, coord, spin, + atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_EQ(atom_ener.size(), natoms); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + EXPECT_FALSE(atom_vir.empty()) << "Atomic virial should not be empty"; + EXPECT_EQ(atom_vir.size(), natoms * 9); + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_atom_v[ii]), EPSILON); + } +} + +// ============================================================================ +// NoPBC test fixture +// ============================================================================ + +template +class TestInferDeepSpinDpaPtExptNopbc : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector spin = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0., + 0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {}; + + // Reference values for NoPBC from gen_spin.py + std::vector expected_e; + std::vector expected_f; + std::vector expected_fm; + std::vector expected_tot_v; + std::vector expected_atom_v; + + int natoms; + double expected_tot_e; + + deepmd::DeepSpin dp; + + void SetUp() override { + std::string model_path = "../../tests/infer/deeppot_dpa_spin.pt2"; + { + std::ifstream f(model_path); + if (!f.good()) { + GTEST_SKIP() << "Skipping: " << model_path << " not found."; + } + } +#ifndef BUILD_PYTORCH + GTEST_SKIP() << "Skip because PyTorch support is not enabled."; +#endif + dp.init(model_path); + + // NoPBC reference values from gen_spin.py + expected_e = { + 1.298915294144196e-02, 1.095576145701290e-01, 1.083914166945241e-01, + 4.932338375417146e-02, 1.099860785812512e-01, 1.100478936528533e-01, + }; + expected_f = { + 1.300765817095240e-02, -1.593967210478553e-03, -3.196759265340465e-03, + -1.300765817095220e-02, 1.593967210478477e-03, 3.196759265340509e-03, + 9.196695370628910e-03, -9.044559760114149e-04, 5.658266727670325e-04, + 1.012744085443978e-02, 1.680427054831429e-02, 1.807036969424208e-02, + -1.133453822298158e-02, -8.941333904804914e-03, -6.627672717506913e-03, + -7.989598002087154e-03, -6.958480667497971e-03, -1.200852364950221e-02, + }; + expected_fm = { + -9.651705644713781e-01, -1.704326891282164e-01, -2.605677204117113e-01, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + -9.168034653189444e-02, 1.736913887115685e-02, -3.908906640474424e-02, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + 0.000000000000000e+00, 0.000000000000000e+00, 0.000000000000000e+00, + }; + expected_tot_v = { + -2.794677047149250e-02, 2.706780654409338e-03, -1.543674466954783e-02, + -6.040095143502448e-03, 3.055091789429398e-03, -4.313832703172029e-03, + -1.604723023732680e-02, 1.777744336229196e-03, -5.171028133240485e-04, + }; + expected_atom_v = { + 2.306155467987514e-03, -9.660921555042744e-04, -1.745198732526081e-03, + -9.476983959938838e-04, 3.970087875108642e-04, 7.171771645359225e-04, + -1.713485354054099e-03, 7.178114321056426e-04, 1.296691619287112e-03, + 7.319511578522508e-03, -3.066281877489238e-03, -5.539089843206328e-03, + -2.318373397601943e-04, 9.712104773733529e-05, 1.754444733319547e-04, + -6.521165022942005e-04, 2.731839401501884e-04, 4.934935693035620e-04, + -2.688582999243521e-02, 2.879540148031285e-03, -1.319893926076197e-03, + 2.654650285707574e-03, -3.069982439259916e-04, 1.657790217325663e-04, + -1.692335959139460e-03, 2.136332667243511e-04, -1.340912177342011e-04, + 4.337039816878673e-03, -4.892467685548227e-04, 1.504317649360216e-03, + -5.929947364721471e-04, 9.606430484810017e-04, -1.606901551850972e-03, + 1.319957388254711e-03, -1.503407517480697e-03, 2.524877227122832e-03, + -3.426950456916239e-03, 6.054673894090045e-03, -9.902054694804960e-03, + -4.978554214552650e-04, 3.694171795596980e-03, -5.898283613865629e-03, + -2.340485739136391e-03, 4.340998169957810e-03, -7.094248555283199e-03, + -1.159669688552975e-02, -1.705812586163657e-03, 1.565174877705520e-03, + -6.424359535528532e-03, -1.786854645970792e-03, 2.132951802944129e-03, + -1.096876407095736e-02, -2.264474955228100e-03, 2.396174543979845e-03, + }; + + natoms = expected_e.size(); + EXPECT_EQ(natoms * 3, expected_f.size()); + EXPECT_EQ(natoms * 3, expected_fm.size()); + EXPECT_EQ(9, expected_tot_v.size()); + EXPECT_EQ(natoms * 9, expected_atom_v.size()); + expected_tot_e = 0.; + for (int ii = 0; ii < natoms; ++ii) { + expected_tot_e += expected_e[ii]; + } + }; + + void TearDown() override {}; +}; + +TYPED_TEST_SUITE(TestInferDeepSpinDpaPtExptNopbc, ValueTypes); + +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_build_nlist) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial; + dp.compute(ener, force, force_mag, virial, coord, spin, atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + std::vector& expected_atom_v = this->expected_atom_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial, atom_ener, atom_vir; + dp.compute(ener, force, force_mag, virial, atom_ener, atom_vir, coord, spin, + atype, box); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_EQ(atom_ener.size(), natoms); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + EXPECT_FALSE(atom_vir.empty()) << "Atomic virial should not be empty"; + EXPECT_EQ(atom_vir.size(), natoms * 9); + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_atom_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_lmp_nlist) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial; + + std::vector > nlist_data = { + {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5}, + {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + dp.compute(ener, force, force_mag, virial, coord, spin, atype, box, 0, inlist, + 0); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_lmp_nlist_atomic) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_e = this->expected_e; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + std::vector& expected_atom_v = this->expected_atom_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial, atom_ener, atom_vir; + + std::vector > nlist_data = { + {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5}, + {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + dp.compute(ener, force, force_mag, virial, atom_ener, atom_vir, coord, spin, + atype, box, 0, inlist, 0); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_EQ(atom_ener.size(), natoms); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } + for (int ii = 0; ii < natoms; ++ii) { + EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON); + } + EXPECT_FALSE(atom_vir.empty()) << "Atomic virial should not be empty"; + EXPECT_EQ(atom_vir.size(), natoms * 9); + for (int ii = 0; ii < natoms * 9; ++ii) { + EXPECT_LT(fabs(atom_vir[ii] - expected_atom_v[ii]), EPSILON); + } +} diff --git a/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc b/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc new file mode 100644 index 0000000000..e58d9c0f78 --- /dev/null +++ b/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Test C++ model deviation for pt_expt (.pt2) backend with DPA1 spin model. +#include + +#include +#include +#include +#include + +#include "DeepSpin.h" +#include "neighbor_list.h" +#include "test_utils.h" + +// Spin models need relaxed epsilon +#undef EPSILON +#define EPSILON (std::is_same::value ? 1e-6 : 1e-1) + +template +class TestInferDeepSpinModeDeviPtExpt : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector spin = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0., + 0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + int natoms; + + deepmd::DeepSpin dp0; + deepmd::DeepSpin dp1; + deepmd::DeepSpinModelDevi dp_md; + + void SetUp() override { + std::string model0_path = "../../tests/infer/deeppot_dpa_spin_md0.pt2"; + std::string model1_path = "../../tests/infer/deeppot_dpa_spin_md1.pt2"; + { + std::ifstream f(model0_path); + if (!f.good()) { + GTEST_SKIP() << "Skipping: " << model0_path << " not found."; + } + } + { + std::ifstream f(model1_path); + if (!f.good()) { + GTEST_SKIP() << "Skipping: " << model1_path << " not found."; + } + } +#ifndef BUILD_PYTORCH + GTEST_SKIP() << "Skip because PyTorch support is not enabled."; +#endif + dp0.init(model0_path); + dp1.init(model1_path); + dp_md.init({model0_path, model1_path}); + natoms = atype.size(); + }; + + void TearDown() override {}; +}; + +TYPED_TEST_SUITE(TestInferDeepSpinModeDeviPtExpt, ValueTypes); + +TYPED_TEST(TestInferDeepSpinModeDeviPtExpt, attrs) { + using VALUETYPE = TypeParam; + deepmd::DeepSpin& dp0 = this->dp0; + deepmd::DeepSpin& dp1 = this->dp1; + deepmd::DeepSpinModelDevi& dp_md = this->dp_md; + EXPECT_EQ(dp0.cutoff(), dp_md.cutoff()); + EXPECT_EQ(dp0.numb_types(), dp_md.numb_types()); + EXPECT_EQ(dp0.dim_fparam(), dp_md.dim_fparam()); + EXPECT_EQ(dp0.dim_aparam(), dp_md.dim_aparam()); + EXPECT_EQ(dp1.cutoff(), dp_md.cutoff()); + EXPECT_EQ(dp1.numb_types(), dp_md.numb_types()); + EXPECT_EQ(dp1.dim_fparam(), dp_md.dim_fparam()); + EXPECT_EQ(dp1.dim_aparam(), dp_md.dim_aparam()); +} + +TYPED_TEST(TestInferDeepSpinModeDeviPtExpt, cpu_build_nlist) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + int& natoms = this->natoms; + deepmd::DeepSpin& dp0 = this->dp0; + deepmd::DeepSpin& dp1 = this->dp1; + deepmd::DeepSpinModelDevi& dp_md = this->dp_md; + + int nmodel = 2; + std::vector edir(nmodel), emd; + std::vector > fdir(nmodel), fmagdir(nmodel), + vdir(nmodel), fmd(nmodel), fmmagd(nmodel), vmd; + dp0.compute(edir[0], fdir[0], fmagdir[0], vdir[0], coord, spin, atype, box); + dp1.compute(edir[1], fdir[1], fmagdir[1], vdir[1], coord, spin, atype, box); + dp_md.compute(emd, fmd, fmmagd, vmd, coord, spin, atype, box); + + EXPECT_EQ(edir.size(), emd.size()); + EXPECT_EQ(fdir.size(), fmd.size()); + EXPECT_EQ(fmagdir.size(), fmmagd.size()); + EXPECT_EQ(vdir.size(), vmd.size()); + for (int kk = 0; kk < nmodel; ++kk) { + EXPECT_EQ(fdir[kk].size(), fmd[kk].size()); + EXPECT_EQ(fmagdir[kk].size(), fmmagd[kk].size()); + EXPECT_EQ(vdir[kk].size(), vmd[kk].size()); + } + for (int kk = 0; kk < nmodel; ++kk) { + EXPECT_LT(fabs(edir[kk] - emd[kk]), EPSILON); + for (size_t ii = 0; ii < fdir[0].size(); ++ii) { + EXPECT_LT(fabs(fdir[kk][ii] - fmd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < fmagdir[0].size(); ++ii) { + EXPECT_LT(fabs(fmagdir[kk][ii] - fmmagd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < vdir[0].size(); ++ii) { + EXPECT_LT(fabs(vdir[kk][ii] - vmd[kk][ii]), EPSILON); + } + } +} + +TYPED_TEST(TestInferDeepSpinModeDeviPtExpt, cpu_build_nlist_atomic) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + int& natoms = this->natoms; + deepmd::DeepSpin& dp0 = this->dp0; + deepmd::DeepSpin& dp1 = this->dp1; + deepmd::DeepSpinModelDevi& dp_md = this->dp_md; + + int nmodel = 2; + std::vector edir(nmodel), emd; + std::vector > fdir(nmodel), fmagdir(nmodel), + vdir(nmodel), fmd(nmodel), fmmagd(nmodel), vmd, aedir(nmodel), aemd, + avdir(nmodel), avmd(nmodel); + dp0.compute(edir[0], fdir[0], fmagdir[0], vdir[0], aedir[0], avdir[0], coord, + spin, atype, box); + dp1.compute(edir[1], fdir[1], fmagdir[1], vdir[1], aedir[1], avdir[1], coord, + spin, atype, box); + dp_md.compute(emd, fmd, fmmagd, vmd, aemd, avmd, coord, spin, atype, box); + + EXPECT_EQ(edir.size(), emd.size()); + EXPECT_EQ(fdir.size(), fmd.size()); + EXPECT_EQ(fmagdir.size(), fmmagd.size()); + EXPECT_EQ(vdir.size(), vmd.size()); + EXPECT_EQ(aedir.size(), aemd.size()); + EXPECT_EQ(avdir.size(), avmd.size()); + for (int kk = 0; kk < nmodel; ++kk) { + EXPECT_EQ(fdir[kk].size(), fmd[kk].size()); + EXPECT_EQ(fmagdir[kk].size(), fmmagd[kk].size()); + EXPECT_EQ(vdir[kk].size(), vmd[kk].size()); + EXPECT_EQ(aedir[kk].size(), aemd[kk].size()); + EXPECT_EQ(avdir[kk].size(), avmd[kk].size()); + } + for (int kk = 0; kk < nmodel; ++kk) { + EXPECT_LT(fabs(edir[kk] - emd[kk]), EPSILON); + for (size_t ii = 0; ii < fdir[0].size(); ++ii) { + EXPECT_LT(fabs(fdir[kk][ii] - fmd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < fmagdir[0].size(); ++ii) { + EXPECT_LT(fabs(fmagdir[kk][ii] - fmmagd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < vdir[0].size(); ++ii) { + EXPECT_LT(fabs(vdir[kk][ii] - vmd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < aedir[0].size(); ++ii) { + EXPECT_LT(fabs(aedir[kk][ii] - aemd[kk][ii]), EPSILON); + } + for (size_t ii = 0; ii < avdir[0].size(); ++ii) { + EXPECT_LT(fabs(avdir[kk][ii] - avmd[kk][ii]), EPSILON); + } + } +} diff --git a/source/install/test_cc_local.sh b/source/install/test_cc_local.sh index b49858daff..5ddbf0eecc 100755 --- a/source/install/test_cc_local.sh +++ b/source/install/test_cc_local.sh @@ -69,7 +69,7 @@ else: _GEN_ENV="LD_PRELOAD=${_LSAN_LIB} LSAN_OPTIONS=detect_leaks=0" fi fi - # Run gen scripts in parallel (2 groups of 3) for faster model generation. + # Run gen scripts in parallel for faster model generation. # Wait on each PID separately so any failure is caught by set -e. env ${_GEN_ENV} python ${INFER_SCRIPT_PATH}/gen_sea.py & PID1=$! @@ -90,6 +90,13 @@ else: wait $PID4 wait $PID5 wait $PID6 + + env ${_GEN_ENV} python ${INFER_SCRIPT_PATH}/gen_spin.py & + PID7=$! + env ${_GEN_ENV} python ${INFER_SCRIPT_PATH}/gen_spin_model_devi.py & + PID8=$! + wait $PID7 + wait $PID8 fi if [ "${ENABLE_PADDLE:-TRUE}" == "TRUE" ]; then PADDLE_INFERENCE_DIR=${BUILD_TMP_DIR}/paddle_inference_install_dir diff --git a/source/lmp/tests/test_lammps_spin_nopbc_pt.py b/source/lmp/tests/test_lammps_spin_nopbc_pt.py index c05a00afce..486538fe8b 100644 --- a/source/lmp/tests/test_lammps_spin_nopbc_pt.py +++ b/source/lmp/tests/test_lammps_spin_nopbc_pt.py @@ -31,24 +31,21 @@ data_type_map_file = Path(__file__).parent / "data_type_map.lmp" md_file = Path(__file__).parent / "md.out" -expected_ae = np.array( - [-2.337313880002, -2.339828637443377, -2.358478126000974, -2.358478126000974] -) -expected_e = np.sum(expected_ae) +expected_e = 3.5101080091096860e-01 expected_f = np.array( [ - [0.036908169450058, -0.0154615304452946, -0.0277072310206975], - [-0.036908169450058, 0.0154615304452946, 0.0277072310206975], - [-0.001114392839443, -0.0010410775210586, 0.0015249586223957], - [0.001114392839443, 0.0010410775210586, -0.0015249586223957], + [3.9007324220254663e-03, -1.6340906092268837e-03, 2.4784543132550553e-03], + [-3.9007324220254660e-03, 1.6340906092268837e-03, -2.4784543132550550e-03], + [1.0879565176984952e-04, 1.0163804310078055e-04, -1.4887826031663627e-04], + [-1.0879565176984952e-04, -1.0163804310078055e-04, 1.4887826031663627e-04], ] ) expected_fm = np.array( [ - [0.0075469514215227, -0.0031615607306379, 0.0204654183352036], - [-0.0074172317569893, 0.003107218709009, 0.020927678503653], - [0.0000000000000000, 0.00000000000000000, 0.00000000000000000], - [0.0000000000000000, 0.00000000000000000, 0.00000000000000000], + [3.4589594972289518e-03, -1.4490235731634794e-03, 2.3561281037953720e-03], + [-3.0400436796538990e-04, 1.2735318117469008e-04, -1.4949786028183132e-03], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], ] ) diff --git a/source/lmp/tests/test_lammps_spin_nopbc_pt2.py b/source/lmp/tests/test_lammps_spin_nopbc_pt2.py new file mode 100644 index 0000000000..56952fdc0d --- /dev/null +++ b/source/lmp/tests/test_lammps_spin_nopbc_pt2.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +from pathlib import ( + Path, +) + +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data_spin, +) + +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa_spin.pt2" +) +pb_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa_spin_md1.pt2" +) +data_file = Path(__file__).parent / "data.lmp" +md_file = Path(__file__).parent / "md.out" + +# Reference values from the seed=1 .pt2 model (NoPBC, Model 0) +expected_e = 3.5101080091096860e-01 +expected_f = np.array( + [ + [3.9007324220254663e-03, -1.6340906092268837e-03, 2.4784543132550553e-03], + [-3.9007324220254660e-03, 1.6340906092268837e-03, -2.4784543132550550e-03], + [1.0879565176984952e-04, 1.0163804310078055e-04, -1.4887826031663627e-04], + [-1.0879565176984952e-04, -1.0163804310078055e-04, 1.4887826031663627e-04], + ] +) +expected_fm = np.array( + [ + [3.4589594972289518e-03, -1.4490235731634794e-03, 2.3561281037953720e-03], + [-3.0400436796538990e-04, 1.2735318117469008e-04, -1.4949786028183132e-03], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + ] +) + +# Reference values from the seed=2 .pt2 model (NoPBC, Model 1) +expected_f2 = np.array( + [ + [-3.0870239868329980e-02, 1.2932127512408504e-02, 2.7561357633479750e-02], + [3.0870239868329978e-02, -1.2932127512408506e-02, -2.7561357633479742e-02], + [4.5712656471395960e-04, 4.2705244861435744e-04, -6.2554161487173430e-04], + [-4.5712656471395960e-04, -4.2705244861435744e-04, 6.2554161487173430e-04], + ] +) + +expected_fm2 = np.array( + [ + [-7.2838456252868870e-03, 3.0513407349174793e-03, -1.9672009896273334e-02], + [9.7240358761389140e-03, -4.0735825967608950e-03, -1.7012161861151297e-02], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + ] +) + +box = np.array([0, 100, 0, 100, 0, 100, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +spin = np.array( + [ + [0, 0, 1.2737], + [0, 0, 1.2737], + [0, 0, 0], + [0, 0, 0], + ] +) +type_NiO = np.array([1, 1, 2, 2]) + + +def setup_module() -> None: + if os.environ.get("ENABLE_PYTORCH", "1") != "1": + pytest.skip( + "Skip test because PyTorch support is not enabled.", + ) + write_lmp_data_spin(box, coord, spin, type_NiO, data_file) + + +def teardown_module() -> None: + os.remove(data_file) + if md_file.exists(): + os.remove(md_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("f f f") + lammps.atom_style("spin") + if units == "metal": + lammps.neighbor("2.0 bin") + else: + raise ValueError("units for spin should be metal") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal": + lammps.mass("1 58") + lammps.mass("2 16") + else: + raise ValueError("units for spin should be metal") + if units == "metal": + lammps.timestep(0.0005) + else: + raise ValueError("units for spin should be metal") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps) -> None: + lammps.pair_style(f"deepspin {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_model_devi(lammps) -> None: + lammps.pair_style( + f"deepspin {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_fm = np.linalg.norm(np.std([expected_fm, expected_fm2], axis=0), axis=1) + # rel=1e-4: md.out is written with default scientific format (~6 significant digits) + assert md[4] == pytest.approx(np.max(expected_md_f), rel=1e-4) + assert md[5] == pytest.approx(np.min(expected_md_f), rel=1e-4) + assert md[6] == pytest.approx(np.mean(expected_md_f), rel=1e-4) + assert md[7] == pytest.approx(np.max(expected_md_fm), rel=1e-4) + assert md[8] == pytest.approx(np.min(expected_md_fm), rel=1e-4) + assert md[9] == pytest.approx(np.mean(expected_md_fm), rel=1e-4) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps) -> None: + relative = 1.0 + lammps.pair_style( + f"deepspin {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + norm_spin = np.linalg.norm(np.mean([expected_fm, expected_fm2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + expected_md_fm = np.linalg.norm(np.std([expected_fm, expected_fm2], axis=0), axis=1) + expected_md_fm /= norm_spin + relative + # rel=1e-4: md.out is written with default scientific format (~6 significant digits) + assert md[4] == pytest.approx(np.max(expected_md_f), rel=1e-4) + assert md[5] == pytest.approx(np.min(expected_md_f), rel=1e-4) + assert md[6] == pytest.approx(np.mean(expected_md_f), rel=1e-4) + assert md[7] == pytest.approx(np.max(expected_md_fm), rel=1e-4) + assert md[8] == pytest.approx(np.min(expected_md_fm), rel=1e-4) + assert md[9] == pytest.approx(np.mean(expected_md_fm), rel=1e-4) diff --git a/source/lmp/tests/test_lammps_spin_pt.py b/source/lmp/tests/test_lammps_spin_pt.py index 034ab1f431..8a21f27710 100644 --- a/source/lmp/tests/test_lammps_spin_pt.py +++ b/source/lmp/tests/test_lammps_spin_pt.py @@ -32,24 +32,29 @@ data_type_map_file = Path(__file__).parent / "data_type_map.lmp" md_file = Path(__file__).parent / "md.out" +expected_e = 3.5053686886040974e-01 expected_ae = np.array( - [-2.33730603846356, -2.339828637443377, -2.3584765990764933, -2.358478126000974] + [ + 6.8203336981159180e-02, + 6.4899945717305470e-02, + 1.0867432727951952e-01, + 1.0875925888242556e-01, + ] ) -expected_e = np.sum(expected_ae) expected_f = np.array( [ - [0.036819000183374, -0.0154603124989284, -0.0277136918031471], - [-0.0369115932121166, 0.0154614940830129, 0.0277067438704936], - [-0.0010240778189108, -0.0010425850123752, 0.0015323196618039], - [0.0011166708476534, 0.0010414034282908, -0.0015253717291505], + [3.9965859059739960e-03, -1.5714685255928385e-03, 2.5489246986054267e-03], + [-3.2538293505478910e-03, 1.5372892024705638e-03, -2.6183962915675120e-03], + [-5.8572951970559290e-04, 1.2789559354576532e-04, -8.4978782279560860e-05], + [-1.5702703572051152e-04, -9.3716270423490900e-05, 1.5445037524164582e-04], ] ) expected_fm = np.array( [ - [0.007540380021158, -0.0031615447712641, 0.0204706018052022], - [-0.0074177167392878, 0.0031072528813168, 0.0209277147341756], - [0.0000000000000000, 0.00000000000000000, 0.00000000000000000], - [0.0000000000000000, 0.00000000000000000, 0.00000000000000000], + [3.5882891345325510e-03, -1.4332293643181068e-03, 2.1309512069449844e-03], + [-2.2686566316524877e-04, 1.1563585411337640e-04, -1.5585998391414227e-03], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], ] ) @@ -73,42 +78,42 @@ expected_v = -np.array( [ - 0.0138536891649799, - -0.0057815832940349, - -0.0104366273910430, - -0.0057802135977019, - 0.0024216972469495, - 0.0043747666241247, - -0.0120159787305366, - 0.0050342035124280, - 0.0090942101965059, - 0.0135151396517160, - -0.0056617476919350, - -0.0102276732499471, - -0.0056606594176084, - 0.0023713573235927, - 0.0042837422619739, - -0.0084858208754591, - 0.0035548709072868, - 0.0064217022841311, - 0.0007099617850315, - 0.0003917168967788, - -0.0005467867622337, - 0.0003906286224523, - 0.0003696501943719, - -0.0005419287758774, - -0.0005551067425154, - -0.0005416915274450, - 0.0007957607021995, - 0.0004252005652282, - 0.0003972268438316, - -0.0005818534050492, - 0.0003958571474987, - 0.0003698139141107, - -0.0005416992544720, - -0.0005797982376440, - -0.0005416536167464, - 0.0007934081146707, + -6.5442162870715910e-04, + 3.3133477241365930e-04, + 3.9883281200020510e-04, + 3.5002780951132103e-04, + -1.1589573685630256e-04, + -2.2265549504892014e-04, + 3.4055402098896265e-04, + -1.9139376762381335e-04, + -3.3133994148957030e-04, + 3.1589094549377213e-03, + -1.3741877857714818e-03, + -2.5227682918535005e-03, + -1.4131573474734480e-03, + 6.0225037103514220e-04, + 1.0902107277722482e-03, + 1.4702420402848647e-03, + -6.1977156785934680e-04, + -1.1144364605899786e-03, + -2.1325917337790670e-03, + 7.0758382747778870e-05, + -1.2848658004242357e-05, + 1.0290158078156647e-04, + -2.7842261103620382e-05, + 3.2953845020209020e-05, + 2.2234591227388966e-04, + 2.2930239123112260e-05, + -3.1997128170916170e-05, + -8.1674704416169310e-04, + -1.1094383565005458e-04, + 1.6599006542109280e-04, + -1.2281050907954750e-04, + -5.0200238466137690e-05, + 7.2193112591762430e-05, + 2.7750902949636390e-04, + 8.2718432747583720e-05, + -1.1997871423457964e-04, ] ).reshape(4, 9) diff --git a/source/lmp/tests/test_lammps_spin_pt2.py b/source/lmp/tests/test_lammps_spin_pt2.py new file mode 100644 index 0000000000..6a6dc50933 --- /dev/null +++ b/source/lmp/tests/test_lammps_spin_pt2.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data_spin, +) + +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa_spin.pt2" +) +pb_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa_spin_md1.pt2" +) +data_file = Path(__file__).parent / "data.lmp" +md_file = Path(__file__).parent / "md.out" + +# Reference values from the seed=1 .pt2 model (PBC, Model 0) +expected_e = 3.5053686886040974e-01 +expected_ae = np.array( + [ + 6.8203336981159180e-02, + 6.4899945717305470e-02, + 1.0867432727951952e-01, + 1.0875925888242556e-01, + ] +) +expected_f = np.array( + [ + [3.9965859059739960e-03, -1.5714685255928385e-03, 2.5489246986054267e-03], + [-3.2538293505478910e-03, 1.5372892024705638e-03, -2.6183962915675120e-03], + [-5.8572951970559290e-04, 1.2789559354576532e-04, -8.4978782279560860e-05], + [-1.5702703572051152e-04, -9.3716270423490900e-05, 1.5445037524164582e-04], + ] +) +expected_fm = np.array( + [ + [3.5882891345325510e-03, -1.4332293643181068e-03, 2.1309512069449844e-03], + [-2.2686566316524877e-04, 1.1563585411337640e-04, -1.5585998391414227e-03], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + ] +) + +# Reference values from the seed=2 .pt2 model (PBC, Model 1) +expected_f2 = np.array( + [ + [-3.0262814573947916e-02, 1.3107838688513508e-02, 2.7456572609621283e-02], + [3.2096308278552410e-02, -1.3110841000027918e-02, -2.7788457969008994e-02], + [-8.0682873449542990e-04, 5.2993401057113383e-04, -4.9511667664325060e-04], + [-1.0266649701090702e-03, -5.2693779419693080e-04, 8.2700203603093960e-04], + ] +) + +expected_fm2 = np.array( + [ + [-6.7580350367321040e-03, 3.1210297243763260e-03, -2.0248402789380206e-02], + [1.0009759711324267e-02, -4.1032469729948700e-03, -1.7144003722544775e-02], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [0.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + ] +) + +expected_v = -np.array( + [ + -6.5442162870715910e-04, + 3.3133477241365930e-04, + 3.9883281200020510e-04, + 3.5002780951132103e-04, + -1.1589573685630256e-04, + -2.2265549504892014e-04, + 3.4055402098896265e-04, + -1.9139376762381335e-04, + -3.3133994148957030e-04, + 3.1589094549377213e-03, + -1.3741877857714818e-03, + -2.5227682918535005e-03, + -1.4131573474734480e-03, + 6.0225037103514220e-04, + 1.0902107277722482e-03, + 1.4702420402848647e-03, + -6.1977156785934680e-04, + -1.1144364605899786e-03, + -2.1325917337790670e-03, + 7.0758382747778870e-05, + -1.2848658004242357e-05, + 1.0290158078156647e-04, + -2.7842261103620382e-05, + 3.2953845020209020e-05, + 2.2234591227388966e-04, + 2.2930239123112260e-05, + -3.1997128170916170e-05, + -8.1674704416169310e-04, + -1.1094383565005458e-04, + 1.6599006542109280e-04, + -1.2281050907954750e-04, + -5.0200238466137690e-05, + 7.2193112591762430e-05, + 2.7750902949636390e-04, + 8.2718432747583720e-05, + -1.1997871423457964e-04, + ] +).reshape(4, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +spin = np.array( + [ + [0, 0, 1.2737], + [0, 0, 1.2737], + [0, 0, 0], + [0, 0, 0], + ] +) +type_NiO = np.array([1, 1, 2, 2]) + + +def setup_module() -> None: + if os.environ.get("ENABLE_PYTORCH", "1") != "1": + pytest.skip( + "Skip test because PyTorch support is not enabled.", + ) + write_lmp_data_spin(box, coord, spin, type_NiO, data_file) + + +def teardown_module() -> None: + os.remove(data_file) + if md_file.exists(): + os.remove(md_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("spin") + if units == "metal": + lammps.neighbor("2.0 bin") + else: + raise ValueError("units for spin should be metal") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal": + lammps.mass("1 58") + lammps.mass("2 16") + else: + raise ValueError("units for spin should be metal") + if units == "metal": + lammps.timestep(0.0005) + else: + raise ValueError("units for spin should be metal") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps) -> None: + lammps.pair_style(f"deepspin {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps) -> None: + lammps.pair_style(f"deepspin {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("peatom all pe/atom pair") + lammps.compute("pressure all pressure NULL pair") + lammps.compute("virial all centroid/stress/atom NULL pair") + lammps.variable("eatom atom c_peatom") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"pressure{jj} equal c_pressure[{ii + 1}]") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii + 1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id")[: coord.shape[0]] - 1 + assert np.array(lammps.variables["eatom"].value) == pytest.approx( + expected_ae[idx_map] + ) + vol = box[1] * box[3] * box[5] + for ii in range(6): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + assert np.array( + lammps.variables[f"pressure{jj}"].value + ) / constants.nktv2p == pytest.approx( + -expected_v[idx_map, jj].sum(axis=0) / vol + ) + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + assert np.array( + lammps.variables[f"virial{jj}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, jj]) + + +def test_pair_deepmd_model_devi(lammps) -> None: + lammps.pair_style( + f"deepspin {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_fm = np.linalg.norm(np.std([expected_fm, expected_fm2], axis=0), axis=1) + # rel=1e-4: md.out is written with default scientific format (~6 significant digits) + assert md[4] == pytest.approx(np.max(expected_md_f), rel=1e-4) + assert md[5] == pytest.approx(np.min(expected_md_f), rel=1e-4) + assert md[6] == pytest.approx(np.mean(expected_md_f), rel=1e-4) + assert md[7] == pytest.approx(np.max(expected_md_fm), rel=1e-4) + assert md[8] == pytest.approx(np.min(expected_md_fm), rel=1e-4) + assert md[9] == pytest.approx(np.mean(expected_md_fm), rel=1e-4) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps) -> None: + relative = 1.0 + lammps.pair_style( + f"deepspin {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(4): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + norm_spin = np.linalg.norm(np.mean([expected_fm, expected_fm2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + expected_md_fm = np.linalg.norm(np.std([expected_fm, expected_fm2], axis=0), axis=1) + expected_md_fm /= norm_spin + relative + # rel=1e-4: md.out is written with default scientific format (~6 significant digits) + assert md[4] == pytest.approx(np.max(expected_md_f), rel=1e-4) + assert md[5] == pytest.approx(np.min(expected_md_f), rel=1e-4) + assert md[6] == pytest.approx(np.mean(expected_md_f), rel=1e-4) + assert md[7] == pytest.approx(np.max(expected_md_fm), rel=1e-4) + assert md[8] == pytest.approx(np.min(expected_md_fm), rel=1e-4) + assert md[9] == pytest.approx(np.mean(expected_md_fm), rel=1e-4) diff --git a/source/tests/infer/deeppot_dpa_spin.pth b/source/tests/infer/deeppot_dpa_spin.pth deleted file mode 100644 index 4b11aaf61b..0000000000 Binary files a/source/tests/infer/deeppot_dpa_spin.pth and /dev/null differ diff --git a/source/tests/infer/deeppot_dpa_spin.yaml b/source/tests/infer/deeppot_dpa_spin.yaml new file mode 100644 index 0000000000..b3088ef8b7 --- /dev/null +++ b/source/tests/infer/deeppot_dpa_spin.yaml @@ -0,0 +1,2921 @@ +backend: dpmodel +model: + backbone_model: + "@class": Model + "@variables": + out_bias: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + out_std: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + "@version": 2 + atom_exclude_types: &id002 + - 3 + - 4 + - 5 + descriptor: + "@class": Descriptor + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 2 + activation_function: tanh + attention_layers: + "@class": NeighborGatedAttention + "@version": 1 + attention_layers: + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.0773000670322781 + - -0.11285163747198214 + - 0.09198909919347824 + - 0.0015564608025799428 + - 0.15993211721468997 + - -0.035699099756999086 + - 0.18711493296570436 + - -0.3680327413169358 + - 0.3146711889101303 + - -0.32196784941870205 + - 0.33080166106245973 + - -0.12427897663338351 + - 0.1971349848817013 + - -0.25328479452025787 + - 0.05619010862452457 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.5453136153500983 + - 0.11646421136015596 + - 0.2728998160990953 + - 0.19211628323790922 + - 0.5121592586576245 + - 0.14015660310197084 + - -0.21283653571712532 + - 0.19270563868027393 + - -0.2869005655415182 + - -0.2805799563301185 + - -0.14688711556497824 + - -0.004936387806237014 + - 0.26891251777915626 + - 0.15015452892352035 + - -0.07209428351171145 + - - -0.2060930999925437 + - -0.11668702412699229 + - 0.03240261480877325 + - -0.12456108006262341 + - 0.28163990564667885 + - -0.32450675882278157 + - 0.14710907975796425 + - 0.0671544495392095 + - -0.007634575731316958 + - -0.37838289250073254 + - 0.13263368708165285 + - -0.1273788347128481 + - 0.2354650725899286 + - -0.26961937359743937 + - -0.1435463708204026 + - - 0.2971343040732761 + - 0.07312988992907121 + - -0.1747043346951614 + - 0.08715563315955775 + - 0.13322347767678513 + - -0.3891940345628173 + - 0.1166655194884039 + - 0.21413329963118571 + - 0.05413909843388683 + - -0.045229184330072024 + - 0.11979871279396558 + - -0.12072731953778283 + - -0.08723066293101692 + - 0.18005965452317124 + - -0.21828380200177724 + - - 0.31671405071433606 + - -0.2716965170499195 + - -0.028361254616657144 + - -0.2617245750818101 + - -0.09699398154549994 + - 0.18481031199205442 + - -0.040465631029007854 + - 0.14108925561333863 + - 0.09429489296773111 + - 0.004655528693258085 + - 0.14898310008646545 + - -0.0857640356727124 + - 0.08773909782891427 + - -0.04149990595119324 + - 0.20601404664335912 + - - 0.21391505516854736 + - -0.09852256414914695 + - -0.1362372794021844 + - 0.22405812208425802 + - 0.049318368740786864 + - 0.009037517617669527 + - 0.048770806587957634 + - 0.20453428339492002 + - -0.06278337870025578 + - 0.13590408268075893 + - 0.16733094069618198 + - -0.2558069441971945 + - 0.2539692845486608 + - -0.3822830942851515 + - -0.01320077168697971 + - - -0.23618346749100957 + - -0.16088435397543652 + - 0.0012095524450221038 + - 0.3434566733111669 + - -0.10412101115096369 + - -0.41354077587368426 + - -0.15301108452052156 + - -0.19472850640297268 + - -0.04784752055915662 + - 0.309354831872237 + - 0.03900287287097172 + - 0.3515719847777949 + - 0.07311373265713761 + - 0.21008558400100064 + - -0.1422281223027184 + - - -0.2501972762466357 + - -0.3947710066769227 + - -0.22159627517439104 + - -0.23717546465166323 + - 0.20361999068539405 + - 0.15834318996454627 + - -0.30369339295139164 + - -0.04594643602162514 + - -0.47003800284266484 + - 0.251313081682966 + - -0.017108155336677568 + - 0.015343488107173805 + - 0.4568124119251754 + - 0.36168954038818313 + - 0.15217782345376488 + - - -0.2661915245836493 + - 0.02323691834885091 + - -0.08549345931548089 + - -0.01576948173458242 + - -0.28637250692974264 + - -0.21920098175766897 + - 0.12766631250579263 + - 0.11029243899731755 + - -0.2570174719864179 + - -0.15426636992711557 + - -0.021417667228726223 + - -0.2669338979262442 + - 0.028601378532179773 + - -0.09310383237144683 + - 0.08623985160238709 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.012014674155435012 + - 0.3077665676691103 + - 0.0604726509386909 + - 0.1842684839388698 + - 0.6156330880385837 + - -0.32920196777558125 + - 0.25783671659080115 + - 0.33358312773786414 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.49949533257566686 + - 0.286547598925902 + - 0.4361493766595648 + - -0.1103280544393254 + - -0.5487598333772651 + - 0.24837784934669413 + - 0.21787651038518507 + - -0.3913940989346758 + - - -0.06463305179286359 + - -0.18513109011840434 + - 0.14079395290035332 + - -0.3234597846998149 + - 0.02730837414103255 + - 0.045234328486521896 + - -0.3465555485376244 + - 0.20186257177624717 + - - 0.4567961848887919 + - -0.521634029517498 + - -0.05827057440920121 + - 0.1202454889788886 + - -0.4149197758780397 + - 0.008473645410200539 + - 0.7860728226637158 + - -0.18512102597129287 + - - 0.02591992461101441 + - -0.19205447789337096 + - -0.37336035406244 + - -0.4396988201616634 + - -0.27280012147716604 + - 0.17196769473503715 + - -0.40497154915115735 + - -0.041913386087695355 + - - 0.15915541120069288 + - -0.1380975416704592 + - -0.218734837591988 + - 0.09092626183841235 + - -0.5174668044557068 + - -0.21180678070629808 + - -0.37390035436391533 + - 0.04945675263633716 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.10962922619214345 + - 0.31595207859196284 + - 0.1823135348000062 + - -0.47792369353249897 + - 0.4364181173781324 + - -0.3836680647022722 + - -0.23463686987556562 + - 0.15404135306221578 + - -0.034120822158277005 + - 0.3183953470379359 + - -0.21392442825478622 + - 0.012841805580442784 + - -0.14095292079989144 + - -0.25141701907332686 + - 0.1621740147854236 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.04432651454819303 + - 0.07074892196496135 + - -0.19255895137584625 + - -0.19981566952132743 + - 0.10148299481604649 + - 0.2160756004145394 + - 0.16060596072069733 + - 0.17111513606428155 + - 0.4109950675037482 + - 0.26635356350735206 + - -0.18156367732081702 + - 0.0895314214571873 + - 0.057816100603675764 + - 0.016503413619334724 + - 0.08760411333791569 + - - -0.13766990184830316 + - 0.06014881349046379 + - -0.21980672183305355 + - -0.30839871665039914 + - 0.20953147387373508 + - -0.19082927049353332 + - -0.23181550452962923 + - -0.08936046171408488 + - -0.078280225924176 + - -0.1325851587956408 + - -0.25378443361135683 + - 0.2948729948552392 + - -0.09325140297939892 + - -0.14574024680698858 + - -0.22099090633027849 + - - 0.045429709573574166 + - -0.1263395702542033 + - -0.1300203926721312 + - 0.17373924393475976 + - 0.24047252246649992 + - 0.2004896569348133 + - 0.20631458219195267 + - 0.37839714562334414 + - 0.11135669898146744 + - -0.11463405996357577 + - -0.03808978768230989 + - 0.20798353705347675 + - -0.14121267899214504 + - 0.13967453150311643 + - 0.10088049506078772 + - - 0.11892678366002153 + - 0.2864895972170223 + - 0.29319963315715897 + - 0.06058700578532366 + - 0.004055425018954525 + - -0.14393427105036072 + - 0.16887631470301387 + - 0.28358557330493 + - -0.2595140635431548 + - -0.02930562910993671 + - 0.26503383245499845 + - 0.41983297531600466 + - -0.1124681545867535 + - -0.3398171732971656 + - -0.2741135582533697 + - - 0.13451906128088217 + - 0.010888133480175035 + - 0.0900078518183752 + - -0.06409830878187231 + - -0.10502100821502645 + - 0.023079499499871807 + - 0.07414436589853633 + - 0.2629215171254904 + - 0.04642960150718239 + - 0.1988755806789641 + - -0.07629100170815732 + - 0.09921524655438335 + - -0.2289820670174585 + - 0.2549805511788397 + - -0.45039023498387787 + - - -0.24464144170777996 + - 0.1967455257626083 + - 0.06764093107202199 + - 0.17382272995208178 + - 0.027916207219719564 + - 0.11966027564666307 + - 0.21745098879410188 + - 0.05692010383329585 + - 0.05971014511869595 + - 0.2977723229828997 + - 0.24776831698997398 + - -0.1902685339637046 + - -0.004666749646972437 + - 0.14935028239849343 + - 0.11645608666956703 + - - -0.4230216490733608 + - -0.05426691735793483 + - -0.023390428143421422 + - -0.5907755934689719 + - 0.37172730079511573 + - -0.21505335264933229 + - 0.17965771743929285 + - -0.035102260175337095 + - -0.0541735937355792 + - -0.47337366894672506 + - -0.10268872287553306 + - -0.15453291293456548 + - -0.04828618892238899 + - 0.14269963343173614 + - 0.07558848779475102 + - - -0.2613291714419873 + - 0.2311123903147909 + - 0.08476759425156882 + - 0.08210712764205921 + - -0.08585522877843707 + - -0.006389859842742848 + - -0.08540677949494811 + - -0.09499113487917653 + - 0.09788200135453207 + - 0.014852707999100719 + - -0.44929890038313824 + - 0.1175298080148149 + - -0.04776284449064017 + - -0.1359122467400183 + - 0.049414904555106914 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.2565939634618157 + - 0.004620349221562676 + - 0.1883062163815952 + - -0.2239550357877486 + - 0.08983838195638551 + - 0.4820038902199821 + - -0.07270401390483885 + - -0.00404319693648555 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.16071798198404572 + - 0.10051996815693054 + - -0.5288655173097807 + - 0.130257474762195 + - 0.05581257096931307 + - 0.04277739706439883 + - -0.21694408078379357 + - 0.6915428323422272 + - - -0.2476547960349563 + - -0.03666478233510395 + - -0.06331623928804718 + - 0.24746908133520484 + - 0.2397629738591262 + - 0.4487859941811591 + - 0.15822993154370696 + - 0.07930115473978999 + - - 0.3274517903181917 + - 0.2301993971941225 + - -0.32163844585913093 + - 0.29445028785203947 + - -0.04638118953081447 + - -0.1381120700898825 + - 0.22376538676031252 + - -0.022521968311180376 + - - 0.04317334481920565 + - -0.1428718603586542 + - 0.3017126550044615 + - 0.4738575664985006 + - -0.27823139510825246 + - 0.007480026084941853 + - 0.038555120292235824 + - 0.5244825249079282 + - - 0.020736666414089052 + - 0.13216568109786306 + - -0.15650475768196248 + - 0.2089877109201956 + - 0.09322183733960938 + - -0.1896625990112233 + - 0.1445804222738735 + - 0.39042411800724985 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + layer_num: 2 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + concat_output_tebd: true + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 9 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.17372268332923477 + - 0.1871378134563067 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.18591758215712717 + - 0.27813472999079786 + - - 0.5413037502928204 + - -0.25070102149188617 + - - -0.6360424547704534 + - 0.21677859619201034 + - - -0.2596212453230471 + - -0.29375077070892397 + - - -0.04877331526782445 + - 0.19762680522737003 + - - 0.11098846129818017 + - -0.055797782097484434 + - - -0.006391287979703734 + - 0.41521313707463775 + - - 0.32771579096607567 + - -0.11341705155929026 + - - -0.1738235817859108 + - 0.05042848544602488 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.21199227364358092 + - 0.26084600040864775 + - -0.2672977629563241 + - -0.22294488782038938 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.2715638127686675 + - -0.39555091561915756 + - -0.49463070013277755 + - 0.2580401519962647 + - - -0.7238582261018045 + - -0.37337716297850404 + - -0.048321156004129756 + - 0.6516314685406116 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.04236615203535423 + - -0.39313642621575284 + - -0.14982659771560616 + - -0.05366082048750501 + - 0.1290929957744145 + - 0.32893913896132687 + - 0.4513568730490151 + - 0.10592804387643008 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.12373698637791548 + - -0.011845682303013905 + - -0.37854504841542985 + - -0.7733963434131446 + - 0.06982795430877005 + - -0.2911322653453856 + - -0.15962141672258884 + - 0.018716446641289086 + - - -0.45643928308288767 + - -0.4074922621785349 + - -0.10585596487381858 + - -0.4140870301183928 + - 0.15056708929600188 + - -0.24913213047723315 + - 0.1472722234467295 + - 0.207039979486021 + - - -0.1546276045097144 + - 0.45975631498836944 + - -0.2713617539469805 + - -0.4556419819708189 + - 0.46278476734589424 + - -0.16024061348091864 + - 0.27456475454336476 + - 0.41313594391240294 + - - -0.14985175126180378 + - 0.009489070195105672 + - 0.03606979879882141 + - 0.02529752572854437 + - 0.0772136131584972 + - -0.5502279296867952 + - -0.2340975341607699 + - 0.08633307850417127 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 2 + - 4 + - 8 + precision: float64 + resnet_dt: false + ntypes: 6 + env_mat: + protection: 1.0e-06 + rcut: 6.0 + rcut_smth: 2.0 + use_exp_switch: false + env_protection: 1.0e-06 + exclude_types: &id003 + - - 4 + - 0 + - - 4 + - 1 + - - 4 + - 2 + - - 4 + - 3 + - - 4 + - 4 + - - 4 + - 5 + - - 5 + - 0 + - - 5 + - 1 + - - 5 + - 2 + - - 5 + - 3 + - - 5 + - 4 + - - 5 + - 5 + ln_eps: 1.0e-05 + neuron: + - 2 + - 4 + - 8 + normalize: true + ntypes: 6 + precision: float64 + rcut: 6.0 + rcut_smth: 2.0 + resnet_dt: false + scaling_factor: 1.0 + sel: + - 30 + set_davg_zero: false + smooth_type_embedding: true + spin: null + tebd_dim: 8 + tebd_input_mode: concat + temperature: 1.0 + trainable: true + trainable_ln: true + type: dpa1 + type_embedding: + "@class": TypeEmbedNet + "@version": 2 + activation_function: Linear + embedding: + "@class": EmbeddingNetwork + "@version": 2 + activation_function: Linear + bias: false + in_dim: 6 + layers: + - "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.2927024092533758 + - -0.014162808051698143 + - 0.5337552806512624 + - 0.13124835611145563 + - -0.08736911906845456 + - -0.4469387822854708 + - -0.24092323623711465 + - -0.07236303858793405 + - - -0.029810578815809525 + - 0.13448218535369424 + - 0.3857245197302343 + - -0.1093625950883868 + - 0.18143197800870767 + - 0.1378970486893381 + - -0.4620301342837659 + - 0.11080395823154059 + - - -0.5453628627663579 + - -0.2515028548605599 + - 0.4755067239838334 + - -0.11492162761040413 + - 0.20288201321554736 + - 0.15184955393967264 + - 0.2549514898804775 + - 0.11458108875975528 + - - 0.4444838344094473 + - 0.0866527130137522 + - -0.05811512799228326 + - 0.2784336797539196 + - -0.16902635778153863 + - -0.10090527455466318 + - 0.09257632035486411 + - -0.08804146618056256 + - - -0.5361517594910684 + - -0.5304567668606835 + - 0.013353255324733446 + - 0.09254196348234821 + - 0.020815206464360567 + - 0.2831940642833442 + - -0.0032802268194216514 + - 0.04518720923263511 + - - -0.19075762886775724 + - 0.1670238834313649 + - 0.31687760687930183 + - -0.023313802928756507 + - -0.08946247104463122 + - -0.08622297317524152 + - -0.14368061835507168 + - 0.2766849049027015 + "@version": 2 + activation_function: Linear + bias: false + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 8 + precision: float64 + resnet_dt: false + neuron: + - 8 + ntypes: 6 + padding: true + precision: float64 + resnet_dt: false + trainable: true + type_map: &id001 + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_econf_tebd: false + use_tebd_bias: false + type_map: *id001 + type_one_side: true + use_econf_tebd: false + use_tebd_bias: false + fitting: + "@class": Fitting + "@variables": + aparam_avg: null + aparam_inv_std: null + bias_atom_e: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + case_embd: null + fparam_avg: null + fparam_inv_std: null + "@version": 4 + activation_function: tanh + atom_ener: null + default_fparam: null + dim_case_embd: 0 + dim_descrpt: 40 + dim_out: 1 + exclude_types: *id002 + layer_name: null + mixed_types: true + nets: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: fitting_network + networks: + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 40 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.19213583785141547 + - 0.01913350809462838 + - 0.061740233764521854 + - -0.2267647790849183 + - 0.12812766437356768 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.10722944375625902 + - 0.030285954989448978 + - -0.07751213660639378 + - -0.05835862674862787 + - 0.06079658252374028 + - - -0.14917809664011292 + - -0.1917000836760898 + - -0.18217651694274126 + - -0.03027373632656806 + - -0.2626097492205049 + - - -0.01587052426328515 + - -0.10141758102785384 + - 0.2884812641234704 + - 0.04070913632457048 + - -0.22952775106142526 + - - -0.005339896387597923 + - 0.2026717424567498 + - -0.23570070258569312 + - 0.02517187837708613 + - 0.006776247990408211 + - - 0.05869629681843979 + - -0.04536473825638193 + - -0.3331387257924115 + - 0.10410925750982085 + - -0.06226317817610261 + - - -0.34375470875866937 + - -0.21165266854067688 + - -0.07367160504875508 + - -0.04451329922224857 + - -0.07669864755078881 + - - 0.0019798728717504406 + - -0.05723324006026723 + - 0.08134147353272701 + - -0.20416051853341133 + - -0.1489448413235881 + - - -0.1297286685333724 + - -0.10254001818713723 + - -0.06485973011482828 + - -0.030747888978288476 + - -0.02400913589645549 + - - -0.094789713046158 + - -0.19167746643213587 + - -0.2503250569998187 + - 0.26434536457709407 + - 0.017547558090668915 + - - -0.29183772248369166 + - -0.036914405343361524 + - 0.03792963273035814 + - 0.17868452483507827 + - 0.00045001937022070665 + - - 0.17243079119989835 + - -0.23834705710954138 + - -0.09220221952867924 + - 0.018321049065233186 + - 0.03436188061915956 + - - 0.11077420998763597 + - -0.049847014250950866 + - -0.013865251301544756 + - 0.06342369857555755 + - -0.06594080637876534 + - - 0.07959185404659529 + - 0.15264265616361125 + - 0.10610192766737493 + - 0.02332598396345151 + - 0.11973629760425605 + - - -0.004068585904498259 + - 0.022424747708734243 + - 0.0904892803921496 + - 0.05915257009060455 + - -0.16743070962990553 + - - -0.1344159533436917 + - 0.06890292754397642 + - 0.040088028773263784 + - 0.06200012741623026 + - 0.06572901068559592 + - - -0.11089129384208427 + - -0.0248775999154777 + - 0.22639251490169005 + - 0.08926907460675748 + - 0.007766677806386451 + - - 0.04732793883524216 + - 0.3566550898924153 + - 0.053342420260193917 + - 0.10532169496785464 + - 0.05480515327344935 + - - -0.14795892356277907 + - -0.033583652192236545 + - 0.09467702932906724 + - -0.003445970791527109 + - -0.02970262899767938 + - - 0.01798891613889279 + - -0.2134710507671308 + - -0.03858114609602234 + - 0.15849887117268416 + - -0.011587330107593236 + - - 0.11836949735270877 + - 0.18151152582621652 + - 0.07797861333058898 + - 0.03771976731550859 + - 0.03860339300326443 + - - -0.24233917378226116 + - 0.11523548951310575 + - 0.11274022558623038 + - -0.13891602078102852 + - 0.038197951979408194 + - - 0.14748023102987462 + - 0.03809724924008725 + - -0.24004342053841574 + - -0.11674077556127527 + - 0.015631624270000557 + - - 0.1225811593516911 + - 0.17781553932560712 + - 0.2019153965211837 + - -0.03911653241832874 + - -0.04308105521791055 + - - -0.22890975503121846 + - 0.2631927308266052 + - 0.3396422256608267 + - 0.0060877665369990334 + - -0.16396998978274016 + - - -0.09015288252228824 + - -0.1462415933363514 + - -0.15500878428971507 + - 0.07258538808048835 + - 0.06972580948759297 + - - -0.013411147005090846 + - -0.11250472915278566 + - 0.14883490023381404 + - -0.11578579710847242 + - 0.14989570157529938 + - - 0.007217452542531879 + - -0.006294431390605557 + - -0.06761263100067849 + - 0.1024456071698635 + - -0.11305751449457453 + - - 0.04974042427247295 + - -0.10506914162544219 + - 0.25137476343128823 + - -0.18546350417307744 + - 0.2169319752796324 + - - -0.36953714727489606 + - -0.010694988244169208 + - 0.10310978529691683 + - 0.40241055217283583 + - 0.130304630596996 + - - -0.1300988477379937 + - 0.13898156842052345 + - -0.29616754521862765 + - 0.05477159336390151 + - -0.07430155564400093 + - - -0.23757959434444867 + - 0.18864334118518866 + - -0.10919537505938344 + - 0.03221667898269365 + - 0.06298671553338492 + - - -0.13366418607892308 + - -0.2137613131165465 + - -0.049788669977315965 + - -0.14137793021258133 + - -0.2713394719317361 + - - 0.10569214804393999 + - 0.2510285905202212 + - -0.27513277377115314 + - 0.04690770942719883 + - -0.34809118312878895 + - - 0.05890615039488246 + - 0.24348639943784472 + - 0.2405967304273802 + - -0.0033626925137259346 + - 0.30761898255775194 + - - -0.17185651083822667 + - -0.10734343952635343 + - -0.06568054347536262 + - 0.14655475328776377 + - -0.0776926672052371 + - - 0.031163930588449826 + - 0.23226835010332045 + - -0.03809535974683978 + - 0.13765491602806237 + - 0.06396617651310808 + - - -0.08473934132044068 + - -0.2682902010461224 + - 0.09550142632352028 + - 0.09194416152248457 + - -0.022502798179882343 + - - 0.1799453451027606 + - 0.053161530662990036 + - 0.2540666599422615 + - -0.07199521485924587 + - 0.18012605784803465 + - - -0.1474306683388992 + - 0.01800791937891814 + - -0.24916194237740008 + - -0.08249392187563226 + - -0.039169320735019185 + - - -0.07061780826603485 + - -0.05039829819305093 + - -0.3158384285468677 + - 0.0026965883039741095 + - 0.19109727147593047 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.09501004658257778 + - 0.1663807327224991 + - -0.5185313341630086 + - -0.7740662908662731 + - -0.18752579321547022 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.2188085499554977 + - -0.4014642473754725 + - 0.032489550654357095 + - 0.06343911616091243 + - -0.00407617112574573 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.3141891418069332 + - 0.30132598326837057 + - -0.1868614701027005 + - -0.1853536726835805 + - -0.14904917553209618 + - - -0.4993776326714626 + - 0.2929711950476154 + - -0.3300253064210836 + - -0.4799775188835898 + - -0.12327559985245252 + - - 0.16627900477763782 + - 0.18281489789715116 + - -0.0796215789550366 + - 0.11637836794519682 + - 0.019126199990905587 + - - 0.47193798042526686 + - 0.3935489978037474 + - 0.1926588188573466 + - 0.11685532990383077 + - -0.3143759410105157 + - - 0.2619509948079511 + - 0.17134734041574828 + - 0.16467987243470003 + - -0.17768942725372738 + - 0.17196893072212313 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.1498329073500072 + - -0.10390305511196503 + - -0.7262688617464856 + - -0.14980303343140125 + - -0.3578894004618838 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3290381873321775 + - 0.23103250534551598 + - -0.6940851206117438 + - -0.19335307745332778 + - -0.9240817753801489 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.3092290441156226 + - -0.496367611501348 + - -0.052492949379292775 + - 0.06663748312823926 + - 0.027714401468510886 + - - -0.10433141997317527 + - -0.323901631855259 + - -0.24739439873488192 + - 0.3076895568713741 + - 0.1593814472209255 + - - -0.07111829721069259 + - -0.27598680250101504 + - 0.16632764307325093 + - 0.1801382402999823 + - 0.3107523993064097 + - - -0.012140157566561928 + - 0.07469305237763302 + - 0.26428018852282276 + - -0.11500213881655802 + - -0.2731498304335624 + - - 0.29941998505510775 + - 0.39267279762211 + - 0.06586779164332648 + - 0.10010820203885952 + - -0.04143485413490972 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.044426160812178636 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.26432565710368733 + - - 0.17264367113482967 + - - -0.04729186377886323 + - - -0.08841444813809296 + - - 0.2969145415081517 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + neuron: + - 5 + - 5 + - 5 + out_dim: 1 + precision: float64 + resnet_dt: true + ntypes: 6 + neuron: + - 5 + - 5 + - 5 + ntypes: 6 + numb_aparam: 0 + numb_fparam: 0 + precision: float64 + rcond: null + resnet_dt: true + spin: null + tot_ener_zero: false + trainable: + - true + - true + - true + - true + type: ener + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_aparam_as_mask: false + var_name: energy + pair_exclude_types: *id003 + preset_out_bias: null + rcond: null + type: standard + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type: spin_ener +model_def_script: + descriptor: + activation_function: tanh + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + neuron: + - 2 + - 4 + - 8 + normalize: true + rcut: 6.0 + rcut_smth: 2.0 + scaling_factor: 1.0 + seed: 1 + sel: 30 + temperature: 1.0 + type: se_atten + type_one_side: true + fitting_net: + neuron: + - 5 + - 5 + - 5 + resnet_dt: true + seed: 1 + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type_map: + - Ni + - O + - H +software: deepmd-kit +time: "2026-04-04 15:06:22.248682+00:00" +version: 3.0.0 diff --git a/source/tests/infer/deeppot_dpa_spin_md0.yaml b/source/tests/infer/deeppot_dpa_spin_md0.yaml new file mode 100644 index 0000000000..32f719ba39 --- /dev/null +++ b/source/tests/infer/deeppot_dpa_spin_md0.yaml @@ -0,0 +1,2921 @@ +backend: dpmodel +model: + backbone_model: + "@class": Model + "@variables": + out_bias: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + out_std: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + "@version": 2 + atom_exclude_types: &id002 + - 3 + - 4 + - 5 + descriptor: + "@class": Descriptor + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 2 + activation_function: tanh + attention_layers: + "@class": NeighborGatedAttention + "@version": 1 + attention_layers: + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.0773000670322781 + - -0.11285163747198214 + - 0.09198909919347824 + - 0.0015564608025799428 + - 0.15993211721468997 + - -0.035699099756999086 + - 0.18711493296570436 + - -0.3680327413169358 + - 0.3146711889101303 + - -0.32196784941870205 + - 0.33080166106245973 + - -0.12427897663338351 + - 0.1971349848817013 + - -0.25328479452025787 + - 0.05619010862452457 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.5453136153500983 + - 0.11646421136015596 + - 0.2728998160990953 + - 0.19211628323790922 + - 0.5121592586576245 + - 0.14015660310197084 + - -0.21283653571712532 + - 0.19270563868027393 + - -0.2869005655415182 + - -0.2805799563301185 + - -0.14688711556497824 + - -0.004936387806237014 + - 0.26891251777915626 + - 0.15015452892352035 + - -0.07209428351171145 + - - -0.2060930999925437 + - -0.11668702412699229 + - 0.03240261480877325 + - -0.12456108006262341 + - 0.28163990564667885 + - -0.32450675882278157 + - 0.14710907975796425 + - 0.0671544495392095 + - -0.007634575731316958 + - -0.37838289250073254 + - 0.13263368708165285 + - -0.1273788347128481 + - 0.2354650725899286 + - -0.26961937359743937 + - -0.1435463708204026 + - - 0.2971343040732761 + - 0.07312988992907121 + - -0.1747043346951614 + - 0.08715563315955775 + - 0.13322347767678513 + - -0.3891940345628173 + - 0.1166655194884039 + - 0.21413329963118571 + - 0.05413909843388683 + - -0.045229184330072024 + - 0.11979871279396558 + - -0.12072731953778283 + - -0.08723066293101692 + - 0.18005965452317124 + - -0.21828380200177724 + - - 0.31671405071433606 + - -0.2716965170499195 + - -0.028361254616657144 + - -0.2617245750818101 + - -0.09699398154549994 + - 0.18481031199205442 + - -0.040465631029007854 + - 0.14108925561333863 + - 0.09429489296773111 + - 0.004655528693258085 + - 0.14898310008646545 + - -0.0857640356727124 + - 0.08773909782891427 + - -0.04149990595119324 + - 0.20601404664335912 + - - 0.21391505516854736 + - -0.09852256414914695 + - -0.1362372794021844 + - 0.22405812208425802 + - 0.049318368740786864 + - 0.009037517617669527 + - 0.048770806587957634 + - 0.20453428339492002 + - -0.06278337870025578 + - 0.13590408268075893 + - 0.16733094069618198 + - -0.2558069441971945 + - 0.2539692845486608 + - -0.3822830942851515 + - -0.01320077168697971 + - - -0.23618346749100957 + - -0.16088435397543652 + - 0.0012095524450221038 + - 0.3434566733111669 + - -0.10412101115096369 + - -0.41354077587368426 + - -0.15301108452052156 + - -0.19472850640297268 + - -0.04784752055915662 + - 0.309354831872237 + - 0.03900287287097172 + - 0.3515719847777949 + - 0.07311373265713761 + - 0.21008558400100064 + - -0.1422281223027184 + - - -0.2501972762466357 + - -0.3947710066769227 + - -0.22159627517439104 + - -0.23717546465166323 + - 0.20361999068539405 + - 0.15834318996454627 + - -0.30369339295139164 + - -0.04594643602162514 + - -0.47003800284266484 + - 0.251313081682966 + - -0.017108155336677568 + - 0.015343488107173805 + - 0.4568124119251754 + - 0.36168954038818313 + - 0.15217782345376488 + - - -0.2661915245836493 + - 0.02323691834885091 + - -0.08549345931548089 + - -0.01576948173458242 + - -0.28637250692974264 + - -0.21920098175766897 + - 0.12766631250579263 + - 0.11029243899731755 + - -0.2570174719864179 + - -0.15426636992711557 + - -0.021417667228726223 + - -0.2669338979262442 + - 0.028601378532179773 + - -0.09310383237144683 + - 0.08623985160238709 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.012014674155435012 + - 0.3077665676691103 + - 0.0604726509386909 + - 0.1842684839388698 + - 0.6156330880385837 + - -0.32920196777558125 + - 0.25783671659080115 + - 0.33358312773786414 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.49949533257566686 + - 0.286547598925902 + - 0.4361493766595648 + - -0.1103280544393254 + - -0.5487598333772651 + - 0.24837784934669413 + - 0.21787651038518507 + - -0.3913940989346758 + - - -0.06463305179286359 + - -0.18513109011840434 + - 0.14079395290035332 + - -0.3234597846998149 + - 0.02730837414103255 + - 0.045234328486521896 + - -0.3465555485376244 + - 0.20186257177624717 + - - 0.4567961848887919 + - -0.521634029517498 + - -0.05827057440920121 + - 0.1202454889788886 + - -0.4149197758780397 + - 0.008473645410200539 + - 0.7860728226637158 + - -0.18512102597129287 + - - 0.02591992461101441 + - -0.19205447789337096 + - -0.37336035406244 + - -0.4396988201616634 + - -0.27280012147716604 + - 0.17196769473503715 + - -0.40497154915115735 + - -0.041913386087695355 + - - 0.15915541120069288 + - -0.1380975416704592 + - -0.218734837591988 + - 0.09092626183841235 + - -0.5174668044557068 + - -0.21180678070629808 + - -0.37390035436391533 + - 0.04945675263633716 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.10962922619214345 + - 0.31595207859196284 + - 0.1823135348000062 + - -0.47792369353249897 + - 0.4364181173781324 + - -0.3836680647022722 + - -0.23463686987556562 + - 0.15404135306221578 + - -0.034120822158277005 + - 0.3183953470379359 + - -0.21392442825478622 + - 0.012841805580442784 + - -0.14095292079989144 + - -0.25141701907332686 + - 0.1621740147854236 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.04432651454819303 + - 0.07074892196496135 + - -0.19255895137584625 + - -0.19981566952132743 + - 0.10148299481604649 + - 0.2160756004145394 + - 0.16060596072069733 + - 0.17111513606428155 + - 0.4109950675037482 + - 0.26635356350735206 + - -0.18156367732081702 + - 0.0895314214571873 + - 0.057816100603675764 + - 0.016503413619334724 + - 0.08760411333791569 + - - -0.13766990184830316 + - 0.06014881349046379 + - -0.21980672183305355 + - -0.30839871665039914 + - 0.20953147387373508 + - -0.19082927049353332 + - -0.23181550452962923 + - -0.08936046171408488 + - -0.078280225924176 + - -0.1325851587956408 + - -0.25378443361135683 + - 0.2948729948552392 + - -0.09325140297939892 + - -0.14574024680698858 + - -0.22099090633027849 + - - 0.045429709573574166 + - -0.1263395702542033 + - -0.1300203926721312 + - 0.17373924393475976 + - 0.24047252246649992 + - 0.2004896569348133 + - 0.20631458219195267 + - 0.37839714562334414 + - 0.11135669898146744 + - -0.11463405996357577 + - -0.03808978768230989 + - 0.20798353705347675 + - -0.14121267899214504 + - 0.13967453150311643 + - 0.10088049506078772 + - - 0.11892678366002153 + - 0.2864895972170223 + - 0.29319963315715897 + - 0.06058700578532366 + - 0.004055425018954525 + - -0.14393427105036072 + - 0.16887631470301387 + - 0.28358557330493 + - -0.2595140635431548 + - -0.02930562910993671 + - 0.26503383245499845 + - 0.41983297531600466 + - -0.1124681545867535 + - -0.3398171732971656 + - -0.2741135582533697 + - - 0.13451906128088217 + - 0.010888133480175035 + - 0.0900078518183752 + - -0.06409830878187231 + - -0.10502100821502645 + - 0.023079499499871807 + - 0.07414436589853633 + - 0.2629215171254904 + - 0.04642960150718239 + - 0.1988755806789641 + - -0.07629100170815732 + - 0.09921524655438335 + - -0.2289820670174585 + - 0.2549805511788397 + - -0.45039023498387787 + - - -0.24464144170777996 + - 0.1967455257626083 + - 0.06764093107202199 + - 0.17382272995208178 + - 0.027916207219719564 + - 0.11966027564666307 + - 0.21745098879410188 + - 0.05692010383329585 + - 0.05971014511869595 + - 0.2977723229828997 + - 0.24776831698997398 + - -0.1902685339637046 + - -0.004666749646972437 + - 0.14935028239849343 + - 0.11645608666956703 + - - -0.4230216490733608 + - -0.05426691735793483 + - -0.023390428143421422 + - -0.5907755934689719 + - 0.37172730079511573 + - -0.21505335264933229 + - 0.17965771743929285 + - -0.035102260175337095 + - -0.0541735937355792 + - -0.47337366894672506 + - -0.10268872287553306 + - -0.15453291293456548 + - -0.04828618892238899 + - 0.14269963343173614 + - 0.07558848779475102 + - - -0.2613291714419873 + - 0.2311123903147909 + - 0.08476759425156882 + - 0.08210712764205921 + - -0.08585522877843707 + - -0.006389859842742848 + - -0.08540677949494811 + - -0.09499113487917653 + - 0.09788200135453207 + - 0.014852707999100719 + - -0.44929890038313824 + - 0.1175298080148149 + - -0.04776284449064017 + - -0.1359122467400183 + - 0.049414904555106914 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.2565939634618157 + - 0.004620349221562676 + - 0.1883062163815952 + - -0.2239550357877486 + - 0.08983838195638551 + - 0.4820038902199821 + - -0.07270401390483885 + - -0.00404319693648555 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.16071798198404572 + - 0.10051996815693054 + - -0.5288655173097807 + - 0.130257474762195 + - 0.05581257096931307 + - 0.04277739706439883 + - -0.21694408078379357 + - 0.6915428323422272 + - - -0.2476547960349563 + - -0.03666478233510395 + - -0.06331623928804718 + - 0.24746908133520484 + - 0.2397629738591262 + - 0.4487859941811591 + - 0.15822993154370696 + - 0.07930115473978999 + - - 0.3274517903181917 + - 0.2301993971941225 + - -0.32163844585913093 + - 0.29445028785203947 + - -0.04638118953081447 + - -0.1381120700898825 + - 0.22376538676031252 + - -0.022521968311180376 + - - 0.04317334481920565 + - -0.1428718603586542 + - 0.3017126550044615 + - 0.4738575664985006 + - -0.27823139510825246 + - 0.007480026084941853 + - 0.038555120292235824 + - 0.5244825249079282 + - - 0.020736666414089052 + - 0.13216568109786306 + - -0.15650475768196248 + - 0.2089877109201956 + - 0.09322183733960938 + - -0.1896625990112233 + - 0.1445804222738735 + - 0.39042411800724985 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + layer_num: 2 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + concat_output_tebd: true + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 9 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.17372268332923477 + - 0.1871378134563067 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.18591758215712717 + - 0.27813472999079786 + - - 0.5413037502928204 + - -0.25070102149188617 + - - -0.6360424547704534 + - 0.21677859619201034 + - - -0.2596212453230471 + - -0.29375077070892397 + - - -0.04877331526782445 + - 0.19762680522737003 + - - 0.11098846129818017 + - -0.055797782097484434 + - - -0.006391287979703734 + - 0.41521313707463775 + - - 0.32771579096607567 + - -0.11341705155929026 + - - -0.1738235817859108 + - 0.05042848544602488 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.21199227364358092 + - 0.26084600040864775 + - -0.2672977629563241 + - -0.22294488782038938 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.2715638127686675 + - -0.39555091561915756 + - -0.49463070013277755 + - 0.2580401519962647 + - - -0.7238582261018045 + - -0.37337716297850404 + - -0.048321156004129756 + - 0.6516314685406116 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.04236615203535423 + - -0.39313642621575284 + - -0.14982659771560616 + - -0.05366082048750501 + - 0.1290929957744145 + - 0.32893913896132687 + - 0.4513568730490151 + - 0.10592804387643008 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.12373698637791548 + - -0.011845682303013905 + - -0.37854504841542985 + - -0.7733963434131446 + - 0.06982795430877005 + - -0.2911322653453856 + - -0.15962141672258884 + - 0.018716446641289086 + - - -0.45643928308288767 + - -0.4074922621785349 + - -0.10585596487381858 + - -0.4140870301183928 + - 0.15056708929600188 + - -0.24913213047723315 + - 0.1472722234467295 + - 0.207039979486021 + - - -0.1546276045097144 + - 0.45975631498836944 + - -0.2713617539469805 + - -0.4556419819708189 + - 0.46278476734589424 + - -0.16024061348091864 + - 0.27456475454336476 + - 0.41313594391240294 + - - -0.14985175126180378 + - 0.009489070195105672 + - 0.03606979879882141 + - 0.02529752572854437 + - 0.0772136131584972 + - -0.5502279296867952 + - -0.2340975341607699 + - 0.08633307850417127 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 2 + - 4 + - 8 + precision: float64 + resnet_dt: false + ntypes: 6 + env_mat: + protection: 1.0e-06 + rcut: 6.0 + rcut_smth: 2.0 + use_exp_switch: false + env_protection: 1.0e-06 + exclude_types: &id003 + - - 4 + - 0 + - - 4 + - 1 + - - 4 + - 2 + - - 4 + - 3 + - - 4 + - 4 + - - 4 + - 5 + - - 5 + - 0 + - - 5 + - 1 + - - 5 + - 2 + - - 5 + - 3 + - - 5 + - 4 + - - 5 + - 5 + ln_eps: 1.0e-05 + neuron: + - 2 + - 4 + - 8 + normalize: true + ntypes: 6 + precision: float64 + rcut: 6.0 + rcut_smth: 2.0 + resnet_dt: false + scaling_factor: 1.0 + sel: + - 30 + set_davg_zero: false + smooth_type_embedding: true + spin: null + tebd_dim: 8 + tebd_input_mode: concat + temperature: 1.0 + trainable: true + trainable_ln: true + type: dpa1 + type_embedding: + "@class": TypeEmbedNet + "@version": 2 + activation_function: Linear + embedding: + "@class": EmbeddingNetwork + "@version": 2 + activation_function: Linear + bias: false + in_dim: 6 + layers: + - "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.2927024092533758 + - -0.014162808051698143 + - 0.5337552806512624 + - 0.13124835611145563 + - -0.08736911906845456 + - -0.4469387822854708 + - -0.24092323623711465 + - -0.07236303858793405 + - - -0.029810578815809525 + - 0.13448218535369424 + - 0.3857245197302343 + - -0.1093625950883868 + - 0.18143197800870767 + - 0.1378970486893381 + - -0.4620301342837659 + - 0.11080395823154059 + - - -0.5453628627663579 + - -0.2515028548605599 + - 0.4755067239838334 + - -0.11492162761040413 + - 0.20288201321554736 + - 0.15184955393967264 + - 0.2549514898804775 + - 0.11458108875975528 + - - 0.4444838344094473 + - 0.0866527130137522 + - -0.05811512799228326 + - 0.2784336797539196 + - -0.16902635778153863 + - -0.10090527455466318 + - 0.09257632035486411 + - -0.08804146618056256 + - - -0.5361517594910684 + - -0.5304567668606835 + - 0.013353255324733446 + - 0.09254196348234821 + - 0.020815206464360567 + - 0.2831940642833442 + - -0.0032802268194216514 + - 0.04518720923263511 + - - -0.19075762886775724 + - 0.1670238834313649 + - 0.31687760687930183 + - -0.023313802928756507 + - -0.08946247104463122 + - -0.08622297317524152 + - -0.14368061835507168 + - 0.2766849049027015 + "@version": 2 + activation_function: Linear + bias: false + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 8 + precision: float64 + resnet_dt: false + neuron: + - 8 + ntypes: 6 + padding: true + precision: float64 + resnet_dt: false + trainable: true + type_map: &id001 + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_econf_tebd: false + use_tebd_bias: false + type_map: *id001 + type_one_side: true + use_econf_tebd: false + use_tebd_bias: false + fitting: + "@class": Fitting + "@variables": + aparam_avg: null + aparam_inv_std: null + bias_atom_e: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + case_embd: null + fparam_avg: null + fparam_inv_std: null + "@version": 4 + activation_function: tanh + atom_ener: null + default_fparam: null + dim_case_embd: 0 + dim_descrpt: 40 + dim_out: 1 + exclude_types: *id002 + layer_name: null + mixed_types: true + nets: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: fitting_network + networks: + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 40 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.19213583785141547 + - 0.01913350809462838 + - 0.061740233764521854 + - -0.2267647790849183 + - 0.12812766437356768 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.10722944375625902 + - 0.030285954989448978 + - -0.07751213660639378 + - -0.05835862674862787 + - 0.06079658252374028 + - - -0.14917809664011292 + - -0.1917000836760898 + - -0.18217651694274126 + - -0.03027373632656806 + - -0.2626097492205049 + - - -0.01587052426328515 + - -0.10141758102785384 + - 0.2884812641234704 + - 0.04070913632457048 + - -0.22952775106142526 + - - -0.005339896387597923 + - 0.2026717424567498 + - -0.23570070258569312 + - 0.02517187837708613 + - 0.006776247990408211 + - - 0.05869629681843979 + - -0.04536473825638193 + - -0.3331387257924115 + - 0.10410925750982085 + - -0.06226317817610261 + - - -0.34375470875866937 + - -0.21165266854067688 + - -0.07367160504875508 + - -0.04451329922224857 + - -0.07669864755078881 + - - 0.0019798728717504406 + - -0.05723324006026723 + - 0.08134147353272701 + - -0.20416051853341133 + - -0.1489448413235881 + - - -0.1297286685333724 + - -0.10254001818713723 + - -0.06485973011482828 + - -0.030747888978288476 + - -0.02400913589645549 + - - -0.094789713046158 + - -0.19167746643213587 + - -0.2503250569998187 + - 0.26434536457709407 + - 0.017547558090668915 + - - -0.29183772248369166 + - -0.036914405343361524 + - 0.03792963273035814 + - 0.17868452483507827 + - 0.00045001937022070665 + - - 0.17243079119989835 + - -0.23834705710954138 + - -0.09220221952867924 + - 0.018321049065233186 + - 0.03436188061915956 + - - 0.11077420998763597 + - -0.049847014250950866 + - -0.013865251301544756 + - 0.06342369857555755 + - -0.06594080637876534 + - - 0.07959185404659529 + - 0.15264265616361125 + - 0.10610192766737493 + - 0.02332598396345151 + - 0.11973629760425605 + - - -0.004068585904498259 + - 0.022424747708734243 + - 0.0904892803921496 + - 0.05915257009060455 + - -0.16743070962990553 + - - -0.1344159533436917 + - 0.06890292754397642 + - 0.040088028773263784 + - 0.06200012741623026 + - 0.06572901068559592 + - - -0.11089129384208427 + - -0.0248775999154777 + - 0.22639251490169005 + - 0.08926907460675748 + - 0.007766677806386451 + - - 0.04732793883524216 + - 0.3566550898924153 + - 0.053342420260193917 + - 0.10532169496785464 + - 0.05480515327344935 + - - -0.14795892356277907 + - -0.033583652192236545 + - 0.09467702932906724 + - -0.003445970791527109 + - -0.02970262899767938 + - - 0.01798891613889279 + - -0.2134710507671308 + - -0.03858114609602234 + - 0.15849887117268416 + - -0.011587330107593236 + - - 0.11836949735270877 + - 0.18151152582621652 + - 0.07797861333058898 + - 0.03771976731550859 + - 0.03860339300326443 + - - -0.24233917378226116 + - 0.11523548951310575 + - 0.11274022558623038 + - -0.13891602078102852 + - 0.038197951979408194 + - - 0.14748023102987462 + - 0.03809724924008725 + - -0.24004342053841574 + - -0.11674077556127527 + - 0.015631624270000557 + - - 0.1225811593516911 + - 0.17781553932560712 + - 0.2019153965211837 + - -0.03911653241832874 + - -0.04308105521791055 + - - -0.22890975503121846 + - 0.2631927308266052 + - 0.3396422256608267 + - 0.0060877665369990334 + - -0.16396998978274016 + - - -0.09015288252228824 + - -0.1462415933363514 + - -0.15500878428971507 + - 0.07258538808048835 + - 0.06972580948759297 + - - -0.013411147005090846 + - -0.11250472915278566 + - 0.14883490023381404 + - -0.11578579710847242 + - 0.14989570157529938 + - - 0.007217452542531879 + - -0.006294431390605557 + - -0.06761263100067849 + - 0.1024456071698635 + - -0.11305751449457453 + - - 0.04974042427247295 + - -0.10506914162544219 + - 0.25137476343128823 + - -0.18546350417307744 + - 0.2169319752796324 + - - -0.36953714727489606 + - -0.010694988244169208 + - 0.10310978529691683 + - 0.40241055217283583 + - 0.130304630596996 + - - -0.1300988477379937 + - 0.13898156842052345 + - -0.29616754521862765 + - 0.05477159336390151 + - -0.07430155564400093 + - - -0.23757959434444867 + - 0.18864334118518866 + - -0.10919537505938344 + - 0.03221667898269365 + - 0.06298671553338492 + - - -0.13366418607892308 + - -0.2137613131165465 + - -0.049788669977315965 + - -0.14137793021258133 + - -0.2713394719317361 + - - 0.10569214804393999 + - 0.2510285905202212 + - -0.27513277377115314 + - 0.04690770942719883 + - -0.34809118312878895 + - - 0.05890615039488246 + - 0.24348639943784472 + - 0.2405967304273802 + - -0.0033626925137259346 + - 0.30761898255775194 + - - -0.17185651083822667 + - -0.10734343952635343 + - -0.06568054347536262 + - 0.14655475328776377 + - -0.0776926672052371 + - - 0.031163930588449826 + - 0.23226835010332045 + - -0.03809535974683978 + - 0.13765491602806237 + - 0.06396617651310808 + - - -0.08473934132044068 + - -0.2682902010461224 + - 0.09550142632352028 + - 0.09194416152248457 + - -0.022502798179882343 + - - 0.1799453451027606 + - 0.053161530662990036 + - 0.2540666599422615 + - -0.07199521485924587 + - 0.18012605784803465 + - - -0.1474306683388992 + - 0.01800791937891814 + - -0.24916194237740008 + - -0.08249392187563226 + - -0.039169320735019185 + - - -0.07061780826603485 + - -0.05039829819305093 + - -0.3158384285468677 + - 0.0026965883039741095 + - 0.19109727147593047 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.09501004658257778 + - 0.1663807327224991 + - -0.5185313341630086 + - -0.7740662908662731 + - -0.18752579321547022 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.2188085499554977 + - -0.4014642473754725 + - 0.032489550654357095 + - 0.06343911616091243 + - -0.00407617112574573 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.3141891418069332 + - 0.30132598326837057 + - -0.1868614701027005 + - -0.1853536726835805 + - -0.14904917553209618 + - - -0.4993776326714626 + - 0.2929711950476154 + - -0.3300253064210836 + - -0.4799775188835898 + - -0.12327559985245252 + - - 0.16627900477763782 + - 0.18281489789715116 + - -0.0796215789550366 + - 0.11637836794519682 + - 0.019126199990905587 + - - 0.47193798042526686 + - 0.3935489978037474 + - 0.1926588188573466 + - 0.11685532990383077 + - -0.3143759410105157 + - - 0.2619509948079511 + - 0.17134734041574828 + - 0.16467987243470003 + - -0.17768942725372738 + - 0.17196893072212313 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.1498329073500072 + - -0.10390305511196503 + - -0.7262688617464856 + - -0.14980303343140125 + - -0.3578894004618838 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.3290381873321775 + - 0.23103250534551598 + - -0.6940851206117438 + - -0.19335307745332778 + - -0.9240817753801489 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.3092290441156226 + - -0.496367611501348 + - -0.052492949379292775 + - 0.06663748312823926 + - 0.027714401468510886 + - - -0.10433141997317527 + - -0.323901631855259 + - -0.24739439873488192 + - 0.3076895568713741 + - 0.1593814472209255 + - - -0.07111829721069259 + - -0.27598680250101504 + - 0.16632764307325093 + - 0.1801382402999823 + - 0.3107523993064097 + - - -0.012140157566561928 + - 0.07469305237763302 + - 0.26428018852282276 + - -0.11500213881655802 + - -0.2731498304335624 + - - 0.29941998505510775 + - 0.39267279762211 + - 0.06586779164332648 + - 0.10010820203885952 + - -0.04143485413490972 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.044426160812178636 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.26432565710368733 + - - 0.17264367113482967 + - - -0.04729186377886323 + - - -0.08841444813809296 + - - 0.2969145415081517 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + neuron: + - 5 + - 5 + - 5 + out_dim: 1 + precision: float64 + resnet_dt: true + ntypes: 6 + neuron: + - 5 + - 5 + - 5 + ntypes: 6 + numb_aparam: 0 + numb_fparam: 0 + precision: float64 + rcond: null + resnet_dt: true + spin: null + tot_ener_zero: false + trainable: + - true + - true + - true + - true + type: ener + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_aparam_as_mask: false + var_name: energy + pair_exclude_types: *id003 + preset_out_bias: null + rcond: null + type: standard + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type: spin_ener +model_def_script: + descriptor: + activation_function: tanh + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + neuron: + - 2 + - 4 + - 8 + normalize: true + rcut: 6.0 + rcut_smth: 2.0 + scaling_factor: 1.0 + seed: 1 + sel: 30 + temperature: 1.0 + type: se_atten + type_one_side: true + fitting_net: + neuron: + - 5 + - 5 + - 5 + resnet_dt: true + seed: 1 + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type_map: + - Ni + - O + - H +software: deepmd-kit +time: "2026-04-04 15:08:00.212056+00:00" +version: 3.0.0 diff --git a/source/tests/infer/deeppot_dpa_spin_md1.yaml b/source/tests/infer/deeppot_dpa_spin_md1.yaml new file mode 100644 index 0000000000..bebf070303 --- /dev/null +++ b/source/tests/infer/deeppot_dpa_spin_md1.yaml @@ -0,0 +1,2921 @@ +backend: dpmodel +model: + backbone_model: + "@class": Model + "@variables": + out_bias: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + out_std: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + - - 1.0 + "@version": 2 + atom_exclude_types: &id002 + - 3 + - 4 + - 5 + descriptor: + "@class": Descriptor + "@variables": + davg: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + - - 0.0 + - 0.0 + - 0.0 + - 0.0 + dstd: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + - - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 2 + activation_function: tanh + attention_layers: + "@class": NeighborGatedAttention + "@version": 1 + attention_layers: + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.18351022625104238 + - -0.04563840943844596 + - 0.13260332418881654 + - -0.0691278425455087 + - -0.06945926323289434 + - 0.43604143150778013 + - -0.08368462141427374 + - -0.1387981791306445 + - 0.005993298283937215 + - -0.5151420582401797 + - 0.19472845492587332 + - 0.19416952313404776 + - -0.010708458231856462 + - -0.18754060846205406 + - 0.07338291320712505 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.24841318089090828 + - 0.1038571583323193 + - -0.2091532332904543 + - 0.4136790735317951 + - -0.27303583663189573 + - -0.12466396686270144 + - 0.20857122043870152 + - 0.2276345670041446 + - -0.24270398224650405 + - 0.007281618781037467 + - -0.07665038914769577 + - -0.18006026900601504 + - -0.08609172045185601 + - -0.34723410497230056 + - 0.07922223598844129 + - - 0.20351217415988776 + - 0.22465419151464255 + - -0.33559874499082115 + - 0.13592429249065166 + - -0.1648509449996543 + - -0.23203528984387572 + - -0.0516465438807872 + - 0.0686370523996727 + - 0.41471124141728716 + - 0.0586914485741541 + - -0.1689459931450781 + - 0.0018059434210212844 + - -0.4147053758686676 + - -0.07048107144267868 + - 0.40259566794347623 + - - -0.11391575567260738 + - 0.23863818963619923 + - 0.18250695785065604 + - 0.19281299906723257 + - 0.247621826822558 + - -0.0985542576783352 + - -0.035082160489491136 + - 0.4792148584108256 + - 0.14167746853937516 + - -0.06814603614410701 + - -0.22665968961826935 + - 0.3242244773260558 + - -0.066079527329557 + - -0.1307307931893784 + - 0.051956143640524394 + - - -0.2056090732642919 + - 0.07357885580248055 + - 0.08681029379913797 + - 0.07247821343078562 + - -0.06712483720853006 + - 0.39719167006621764 + - 0.11149451147606776 + - -0.12278081601771884 + - 0.045171486309587526 + - -0.0755253250897645 + - -0.21640855235984527 + - 0.08547559318182094 + - 0.1653564252470324 + - 0.0952574448166531 + - -0.1377746815508281 + - - -0.022519359735599515 + - 0.05382946104963937 + - 0.005065503738307829 + - -0.11001639078999409 + - 0.2462411205677168 + - -0.2095220212541597 + - 0.1858887564410953 + - -0.1867738843692884 + - 0.016117042040479493 + - 0.28249082907284107 + - 0.22436922074999183 + - -0.22791154747547826 + - 0.29699526367685036 + - 0.3058975240069586 + - -0.1263063786358465 + - - 0.228630022209369 + - 0.03600580997662669 + - -0.30041096458994776 + - 0.05215707817418025 + - -0.11590856918439549 + - 0.04613204765185727 + - -0.15013167647158984 + - 0.03175216690660104 + - 0.12894979742295995 + - -0.14577555003932247 + - -0.20045529796863254 + - 0.03682374995063141 + - 0.07380272734778177 + - 0.13093301266802376 + - -0.10845218099118197 + - - 0.16340966236541793 + - 0.07891342207116378 + - -0.24332865990588795 + - -0.14554580109435383 + - 0.12531904655031273 + - -0.03349136271276537 + - 0.2971042679187557 + - 0.1386000661991385 + - 0.20111089109063607 + - 0.3347783227216307 + - -0.09930032441066812 + - 0.09254914049778377 + - 0.1523072534711225 + - -0.24890506615160968 + - -0.28977438340158623 + - - 0.101287925304405 + - -0.2563848385489577 + - 0.03647978373709844 + - 0.329129480771459 + - 0.3005480604309731 + - 0.07018058650703372 + - -0.08448764525029563 + - 0.2670352098254057 + - -0.039653866987632895 + - -0.06155552979117676 + - 0.0668749042632758 + - -0.1316809083091721 + - -0.21948380976866877 + - 0.06412295486182476 + - 0.1618133144682503 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.20045376999128472 + - 0.2987805408490743 + - 0.24721299447247055 + - 0.09900927157798681 + - -0.0004981500361246952 + - -0.22867694046198447 + - 0.30251015671967146 + - 0.3634748074255848 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.32306219736034275 + - 0.19545694221853752 + - 0.27273162436938836 + - -0.4955183264027905 + - 0.2654826760807126 + - -0.2674345947352222 + - 0.2973724985053365 + - 0.2928827308944446 + - - 0.22558964338136114 + - 0.09428873844711233 + - 0.05575617130102639 + - 0.3111520828678394 + - 0.18699860063303472 + - 0.000685428165526647 + - -0.6174343148136284 + - -0.5491849734438004 + - - -0.0240211843189102 + - 0.08977892740110542 + - 0.12310787770210622 + - 0.12534700620727446 + - 0.5211921965444312 + - -0.18388986716415426 + - -0.49576670250576726 + - 0.2945145753065993 + - - 0.14386152783626194 + - -0.3834460473225381 + - 0.02287490748002227 + - -0.41445745412513074 + - -0.07366625228908175 + - -0.24473690111079943 + - 0.02841491531440257 + - -0.010521319178505545 + - - -0.43664598371558466 + - 0.18723071901975422 + - 0.1339203457643558 + - -0.2622484428593279 + - 0.2984576814138297 + - -0.1217610036415616 + - 0.29227045340860613 + - 0.12766039268491938 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + - attention_layer: + bias: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + in_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.46320425201766247 + - -0.0861544352136468 + - 0.2393678931194865 + - 0.11087545718505754 + - 0.2593361178266803 + - 0.2540074002087115 + - -0.5740614513048666 + - -0.007829221227449888 + - 0.03834047013818848 + - -0.004317709010731196 + - -0.0027632633376340767 + - -0.129202963065761 + - 0.19962765144963399 + - -0.31915441658681964 + - 0.013767837288653264 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.15278845751301479 + - -0.015117760586652125 + - 0.17105527645119087 + - -0.017690331684815156 + - 0.013004523084949588 + - 0.03592592325706536 + - -0.0015929443118154578 + - -0.18017603272628052 + - 0.2493718052519214 + - -0.01851252273831001 + - -0.36840239789481716 + - 0.07573562654747307 + - -0.09496153609588609 + - -0.217116034374993 + - -0.19437900375981282 + - - -0.20243734196131918 + - -0.19577994014285202 + - 0.13487904643057716 + - -0.09618343487121536 + - 0.1517371615159787 + - 0.2317393068058875 + - 0.12723189060826812 + - -0.21010011890877345 + - 0.02813198590045214 + - -0.09258860620336339 + - -0.1280896223225971 + - -0.11917063105802392 + - 0.04521303815511686 + - 0.20654944804245792 + - -0.15865221800841317 + - - 0.04613230757350663 + - 0.17291133888651458 + - -0.06930093162860351 + - 0.045228625938355055 + - -0.37356656687631423 + - -0.07925353159386202 + - 0.09488811507966938 + - 0.06244160229000192 + - -0.054835413565586515 + - 0.05835005467493042 + - -0.04564624157953147 + - -0.1867159393752215 + - 0.2151980108526414 + - 0.15604528498387527 + - -0.18474772218847696 + - - 0.3827994645389345 + - -0.010237764127357605 + - 0.12381495811086156 + - -0.07022891369526123 + - -0.13735614874469032 + - -0.017164374355980064 + - -0.4274370541318091 + - -0.24458731129502673 + - -0.18366825340135526 + - 0.09729400054763976 + - 0.09411515716806347 + - 0.11506950894475473 + - 0.10351955618709224 + - -0.39118396340839356 + - -0.19621858665487787 + - - -0.08753885247231129 + - -0.35366717552496174 + - -0.1877105395511413 + - -0.2703071905795911 + - -0.02658343547240909 + - -0.10019617555427808 + - -0.3917356859889574 + - -0.06819182655971401 + - -0.24428520138922613 + - 0.19318452395679075 + - 0.16284272315545056 + - 0.09549566838851437 + - -0.15404568857961887 + - -0.13201596319089784 + - 0.3598798855846083 + - - 0.014716388594522288 + - -0.17735110477484514 + - -0.061661526868765325 + - -0.18203671283898065 + - 0.012721645489698783 + - 0.107397816199318 + - 0.1811682413719007 + - -0.06356458406340935 + - -0.017564904489886752 + - -0.0696406040356576 + - 0.0149169373033483 + - -0.15882262096263317 + - 0.09525848781932615 + - 0.17964040252444105 + - -0.03624892800234665 + - - 0.1786610101163796 + - 0.13715969809908898 + - -0.1880461321004054 + - -0.3896933856757151 + - 0.041440821351723604 + - 0.041072593013709066 + - 0.24596001303001944 + - -0.40664783633033785 + - -0.1478716346596206 + - 0.071411819949603 + - 0.2325127968268579 + - -0.07342839058831764 + - 0.20895961640443458 + - 0.09181799202495412 + - 0.4926847223074976 + - - -0.04882295096710013 + - 0.07765052182971377 + - -0.34038490086778184 + - -0.1252117642283293 + - 0.021757235933785517 + - -0.011391859589372822 + - 0.13342940246719198 + - -0.3379964480138069 + - 0.21966774375378947 + - -0.03690015099190904 + - 0.0405116207259872 + - -0.20848626800286674 + - 0.31758221044797136 + - 0.05514498908692302 + - -0.09901035445881806 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + nnei: 30 + normalize: true + num_heads: 1 + out_proj: + "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.26027960344854445 + - -0.005037455883166034 + - -0.2357556083114248 + - 0.2502376952585402 + - -0.3861181832501032 + - 0.2251846251079561 + - -0.3787536933759104 + - -0.282026008969001 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.35287515934198155 + - 0.17037046689819482 + - 0.0674792027058034 + - 0.2359224430302736 + - -0.09908749500161877 + - 0.2413708908545201 + - 0.4516522336877069 + - -0.027166513316912725 + - - -0.2284947341144437 + - 0.0178700011547884 + - 0.3373485274023878 + - 0.12615306499365203 + - 0.38397500977326865 + - 0.3077434699570989 + - -0.21271028169402129 + - -0.12331826386489855 + - - 0.03490948953064142 + - 0.21388052092612841 + - 0.2976002080672586 + - -0.6828361893512437 + - -0.20327324806999922 + - 0.1679422439393503 + - -0.25029763669746585 + - 0.13859516541933636 + - - 0.22538603715419983 + - -0.04571464689496278 + - 0.1477185373369375 + - -0.715304012437922 + - 0.0036840425700703826 + - 0.006742848992578719 + - -0.2365182585758341 + - 0.14823007775766914 + - - 0.1174073817177984 + - 0.29161887216630056 + - -0.24494657592766272 + - 0.18743887078811955 + - -0.01507541506930442 + - -0.4834961103470373 + - -0.10159225497958217 + - 0.13148867132103184 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + precision: float64 + scaling_factor: 1.0 + smooth: true + temperature: 1.0 + attn_layer_norm: + "@class": LayerNorm + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + - 0.0 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + - 1.0 + "@version": 1 + eps: 1.0e-05 + precision: float64 + trainable: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + do_mask: false + dotr: true + embed_dim: 8 + hidden_dim: 5 + layer_num: 2 + ln_eps: 1.0e-05 + nnei: 30 + normalize: true + precision: float64 + scaling_factor: 1.0 + temperature: 1.0 + trainable_ln: true + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + concat_output_tebd: true + embeddings: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: embedding_network + networks: + - "@class": EmbeddingNetwork + "@version": 2 + activation_function: tanh + bias: true + in_dim: 9 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.04113537876982132 + - 0.2135993577649116 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.19681532110926553 + - 0.16334979105313088 + - - 0.21075497962711434 + - 0.18561262316289895 + - - -0.2025368467015397 + - 0.30497479132960376 + - - 0.394623773693339 + - -0.5064923928042572 + - - -0.3114866097061708 + - 0.25480611739095393 + - - 0.3247253367745367 + - -0.12224922439692641 + - - -0.1454743994136753 + - -0.19056163028525014 + - - -0.4207890615015054 + - -0.04282640777790463 + - - -0.18224083572453986 + - 0.49440056045729425 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.9267767543746736 + - -0.7598444290615695 + - -0.472759170368349 + - -0.4522753423489502 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.5042381620446951 + - 0.2938947201880637 + - -0.08711840691805456 + - -0.11001290236600435 + - - -0.39743457833825646 + - 0.7121257502480001 + - -0.810708871828223 + - 0.5106226403219409 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.44005799868203765 + - 0.29655099423195885 + - -0.004267534184854465 + - -0.12538685355150112 + - -0.0633395411335549 + - 0.10241734942934799 + - 0.3702843665149265 + - -0.2000640846067108 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.047193055845247425 + - -0.19439221133072598 + - 0.09798453309569832 + - 0.6185785175412979 + - 0.3060557021582276 + - 0.0706692687554129 + - 0.035152222724231366 + - -0.014285605716690993 + - - -0.04599636222082389 + - 0.5238775079615289 + - -0.18638606304480843 + - -0.10150429650568861 + - 0.38656771075032204 + - -0.26565468801345316 + - -0.09795208618845697 + - -0.08426203696033274 + - - -0.31834471891577343 + - -0.42270207790248504 + - -0.5246839894162881 + - -0.2888396394992532 + - -0.35845154406297464 + - 0.2511035839965956 + - 0.31637016338538465 + - 0.08809323262145466 + - - 0.37578567353387815 + - 0.6808813311521732 + - 0.3026687574841046 + - -0.35554258162485597 + - -0.10377188866537626 + - 0.0093236771028058 + - 0.3334834453439019 + - 0.3598841898253352 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 2 + - 4 + - 8 + precision: float64 + resnet_dt: false + ntypes: 6 + env_mat: + protection: 1.0e-06 + rcut: 6.0 + rcut_smth: 2.0 + use_exp_switch: false + env_protection: 1.0e-06 + exclude_types: &id003 + - - 4 + - 0 + - - 4 + - 1 + - - 4 + - 2 + - - 4 + - 3 + - - 4 + - 4 + - - 4 + - 5 + - - 5 + - 0 + - - 5 + - 1 + - - 5 + - 2 + - - 5 + - 3 + - - 5 + - 4 + - - 5 + - 5 + ln_eps: 1.0e-05 + neuron: + - 2 + - 4 + - 8 + normalize: true + ntypes: 6 + precision: float64 + rcut: 6.0 + rcut_smth: 2.0 + resnet_dt: false + scaling_factor: 1.0 + sel: + - 30 + set_davg_zero: false + smooth_type_embedding: true + spin: null + tebd_dim: 8 + tebd_input_mode: concat + temperature: 1.0 + trainable: true + trainable_ln: true + type: dpa1 + type_embedding: + "@class": TypeEmbedNet + "@version": 2 + activation_function: Linear + embedding: + "@class": EmbeddingNetwork + "@version": 2 + activation_function: Linear + bias: false + in_dim: 6 + layers: + - "@class": Layer + "@variables": + b: null + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.019121319536727743 + - 0.1622099217526562 + - 0.52422642629629 + - 0.0003481780854710199 + - -0.14932320832132 + - 0.16953289687994497 + - 0.13023669030537668 + - -0.06786447293424418 + - - -0.5367377279555972 + - -0.21692165642289568 + - 0.09028222516014006 + - -0.11806823801341935 + - 0.09672577514298564 + - 0.2459902879241782 + - 0.17198353347306597 + - 0.18377751881319745 + - - -0.0014766159902972935 + - 0.1000410886260073 + - -0.2907377242281664 + - 0.3162120364881382 + - 0.4271105257486104 + - 0.22058291598931898 + - 0.01568277454483321 + - -0.18428977602205465 + - - -0.3204655779290196 + - -0.5142684007913663 + - 0.05135966371525627 + - -0.21086568980349768 + - -0.18322818467115387 + - -0.06284540244831963 + - 0.012685450080645493 + - 0.7357771957158443 + - - 0.11857625691428302 + - 0.267816433264396 + - 0.45158171239301803 + - -0.2714794645199128 + - 0.37071886022802414 + - 0.12196019501846228 + - 0.08852737862760002 + - 0.2852190944801097 + - - 0.06018288011343084 + - -0.39027247236445695 + - -0.12053609546119343 + - -0.36430465334314277 + - -0.03307079734770386 + - 0.03239794425444641 + - 0.18978829125202099 + - -0.05638835726844023 + "@version": 2 + activation_function: Linear + bias: false + precision: float64 + resnet: true + trainable: true + use_timestep: false + neuron: + - 8 + precision: float64 + resnet_dt: false + neuron: + - 8 + ntypes: 6 + padding: true + precision: float64 + resnet_dt: false + trainable: true + type_map: &id001 + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_econf_tebd: false + use_tebd_bias: false + type_map: *id001 + type_one_side: true + use_econf_tebd: false + use_tebd_bias: false + fitting: + "@class": Fitting + "@variables": + aparam_avg: null + aparam_inv_std: null + bias_atom_e: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + - - 0.0 + case_embd: null + fparam_avg: null + fparam_inv_std: null + "@version": 4 + activation_function: tanh + atom_ener: null + default_fparam: null + dim_case_embd: 0 + dim_descrpt: 40 + dim_out: 1 + exclude_types: *id002 + layer_name: null + mixed_types: true + nets: + "@class": NetworkCollection + "@version": 1 + ndim: 0 + network_type: fitting_network + networks: + - "@class": FittingNetwork + "@version": 1 + activation_function: tanh + bias_out: true + in_dim: 40 + layers: + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.009036579787368115 + - -0.12043526618489708 + - -0.04022044991652253 + - -0.07825004058183097 + - -0.03264115403416456 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.5674676307341504 + - 0.023782010403849002 + - -0.05443603388599015 + - -0.15118658175337193 + - 0.08335549354993747 + - - 0.01687856484895976 + - -0.13669933269916298 + - -0.2583871114562561 + - -0.04336961815137456 + - -0.026271419761158384 + - - 0.09928924983316419 + - -0.030823465543671564 + - -0.07236236499507996 + - -0.22291471951912412 + - -0.21573877326789404 + - - -0.09278605203002248 + - 0.08377078773775477 + - -0.12961456608914707 + - -0.016839851437396955 + - 0.13415295214606296 + - - 0.012335469675728256 + - 0.05184437952876085 + - -0.006423649584056193 + - 0.0791277707651413 + - -0.30098066435096815 + - - 0.23047299924070833 + - 0.011873576382302296 + - 0.22526527610353067 + - -0.06817268037111866 + - 0.11189913048627279 + - - 0.014153472220860053 + - 0.08806719945921511 + - 0.3664242447098106 + - 0.04321403733928968 + - 0.1858804281169388 + - - 0.1548207212141085 + - 0.2621489519409535 + - -0.11858615560353057 + - -0.12486224876756297 + - 0.1168018815762926 + - - -0.0644985068316256 + - -0.08224138961544007 + - -0.03860328395788887 + - -0.046369682409357724 + - 0.1947426895691133 + - - -0.04919127545001311 + - -0.03737675099438821 + - -0.13411235417788264 + - 0.08224130773840631 + - 0.09273243988485456 + - - 0.22884824021908912 + - -0.03432459474452931 + - 0.17747812279031197 + - -0.12259533332014795 + - -0.0671810862789741 + - - -0.19586402439269535 + - 0.0644752438559895 + - -0.08847227949125148 + - -0.25926144778729854 + - 0.009478352142050912 + - - -0.020722424919005625 + - -0.08223270799818304 + - -0.15521951221770625 + - -0.0756580525932391 + - -0.056335758172860115 + - - 0.06378860044971417 + - 0.14929019221442877 + - -0.07871463028411804 + - 0.16852607421450178 + - -0.0447403453422747 + - - -0.06359553037472998 + - -0.010390961964384993 + - -0.14393618071816852 + - -0.21213322102177767 + - 0.0897102042442112 + - - 0.01804669930433539 + - 0.14674879958354925 + - -0.017738920787244743 + - -0.016183756496653465 + - -0.29793276333150465 + - - 0.18210173304243182 + - -0.13730027353137575 + - -0.07346146478717386 + - -0.031191012311731544 + - 0.12537243686809474 + - - -0.11221715474289964 + - 0.05614844128429898 + - 0.2940867170363763 + - -0.04008343210251607 + - 0.015773474566291266 + - - -0.22386665106548173 + - -0.14287116672063896 + - 0.03033590052843445 + - -0.0021294415782683856 + - -0.1600639422299111 + - - 0.16566164576662784 + - 0.14021528100045616 + - -0.005984056874067132 + - -0.22380037058326285 + - 0.24795710844174867 + - - -0.16244535887812667 + - -0.011359963616265218 + - 0.10884484432008126 + - -0.11031271551782482 + - -0.11652383232883412 + - - 0.0789223130568038 + - -0.19447458197651937 + - 0.05028051160586014 + - -0.12502497281504069 + - 0.17874575133706475 + - - -0.04332541874413034 + - -0.09993328082148574 + - -0.18411659529765192 + - 0.0935944320660105 + - -0.33895082223337825 + - - -0.10838342367642219 + - -0.04524333046162522 + - 0.018546594982928204 + - 0.3063305262005577 + - 0.11113911791179494 + - - 0.0473308698137412 + - -0.058165912497369444 + - 0.00473311818493111 + - -0.059462573376019685 + - 0.1006336845979852 + - - 0.12908464929081007 + - -0.13313986313838141 + - -0.20928277607170023 + - 0.29836459384363634 + - 0.09928401850939896 + - - -0.24244062130712524 + - 0.1721702535661161 + - 0.02843127684651192 + - -0.394476087897458 + - 0.33465902429328837 + - - 0.015147371330177416 + - 0.23211333058570127 + - -0.04257089989510875 + - 0.13395758446604872 + - 0.12976115854406703 + - - -0.1374936613829365 + - 0.046660173610344934 + - 0.16708061294983348 + - -0.30135250793158086 + - 0.09318435731037113 + - - -0.03632629541960811 + - -0.14877276597072533 + - -0.1102234695591099 + - 0.1863501585429565 + - 0.07549035140825774 + - - -0.03997042105254072 + - 0.15656325179297784 + - -0.0065638339229475125 + - -0.12944465676032332 + - 0.05411008410228317 + - - 0.010669872823749345 + - -0.12935249288214815 + - -0.1385873222444772 + - -0.2070331017648454 + - -0.02494997965825636 + - - -0.3542082175568733 + - -0.0029689418300270264 + - -0.07077892590307655 + - -0.2394118031138418 + - -0.10368891108719153 + - - -0.09148078381449777 + - 0.1532323355014287 + - 0.06754382410937022 + - -0.09826612014397533 + - 0.10145241597381519 + - - -0.06796982244136958 + - -0.21017464073201012 + - 0.21892110823154312 + - -0.2642604118410783 + - -0.044278605467948276 + - - -0.15383256844109483 + - -0.07503131599790384 + - 0.044603233877947554 + - 0.1409223114210914 + - -0.11099125566007907 + - - -0.2597948642623824 + - -0.10479807608530682 + - 0.30379811203169355 + - 0.11779849062588124 + - 0.2654714226553592 + - - 0.11349409442040773 + - -0.08763039088292908 + - -0.07340049416285838 + - -0.10576665782169115 + - 0.09672263326368545 + - - -0.1369501584766069 + - -0.016660386267711477 + - 0.016662926615926787 + - -0.046144852317728635 + - -0.013642370032357963 + - - 0.032093404136984964 + - 0.18132681764897873 + - -0.19317412976507092 + - -0.1040142623744361 + - -0.13801410477127313 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: false + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.14983067341331888 + - 0.12371841881354884 + - -0.22986822312961094 + - -0.25875449872749423 + - -0.03397589432487195 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.2386111924902191 + - 0.07479470127640106 + - -0.024271667699068782 + - 0.33407859647648697 + - 0.2155095839218714 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.05337260536385372 + - 0.7893095165846483 + - 0.10146750020477836 + - -0.07382069876495147 + - -0.049906208789190866 + - - 0.13920694268407685 + - 0.17288646920139744 + - 0.19417177418429363 + - -0.24360246237209826 + - 0.2751854797025152 + - - 0.4201478494578695 + - -0.5180712872213497 + - -0.06046142567704595 + - -0.4688142500317172 + - -0.8738568809019925 + - - -0.16989341023698443 + - -0.24304566774226685 + - -0.1426724141105585 + - -0.3334006993671281 + - 0.2956326142653607 + - - -0.29078211029181533 + - 0.38314560365399564 + - -0.4314685894513315 + - 0.26153727274597066 + - 0.23244982168063325 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - 0.1320183473911962 + - -0.4571722822420062 + - 0.010185901995519605 + - -0.16274352453991572 + - 0.11544273543463468 + idt: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.3011321904860118 + - -0.026753189393662627 + - -0.27834555279220746 + - -0.22656052647725025 + - 0.3535284718490197 + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - 0.488455469755746 + - -0.4501682102794183 + - 0.4473504692939388 + - -0.4013567727531333 + - -0.14017839613594146 + - - 0.11170356672506891 + - -0.19534689826061877 + - -0.19463584842339451 + - -0.32895518006240027 + - 0.4787047201183065 + - - 0.30562395194923836 + - 0.05207863515608017 + - 0.2811214296141219 + - 0.4060824734214178 + - -0.11467282115018355 + - - -0.4336905376571399 + - 0.1257782559999088 + - -0.22904030891735902 + - -0.13508914819534892 + - 0.7275596041551536 + - - -0.3862074614822706 + - 0.026603801148868712 + - 0.17689513330383833 + - 0.13340030783705445 + - -0.13151994511392598 + "@version": 2 + activation_function: tanh + bias: true + precision: float64 + resnet: true + trainable: true + use_timestep: true + - "@class": Layer + "@variables": + b: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - -0.47944902606795525 + idt: null + w: + "@class": np.ndarray + "@is_variable": true + "@version": 1 + dtype: float64 + value: + - - -0.3447980243495368 + - - 0.7370552717974026 + - - 0.38464511744143104 + - - 0.24069092502948974 + - - -0.03192455313254351 + "@version": 2 + activation_function: none + bias: true + precision: float64 + resnet: false + trainable: true + use_timestep: false + neuron: + - 5 + - 5 + - 5 + out_dim: 1 + precision: float64 + resnet_dt: true + ntypes: 6 + neuron: + - 5 + - 5 + - 5 + ntypes: 6 + numb_aparam: 0 + numb_fparam: 0 + precision: float64 + rcond: null + resnet_dt: true + spin: null + tot_ener_zero: false + trainable: + - true + - true + - true + - true + type: ener + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + use_aparam_as_mask: false + var_name: energy + pair_exclude_types: *id003 + preset_out_bias: null + rcond: null + type: standard + type_map: + - Ni + - O + - H + - Ni_spin + - O_spin + - H_spin + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type: spin_ener +model_def_script: + descriptor: + activation_function: tanh + attn: 5 + attn_dotr: true + attn_layer: 2 + attn_mask: false + axis_neuron: 4 + neuron: + - 2 + - 4 + - 8 + normalize: true + rcut: 6.0 + rcut_smth: 2.0 + scaling_factor: 1.0 + seed: 2 + sel: 30 + temperature: 1.0 + type: se_atten + type_one_side: true + fitting_net: + neuron: + - 5 + - 5 + - 5 + resnet_dt: true + seed: 2 + spin: + use_spin: + - true + - false + - false + virtual_scale: + - 0.314 + - 0.0 + - 0.0 + type_map: + - Ni + - O + - H +software: deepmd-kit +time: "2026-04-04 15:08:00.247584+00:00" +version: 3.0.0 diff --git a/source/tests/infer/gen_common.py b/source/tests/infer/gen_common.py index 11b732c513..4b542e43f6 100644 --- a/source/tests/infer/gen_common.py +++ b/source/tests/infer/gen_common.py @@ -79,3 +79,42 @@ def print_cpp_values(label, ae, f, av): comma = "," if ii < len(virial_flat) - 1 else "" print(f" {v:.18e}{comma}") # noqa: T201 print(" };") # noqa: T201 + + +def print_cpp_spin_values(label, ae, f, fm, tot_v, av): + """Print C++ reference arrays for spin models (energy, force, force_mag, virial).""" + print(f"\n// ---- {label} ----") # noqa: T201 + atom_energy = ae[0, :, 0] + print(" std::vector expected_e = {") # noqa: T201 + for ii, e in enumerate(atom_energy): + comma = "," if ii < len(atom_energy) - 1 else "" + print(f" {e:.18e}{comma}") # noqa: T201 + print(" };") # noqa: T201 + + print(" std::vector expected_f = {") # noqa: T201 + force_flat = f[0].flatten() + for ii, fv in enumerate(force_flat): + comma = "," if ii < len(force_flat) - 1 else "" + print(f" {fv:.18e}{comma}") # noqa: T201 + print(" };") # noqa: T201 + + print(" std::vector expected_fm = {") # noqa: T201 + fm_flat = fm[0].flatten() + for ii, fv in enumerate(fm_flat): + comma = "," if ii < len(fm_flat) - 1 else "" + print(f" {fv:.18e}{comma}") # noqa: T201 + print(" };") # noqa: T201 + + print(" std::vector expected_tot_v = {") # noqa: T201 + tot_v_flat = tot_v[0].flatten() + for ii, v in enumerate(tot_v_flat): + comma = "," if ii < len(tot_v_flat) - 1 else "" + print(f" {v:.18e}{comma}") # noqa: T201 + print(" };") # noqa: T201 + + print(" std::vector expected_atom_v = {") # noqa: T201 + av_flat = av[0].flatten() + for ii, v in enumerate(av_flat): + comma = "," if ii < len(av_flat) - 1 else "" + print(f" {v:.18e}{comma}") # noqa: T201 + print(" };") # noqa: T201 diff --git a/source/tests/infer/gen_spin.py b/source/tests/infer/gen_spin.py new file mode 100644 index 0000000000..3053e0ad4f --- /dev/null +++ b/source/tests/infer/gen_spin.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Generate deeppot_dpa_spin.pth and deeppot_dpa_spin.pt2 test models. + +The canonical model weights are stored in ``deeppot_dpa_spin.yaml`` (dpmodel +serialization, committed to git). This script converts the .yaml to both +.pth (torch.jit) and .pt2 (torch.export) formats. + +If the .yaml does not yet exist, it is created from a dpmodel built with +a deterministic config+seed — but this should only be done once (the .yaml +is then committed). + +Also prints reference values for C++ tests (PBC and NoPbc). +""" + +import copy +import os +import sys + +import numpy as np + +# Ensure the source tree is on the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) + +from gen_common import ( + ensure_inductor_compiler, + load_custom_ops, + print_cpp_spin_values, +) + + +def _build_yaml(yaml_path: str) -> None: + """Build the dpmodel from config+seed and save as .yaml.""" + from deepmd.dpmodel.model.model import ( + get_model, + ) + from deepmd.dpmodel.utils.serialization import ( + save_dp_model, + ) + + config = { + "type_map": ["Ni", "O", "H"], + "descriptor": { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": True, + "temperature": 1.0, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5, 5], + "resnet_dt": True, + "seed": 1, + }, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140, 0.0, 0.0], + }, + } + + model = get_model(copy.deepcopy(config)) + model_dict = model.serialize() + + data = { + "model": model_dict, + "model_def_script": config, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + + print(f"Building dpmodel and saving to {yaml_path} ...") # noqa: T201 + save_dp_model(yaml_path, data) + + +def main(): + from deepmd.entrypoints.convert_backend import ( + convert_backend, + ) + + ensure_inductor_compiler() + + base_dir = os.path.dirname(__file__) + yaml_path = os.path.join(base_dir, "deeppot_dpa_spin.yaml") + pth_path = os.path.join(base_dir, "deeppot_dpa_spin.pth") + pt2_path = os.path.join(base_dir, "deeppot_dpa_spin.pt2") + + # ---- 1. Build .yaml if it doesn't exist ---- + if not os.path.exists(yaml_path): + _build_yaml(yaml_path) + else: + print(f"Using existing {yaml_path}") # noqa: T201 + + # ---- 2. Convert .yaml -> .pth and .yaml -> .pt2 ---- + # Import deepmd.pt to register the backend (needed for convert_backend) + import deepmd.pt # noqa: F401 + + load_custom_ops() + + print(f"Converting to {pth_path} ...") # noqa: T201 + convert_backend(INPUT=yaml_path, OUTPUT=pth_path) + + print(f"Converting to {pt2_path} ...") # noqa: T201 + convert_backend(INPUT=yaml_path, OUTPUT=pt2_path) + + print("Export done.") # noqa: T201 + + # ---- 3. Run inference for PBC test ---- + from deepmd.infer import ( + DeepPot, + ) + + dp = DeepPot(pt2_path) + + coord = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=np.float64, + ) + spin = np.array( + [ + 0.13, + 0.02, + 0.03, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.14, + 0.10, + 0.12, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + dtype=np.float64, + ) + atype = [0, 1, 1, 0, 1, 1] + box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64) + + e1, f1, v1, ae1, av1, fm1, _ = dp.eval(coord, box, atype, atomic=True, spin=spin) + print(f"\n// PBC total energy: {e1[0, 0]:.18e}") # noqa: T201 + print_cpp_spin_values("PBC reference values", ae1, f1, fm1, v1, av1) + + # ---- 4. Run inference for NoPbc test ---- + e_np, f_np, v_np, ae_np, av_np, fm_np, _ = dp.eval( + coord, None, atype, atomic=True, spin=spin + ) + print(f"\n// NoPbc total energy: {e_np[0, 0]:.18e}") # noqa: T201 + print_cpp_spin_values("NoPbc reference values", ae_np, f_np, fm_np, v_np, av_np) + + # ---- 5. Verify .pth gives same results ---- + if os.path.exists(pth_path): + dp_pth = DeepPot(pth_path) + e_pth, f_pth, v_pth, ae_pth, av_pth, fm_pth, _ = dp_pth.eval( + coord, box, atype, atomic=True, spin=spin + ) + print(f"\n// .pth PBC total energy: {e_pth[0, 0]:.18e}") # noqa: T201 + print(f"// .pth vs .pt2 energy diff: {abs(e1[0, 0] - e_pth[0, 0]):.2e}") # noqa: T201 + print(f"// .pth vs .pt2 force max diff: {np.max(np.abs(f1 - f_pth)):.2e}") # noqa: T201 + print(f"// .pth vs .pt2 force_mag max diff: {np.max(np.abs(fm1 - fm_pth)):.2e}") # noqa: T201 + + print("\nDone!") # noqa: T201 + + +if __name__ == "__main__": + main() diff --git a/source/tests/infer/gen_spin_model_devi.py b/source/tests/infer/gen_spin_model_devi.py new file mode 100644 index 0000000000..3dc5240e14 --- /dev/null +++ b/source/tests/infer/gen_spin_model_devi.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Generate deeppot_dpa_spin_md0.pt2 and deeppot_dpa_spin_md1.pt2 test models. + +The canonical model weights are stored in .yaml files (dpmodel serialization, +committed to git). This script converts them to .pt2 format. + +If the .yaml files do not yet exist, they are created from dpmodel with +different seeds — but this should only be done once (the .yaml files are +then committed). + +Prints reference values for C++ tests. +""" + +import copy +import os +import sys + +import numpy as np + +# Ensure the source tree is on the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) + +from gen_common import ( + ensure_inductor_compiler, + load_custom_ops, + print_cpp_spin_values, +) + +# Model config (same architecture as gen_spin.py, different seeds) +_BASE_CONFIG = { + "type_map": ["Ni", "O", "H"], + "descriptor": { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": True, + "temperature": 1.0, + "type_one_side": True, + }, + "fitting_net": { + "neuron": [5, 5, 5], + "resnet_dt": True, + }, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140, 0.0, 0.0], + }, +} + + +def _build_yaml(yaml_path: str, seed: int) -> None: + """Build a dpmodel with given seed and save as .yaml.""" + from deepmd.dpmodel.model.model import ( + get_model, + ) + from deepmd.dpmodel.utils.serialization import ( + save_dp_model, + ) + + cfg = copy.deepcopy(_BASE_CONFIG) + cfg["descriptor"]["seed"] = seed + cfg["fitting_net"]["seed"] = seed + model = get_model(cfg) + model_dict = model.serialize() + + data = { + "model": model_dict, + "model_def_script": cfg, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + + print(f"Building dpmodel (seed={seed}) and saving to {yaml_path} ...") # noqa: T201 + save_dp_model(yaml_path, data) + + +def main(): + from deepmd.entrypoints.convert_backend import ( + convert_backend, + ) + + ensure_inductor_compiler() + + base_dir = os.path.dirname(__file__) + + # ---- 1. Ensure .yaml files exist ---- + seeds = [1, 2] + yaml_paths = [] + pt2_paths = [] + for idx, seed in enumerate(seeds): + yaml_path = os.path.join(base_dir, f"deeppot_dpa_spin_md{idx}.yaml") + yaml_paths.append(yaml_path) + if not os.path.exists(yaml_path): + _build_yaml(yaml_path, seed) + else: + print(f"Using existing {yaml_path}") # noqa: T201 + + # ---- 2. Convert .yaml -> .pt2 ---- + # Import deepmd.pt to register the backend + import deepmd.pt # noqa: F401 + + load_custom_ops() + + for idx, yaml_path in enumerate(yaml_paths): + pt2_path = os.path.join(base_dir, f"deeppot_dpa_spin_md{idx}.pt2") + pt2_paths.append(pt2_path) + print(f"Converting to {pt2_path} ...") # noqa: T201 + convert_backend(INPUT=yaml_path, OUTPUT=pt2_path) + + print("Export done.") # noqa: T201 + + # ---- 3. Run inference for both models ---- + from deepmd.infer import ( + DeepPot, + ) + + coord = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=np.float64, + ) + spin = np.array( + [ + 0.13, + 0.02, + 0.03, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.14, + 0.10, + 0.12, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + dtype=np.float64, + ) + atype = [0, 1, 1, 0, 1, 1] + box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64) + + for idx, pt2_path in enumerate(pt2_paths): + dp = DeepPot(pt2_path) + e, f, v, ae, av, fm, _ = dp.eval(coord, box, atype, atomic=True, spin=spin) + print(f"\n// Model {idx} total energy: {e[0, 0]:.18e}") # noqa: T201 + print_cpp_spin_values(f"Model {idx} reference values", ae, f, fm, v, av) + + # ---- 4. Also print LAMMPS 4-atom system reference values ---- + lmp_coord = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=np.float64, + ) + lmp_spin = np.array( + [ + 0, + 0, + 1.2737, + 0, + 0, + 1.2737, + 0, + 0, + 0, + 0, + 0, + 0, + ], + dtype=np.float64, + ) + lmp_atype = [0, 0, 1, 1] + lmp_box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64 + ) + + print("\n// ---- LAMMPS 4-atom system (PBC) ----") # noqa: T201 + for idx, pt2_path in enumerate(pt2_paths): + dp = DeepPot(pt2_path) + e, f, v, ae, av, fm, _ = dp.eval( + lmp_coord, lmp_box, lmp_atype, atomic=True, spin=lmp_spin + ) + print(f"\n// LAMMPS Model {idx} total energy: {e[0, 0]:.18e}") # noqa: T201 + print(f"// LAMMPS Model {idx} force:") # noqa: T201 + for ii in range(4): + print(f"// [{f[0, ii, 0]:.16e}, {f[0, ii, 1]:.16e}, {f[0, ii, 2]:.16e}]") # noqa: T201 + print(f"// LAMMPS Model {idx} force_mag:") # noqa: T201 + for ii in range(4): + msg = ( + f"// [{fm[0, ii, 0]:.16e}, {fm[0, ii, 1]:.16e}, {fm[0, ii, 2]:.16e}]" + ) + print(msg) # noqa: T201 + + # NoPBC for LAMMPS + print("\n// ---- LAMMPS 4-atom system (NoPBC) ----") # noqa: T201 + for idx, pt2_path in enumerate(pt2_paths): + dp = DeepPot(pt2_path) + e, f, v, ae, av, fm, _ = dp.eval( + lmp_coord, None, lmp_atype, atomic=True, spin=lmp_spin + ) + print(f"\n// LAMMPS NoPBC Model {idx} total energy: {e[0, 0]:.18e}") # noqa: T201 + print(f"// LAMMPS NoPBC Model {idx} force:") # noqa: T201 + for ii in range(4): + print(f"// [{f[0, ii, 0]:.16e}, {f[0, ii, 1]:.16e}, {f[0, ii, 2]:.16e}]") # noqa: T201 + print(f"// LAMMPS NoPBC Model {idx} force_mag:") # noqa: T201 + for ii in range(4): + msg = ( + f"// [{fm[0, ii, 0]:.16e}, {fm[0, ii, 1]:.16e}, {fm[0, ii, 2]:.16e}]" + ) + print(msg) # noqa: T201 + + print("\nDone!") # noqa: T201 + + +if __name__ == "__main__": + main() diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 112ada5dc7..f0bef34cf8 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -96,6 +96,10 @@ def test_get_sel_type(self) -> None: sel_type = self.dp.deep_eval.get_sel_type() self.assertEqual(sel_type, self.model.get_sel_type()) + def test_use_spin_non_spin_model(self) -> None: + self.assertFalse(self.dp.has_spin) + self.assertEqual(self.dp.use_spin, []) + def test_model_type(self) -> None: self.assertIs(self.dp.deep_eval.model_type, DeepPot) @@ -582,6 +586,10 @@ def test_get_sel_type(self) -> None: sel_type = self.dp.deep_eval.get_sel_type() self.assertEqual(sel_type, self.model.get_sel_type()) + def test_use_spin_non_spin_model(self) -> None: + self.assertFalse(self.dp.has_spin) + self.assertEqual(self.dp.use_spin, []) + def test_model_type(self) -> None: self.assertIs(self.dp.deep_eval.model_type, DeepPot) @@ -855,5 +863,441 @@ def test_pt2_vs_pte_consistency(self) -> None: ) +class TestDeepEvalEnerDefaultFparam(unittest.TestCase): + """Test .pte inference with default fparam (non-spin model).""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.numb_fparam = 1 + cls.default_fparam = [0.5] + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + numb_fparam=cls.numb_fparam, + default_fparam=cls.default_fparam, + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.tmpfile.close() + deserialize_to_file(cls.tmpfile.name, cls.model_data) + + cls.dp = DeepPot(cls.tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + + def test_get_dim_fparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_fparam(), self.numb_fparam) + + def test_eval_without_fparam_matches_explicit(self) -> None: + """Eval without fparam should use default and match explicit fparam.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Eval WITHOUT fparam — should use default_fparam=[0.5] + e_no, f_no, v_no = self.dp.eval(coords, cells, atom_types) + # Eval WITH explicit fparam=[0.5] + e_ex, f_ex, v_ex = self.dp.eval( + coords, cells, atom_types, fparam=self.default_fparam + ) + + np.testing.assert_allclose(e_no, e_ex, atol=1e-10) + np.testing.assert_allclose(f_no, f_ex, atol=1e-10) + np.testing.assert_allclose(v_no, v_ex, atol=1e-10) + + def test_fparam_takes_effect(self) -> None: + """Different fparam values must produce different outputs.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + e0, f0, v0 = self.dp.eval(coords, cells, atom_types, fparam=[0.0]) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types, fparam=[1.0]) + + assert not np.allclose(e0, e1), ( + "Changing fparam did not change output — fparam may be ignored" + ) + + +class TestDeepEvalEnerDefaultFparamPt2(unittest.TestCase): + """Test .pt2 inference with default fparam (non-spin model).""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.numb_fparam = 1 + cls.default_fparam = [0.5] + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + numb_fparam=cls.numb_fparam, + default_fparam=cls.default_fparam, + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) + cls.tmpfile.close() + torch.set_default_device(None) + try: + deserialize_to_file(cls.tmpfile.name, cls.model_data) + finally: + torch.set_default_device("cuda:9999999") + + # Also save .pte for cross-format comparison + cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.pte_tmpfile.close() + deserialize_to_file(cls.pte_tmpfile.name, cls.model_data) + + cls.dp = DeepPot(cls.tmpfile.name) + cls.dp_pte = DeepPot(cls.pte_tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + os.unlink(cls.pte_tmpfile.name) + + def test_get_dim_fparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_fparam(), self.numb_fparam) + + def test_eval_without_fparam_matches_explicit(self) -> None: + """Eval without fparam should use default and match explicit fparam.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + e_no, f_no, v_no = self.dp.eval(coords, cells, atom_types) + e_ex, f_ex, v_ex = self.dp.eval( + coords, cells, atom_types, fparam=self.default_fparam + ) + + np.testing.assert_allclose(e_no, e_ex, atol=1e-10) + np.testing.assert_allclose(f_no, f_ex, atol=1e-10) + np.testing.assert_allclose(v_no, v_ex, atol=1e-10) + + def test_fparam_takes_effect(self) -> None: + """Different fparam values must produce different outputs.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + e0, f0, v0 = self.dp.eval(coords, cells, atom_types, fparam=[0.0]) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types, fparam=[1.0]) + + assert not np.allclose(e0, e1), ( + "Changing fparam did not change output — fparam may be ignored" + ) + + def test_pt2_vs_pte_consistency(self) -> None: + """Outputs from .pt2 with default fparam should match .pte.""" + rng = np.random.default_rng(GLOBAL_SEED + 19) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Both use default fparam (no explicit fparam) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types) + e2, f2, v2 = self.dp_pte.eval(coords, cells, atom_types) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + + +class TestDeepEvalEnerAparam(unittest.TestCase): + """Test .pte inference with aparam (non-spin model).""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.numb_aparam = 2 + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + numb_aparam=cls.numb_aparam, + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.tmpfile.close() + deserialize_to_file(cls.tmpfile.name, cls.model_data) + + cls.dp = DeepPot(cls.tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + + def test_get_dim_aparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_aparam(), self.numb_aparam) + + def test_aparam_takes_effect(self) -> None: + """Different aparam values must produce different outputs.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + aparam_zero = np.zeros(natoms * self.numb_aparam, dtype=np.float64) + aparam_nonzero = np.full(natoms * self.numb_aparam, 0.5, dtype=np.float64) + + e0, f0, v0 = self.dp.eval(coords, cells, atom_types, aparam=aparam_zero) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types, aparam=aparam_nonzero) + + assert not np.allclose(e0, e1), ( + "Changing aparam did not change output — aparam may be ignored" + ) + + def test_eval_without_aparam_raises(self) -> None: + """Model with dim_aparam > 0 must raise when aparam not provided.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + with self.assertRaises(ValueError): + self.dp.eval(coords, cells, atom_types) + + def test_eval_consistency(self) -> None: + """Test that DeepPot.eval with aparam matches direct model forward.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + aparam = rng.random(natoms * self.numb_aparam) + + e, f, v, ae, av = self.dp.eval( + coords, cells, atom_types, atomic=True, aparam=aparam + ) + + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) + aparam_t = torch.tensor( + aparam.reshape(1, natoms, self.numb_aparam), + dtype=torch.float64, + device=DEVICE, + ) + ref = self.model.forward( + coord_t, atype_t, cell_t, aparam=aparam_t, do_atomic_virial=True + ) + + np.testing.assert_allclose( + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae, ref["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + av, ref["atom_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + +class TestDeepEvalEnerAparamPt2(unittest.TestCase): + """Test .pt2 inference with aparam (non-spin model).""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.numb_aparam = 2 + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + numb_aparam=cls.numb_aparam, + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) + cls.tmpfile.close() + torch.set_default_device(None) + try: + deserialize_to_file(cls.tmpfile.name, cls.model_data) + finally: + torch.set_default_device("cuda:9999999") + + # Also save .pte for cross-format comparison + cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.pte_tmpfile.close() + deserialize_to_file(cls.pte_tmpfile.name, cls.model_data) + + cls.dp = DeepPot(cls.tmpfile.name) + cls.dp_pte = DeepPot(cls.pte_tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + os.unlink(cls.pte_tmpfile.name) + + def test_get_dim_aparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_aparam(), self.numb_aparam) + + def test_aparam_takes_effect(self) -> None: + """Different aparam values must produce different outputs.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + aparam_zero = np.zeros(natoms * self.numb_aparam, dtype=np.float64) + aparam_nonzero = np.full(natoms * self.numb_aparam, 0.5, dtype=np.float64) + + e0, f0, v0 = self.dp.eval(coords, cells, atom_types, aparam=aparam_zero) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types, aparam=aparam_nonzero) + + assert not np.allclose(e0, e1), ( + "Changing aparam did not change output — aparam may be ignored" + ) + + def test_eval_without_aparam_raises(self) -> None: + """Model with dim_aparam > 0 must raise when aparam not provided.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + with self.assertRaises(ValueError): + self.dp.eval(coords, cells, atom_types) + + def test_eval_consistency(self) -> None: + """Test that .pt2 DeepPot.eval with aparam matches direct model forward.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + aparam = rng.random(natoms * self.numb_aparam) + + e, f, v, ae, av = self.dp.eval( + coords, cells, atom_types, atomic=True, aparam=aparam + ) + + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) + aparam_t = torch.tensor( + aparam.reshape(1, natoms, self.numb_aparam), + dtype=torch.float64, + device=DEVICE, + ) + ref = self.model.forward( + coord_t, atype_t, cell_t, aparam=aparam_t, do_atomic_virial=True + ) + + np.testing.assert_allclose( + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae, ref["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + av, ref["atom_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_pt2_vs_pte_consistency(self) -> None: + """Outputs from .pt2 with aparam should match .pte.""" + rng = np.random.default_rng(GLOBAL_SEED + 19) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + aparam = rng.random(natoms * self.numb_aparam) + + e1, f1, v1 = self.dp.eval(coords, cells, atom_types, aparam=aparam) + e2, f2, v2 = self.dp_pte.eval(coords, cells, atom_types, aparam=aparam) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt_expt/infer/test_deep_eval_spin.py b/source/tests/pt_expt/infer/test_deep_eval_spin.py new file mode 100644 index 0000000000..829b1f5666 --- /dev/null +++ b/source/tests/pt_expt/infer/test_deep_eval_spin.py @@ -0,0 +1,482 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for spin model inference via the DeepPot high-level API. + +Verifies that .pt2 and .pte spin models produce correct results when loaded +through the pt_expt inference backend (DeepEval → DeepPot). +""" + +import copy +import os +import tempfile + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.pt_expt.model.spin_ener_model import ( + SpinEnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.serialization import ( + deserialize_to_file, +) + +SPIN_CONFIG = { + "type_map": ["Ni", "O"], + "descriptor": { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": True, + "temperature": 1.0, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5, 5], + "resnet_dt": True, + "seed": 1, + }, + "spin": { + "use_spin": [True, False], + "virtual_scale": [0.3140, 0.0], + }, +} + +COORD = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=np.float64, +) +SPIN = np.array( + [ + 0.13, + 0.02, + 0.03, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.14, + 0.10, + 0.12, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + dtype=np.float64, +) +ATYPE = [0, 1, 1, 0, 1, 1] +BOX = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64) + + +def _build_reference(): + """Build pt_expt model and run eager reference inference. + + Returns data dict, and reference dicts for PBC and NoPBC. + """ + dp_model = get_model_dp(copy.deepcopy(SPIN_CONFIG)) + model_dict = dp_model.serialize() + data = { + "model": model_dict, + "model_def_script": SPIN_CONFIG, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + + # Build pt_expt model for eager reference + pt_model = SpinEnergyModel.deserialize(dp_model.serialize()).to(env.DEVICE) + pt_model.eval() + + natoms = len(ATYPE) + coord_t = torch.tensor( + COORD.reshape(1, natoms, 3), dtype=torch.float64, device=env.DEVICE + ) + coord_t.requires_grad_(True) + atype_t = torch.tensor([ATYPE], dtype=torch.int64, device=env.DEVICE) + spin_t = torch.tensor( + SPIN.reshape(1, natoms, 3), dtype=torch.float64, device=env.DEVICE + ) + box_t = torch.tensor(BOX.reshape(1, 9), dtype=torch.float64, device=env.DEVICE) + + # PBC reference + ref_pbc = pt_model(coord_t, atype_t, spin_t, box_t) + ref_pbc = {k: v.detach().cpu().numpy() for k, v in ref_pbc.items()} + + # NoPBC reference + ref_nopbc = pt_model(coord_t, atype_t, spin_t, None) + ref_nopbc = {k: v.detach().cpu().numpy() for k, v in ref_nopbc.items()} + + return data, ref_pbc, ref_nopbc + + +@pytest.fixture(scope="module") +def spin_model_files(): + """Create .pt2 and .pte spin model files and compute reference values.""" + data, ref_pbc, ref_nopbc = _build_reference() + files = {} + tmpdir = tempfile.mkdtemp() + for ext in (".pt2", ".pte"): + path = os.path.join(tmpdir, f"spin_test{ext}") + # AOTInductor (.pt2) internally creates tensors using the PyTorch + # default device. Clear it so compilation stays on CPU. + prev = torch.get_default_device() + torch.set_default_device(None) + try: + deserialize_to_file(path, copy.deepcopy(data)) + finally: + torch.set_default_device(prev) + files[ext] = path + yield files, ref_pbc, ref_nopbc + for path in files.values(): + if os.path.exists(path): + os.unlink(path) + os.rmdir(tmpdir) + + +@pytest.mark.parametrize("ext", [".pt2", ".pte"]) # model format +class TestSpinInference: + """Test spin model inference through DeepPot high-level API.""" + + def test_get_has_spin(self, spin_model_files, ext) -> None: + """Test that get_has_spin returns True for spin models.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, _ = spin_model_files + dp = DeepPot(files[ext]) + assert dp.has_spin + + def test_get_use_spin(self, spin_model_files, ext) -> None: + """Test that use_spin returns per-type spin usage.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, _ = spin_model_files + dp = DeepPot(files[ext]) + use_spin = dp.use_spin + assert use_spin == [True, False] + + def test_get_ntypes_spin(self, spin_model_files, ext) -> None: + """Test that get_ntypes_spin returns 0 (new spin implementation).""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, _ = spin_model_files + dp = DeepPot(files[ext]) + assert dp.get_ntypes_spin() == 0 + + def test_eval_spin_model_requires_spin(self, spin_model_files, ext) -> None: + """Spin model must raise ValueError when spin is not provided.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, _ = spin_model_files + dp = DeepPot(files[ext]) + with pytest.raises(ValueError, match="no `spin` argument was provided"): + dp.eval(COORD, BOX, ATYPE) + + def test_eval_pbc_atomic(self, spin_model_files, ext) -> None: + """Test PBC evaluation with atomic=True.""" + from deepmd.infer import ( + DeepPot, + ) + + files, ref, _ = spin_model_files + dp = DeepPot(files[ext]) + natoms = len(ATYPE) + + e, f, v, ae, av, fm, mm = dp.eval(COORD, BOX, ATYPE, atomic=True, spin=SPIN) + + np.testing.assert_allclose( + e.reshape(-1), ref["energy"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae.reshape(-1), + ref["atom_energy"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + f.reshape(-1), ref["force"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + fm.reshape(-1), + ref["force_mag"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + v.reshape(-1), ref["virial"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose(mm.reshape(-1), ref["mask_mag"].reshape(-1)) + # Shape checks + assert e.shape == (1, 1) + assert f.shape == (1, natoms, 3) + assert v.shape == (1, 9) + assert ae.shape == (1, natoms, 1) + assert av.shape == (1, natoms, 9) + assert fm.shape == (1, natoms, 3) + assert mm.shape == (1, natoms, 1) + + def test_eval_pbc_nonatomic(self, spin_model_files, ext) -> None: + """Test PBC evaluation with atomic=False.""" + from deepmd.infer import ( + DeepPot, + ) + + files, ref, _ = spin_model_files + dp = DeepPot(files[ext]) + + e, f, v, fm, mm = dp.eval(COORD, BOX, ATYPE, atomic=False, spin=SPIN) + + np.testing.assert_allclose( + e.reshape(-1), ref["energy"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f.reshape(-1), ref["force"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + fm.reshape(-1), + ref["force_mag"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + v.reshape(-1), ref["virial"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + + def test_eval_nopbc_atomic(self, spin_model_files, ext) -> None: + """Test NoPBC evaluation with atomic=True.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, ref = spin_model_files + dp = DeepPot(files[ext]) + + e, f, v, ae, av, fm, mm = dp.eval(COORD, None, ATYPE, atomic=True, spin=SPIN) + + np.testing.assert_allclose( + e.reshape(-1), ref["energy"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae.reshape(-1), + ref["atom_energy"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + f.reshape(-1), ref["force"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + fm.reshape(-1), + ref["force_mag"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + + def test_eval_nopbc_nonatomic(self, spin_model_files, ext) -> None: + """Test NoPBC evaluation with atomic=False.""" + from deepmd.infer import ( + DeepPot, + ) + + files, _, ref = spin_model_files + dp = DeepPot(files[ext]) + + e, f, v, fm, mm = dp.eval(COORD, None, ATYPE, atomic=False, spin=SPIN) + + np.testing.assert_allclose( + e.reshape(-1), ref["energy"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f.reshape(-1), ref["force"].reshape(-1), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + fm.reshape(-1), + ref["force_mag"].reshape(-1), + rtol=1e-10, + atol=1e-10, + ) + + +SPIN_FPARAM_CONFIG = copy.deepcopy(SPIN_CONFIG) +SPIN_FPARAM_CONFIG["fitting_net"]["numb_fparam"] = 1 +SPIN_FPARAM_CONFIG["fitting_net"]["default_fparam"] = [0.5] + + +@pytest.fixture(scope="module") +def spin_fparam_model_files(): + """Create .pt2 and .pte spin model files with default fparam.""" + dp_model = get_model_dp(copy.deepcopy(SPIN_FPARAM_CONFIG)) + model_dict = dp_model.serialize() + data = { + "model": model_dict, + "model_def_script": SPIN_FPARAM_CONFIG, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + files = {} + tmpdir = tempfile.mkdtemp() + for ext in (".pt2", ".pte"): + path = os.path.join(tmpdir, f"spin_fparam_test{ext}") + prev = torch.get_default_device() + torch.set_default_device(None) + try: + deserialize_to_file(path, copy.deepcopy(data)) + finally: + torch.set_default_device(prev) + files[ext] = path + yield files + for path in files.values(): + if os.path.exists(path): + os.unlink(path) + os.rmdir(tmpdir) + + +@pytest.mark.parametrize("ext", [".pt2", ".pte"]) # model format +class TestSpinDefaultFparam: + """Test spin model with default_fparam via DeepPot API.""" + + def test_eval_without_fparam_matches_explicit( + self, spin_fparam_model_files, ext + ) -> None: + """Eval without fparam should use default and match explicit fparam.""" + from deepmd.infer import ( + DeepPot, + ) + + files = spin_fparam_model_files + dp = DeepPot(files[ext]) + + # Eval WITHOUT fparam — should use default_fparam=[0.5] + e_no, f_no, v_no, fm_no, mm_no = dp.eval( + COORD, BOX, ATYPE, atomic=False, spin=SPIN + ) + # Eval WITH explicit fparam=[0.5] + e_ex, f_ex, v_ex, fm_ex, mm_ex = dp.eval( + COORD, BOX, ATYPE, atomic=False, spin=SPIN, fparam=[0.5] + ) + + np.testing.assert_allclose(e_no, e_ex, atol=1e-10) + np.testing.assert_allclose(f_no, f_ex, atol=1e-10) + np.testing.assert_allclose(v_no, v_ex, atol=1e-10) + np.testing.assert_allclose(fm_no, fm_ex, atol=1e-10) + + +SPIN_APARAM_CONFIG = copy.deepcopy(SPIN_CONFIG) +SPIN_APARAM_CONFIG["fitting_net"]["numb_aparam"] = 2 + + +@pytest.fixture(scope="module") +def spin_aparam_model_files(): + """Create .pt2 and .pte spin model files with aparam.""" + dp_model = get_model_dp(copy.deepcopy(SPIN_APARAM_CONFIG)) + model_dict = dp_model.serialize() + data = { + "model": model_dict, + "model_def_script": SPIN_APARAM_CONFIG, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + files = {} + tmpdir = tempfile.mkdtemp() + for ext in (".pt2", ".pte"): + path = os.path.join(tmpdir, f"spin_aparam_test{ext}") + prev = torch.get_default_device() + torch.set_default_device(None) + try: + deserialize_to_file(path, copy.deepcopy(data)) + finally: + torch.set_default_device(prev) + files[ext] = path + yield files + for path in files.values(): + if os.path.exists(path): + os.unlink(path) + os.rmdir(tmpdir) + + +@pytest.mark.parametrize("ext", [".pt2", ".pte"]) # model format +class TestSpinAparam: + """Test spin model with aparam via DeepPot API (.pt2/.pte).""" + + def test_aparam_takes_effect(self, spin_aparam_model_files, ext) -> None: + """Verify that different aparam values produce different outputs.""" + from deepmd.infer import ( + DeepPot, + ) + + files = spin_aparam_model_files + dp = DeepPot(files[ext]) + natoms = len(ATYPE) + + aparam_zero = np.zeros(natoms * 2, dtype=np.float64) + aparam_nonzero = np.full(natoms * 2, 0.5, dtype=np.float64) + + e0, f0, v0, fm0, mm0 = dp.eval( + COORD, BOX, ATYPE, atomic=False, spin=SPIN, aparam=aparam_zero + ) + e1, f1, v1, fm1, mm1 = dp.eval( + COORD, BOX, ATYPE, atomic=False, spin=SPIN, aparam=aparam_nonzero + ) + + # Different aparam must produce different energy + assert not np.allclose(e0, e1), ( + "Changing aparam did not change output — aparam may be ignored" + ) + + def test_eval_without_aparam_raises(self, spin_aparam_model_files, ext) -> None: + """Model with dim_aparam > 0 must raise when aparam not provided.""" + from deepmd.infer import ( + DeepPot, + ) + + files = spin_aparam_model_files + dp = DeepPot(files[ext]) + + with pytest.raises(ValueError, match="aparam is required"): + dp.eval(COORD, BOX, ATYPE, atomic=False, spin=SPIN) diff --git a/source/tests/pt_expt/test_dp_freeze.py b/source/tests/pt_expt/test_dp_freeze.py index add8313752..7c33f0de81 100644 --- a/source/tests/pt_expt/test_dp_freeze.py +++ b/source/tests/pt_expt/test_dp_freeze.py @@ -188,36 +188,15 @@ def test_freeze_pt2_nopbc_negative_coords(self) -> None: np.testing.assert_allclose(f_pte, f_pt2, atol=1e-10) np.testing.assert_allclose(v_pte, v_pt2, atol=1e-10) - -class TestDPFreezePt2DefaultFparam(unittest.TestCase): - """Test .pt2 with default fparam — eval without providing fparam.""" - - @classmethod - def setUpClass(cls) -> None: - cls.tmpdir = tempfile.mkdtemp() - - model_params = deepcopy(model_se_e2_a) - model_params["fitting_net"]["numb_fparam"] = 1 - model_params["fitting_net"]["default_fparam"] = [0.5] - model = get_model(model_params) - wrapper = ModelWrapper(model, model_params=model_params) - state_dict = wrapper.state_dict() - cls.ckpt_file = os.path.join(cls.tmpdir, "model_dfp.pt") - torch.save({"model": state_dict}, cls.ckpt_file) - - @classmethod - def tearDownClass(cls) -> None: - shutil.rmtree(cls.tmpdir) - - def test_pt2_eval_default_fparam(self) -> None: - """Eval .pt2 without fparam should match eval with explicit default value.""" + def test_nonspin_model_rejects_spin(self) -> None: + """Non-spin model must raise ValueError when spin is provided.""" import numpy as np from deepmd.infer import ( DeepPot, ) - pt2_path = os.path.join(self.tmpdir, "dfp.pt2") + pt2_path = os.path.join(self.tmpdir, "nonspin_reject.pt2") freeze(model=self.ckpt_file, output=pt2_path) coord = np.array( @@ -226,17 +205,11 @@ def test_pt2_eval_default_fparam(self) -> None: ) box = np.array([5.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 5.0], dtype=np.float64) atype = [0, 1, 2] + spin = np.zeros(9, dtype=np.float64) dp = DeepPot(pt2_path) - - # Eval WITHOUT fparam — model should use default (0.5) - e_no, f_no, v_no = dp.eval(coord, box, atype) - # Eval WITH explicit default value - e_ex, f_ex, v_ex = dp.eval(coord, box, atype, fparam=[0.5]) - - np.testing.assert_allclose(e_no, e_ex, atol=1e-10) - np.testing.assert_allclose(f_no, f_ex, atol=1e-10) - np.testing.assert_allclose(v_no, v_ex, atol=1e-10) + with self.assertRaises(ValueError): + dp.eval(coord, box, atype, spin=spin) if __name__ == "__main__":