From a6c1dc254f2361233ad187a5451e0a0a42b3e51f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:04:30 +0800 Subject: [PATCH 1/6] feat(dp, pt): add charge spin embedding --- .../dpmodel/atomic_model/dp_atomic_model.py | 20 ++++++ deepmd/dpmodel/descriptor/dpa1.py | 1 + deepmd/dpmodel/descriptor/dpa2.py | 1 + deepmd/dpmodel/descriptor/dpa3.py | 70 +++++++++++++++++++ deepmd/dpmodel/descriptor/hybrid.py | 1 + .../descriptor/make_base_descriptor.py | 1 + deepmd/dpmodel/descriptor/se_e2_a.py | 1 + deepmd/dpmodel/descriptor/se_r.py | 1 + deepmd/dpmodel/descriptor/se_t.py | 1 + deepmd/dpmodel/descriptor/se_t_tebd.py | 1 + .../pt/model/atomic_model/dp_atomic_model.py | 15 ++++ deepmd/pt/model/descriptor/dpa1.py | 1 + deepmd/pt/model/descriptor/dpa2.py | 1 + deepmd/pt/model/descriptor/dpa3.py | 63 +++++++++++++++++ deepmd/pt/model/descriptor/hybrid.py | 1 + deepmd/pt/model/descriptor/se_a.py | 1 + deepmd/pt/model/descriptor/se_r.py | 1 + deepmd/pt/model/descriptor/se_t.py | 1 + deepmd/pt/model/descriptor/se_t_tebd.py | 1 + deepmd/utils/argcheck.py | 12 ++++ source/tests/consistent/descriptor/common.py | 45 ++++++++++-- .../tests/consistent/descriptor/test_dpa3.py | 26 ++++++- source/tests/pt/model/test_dpa3.py | 21 +++++- .../dpmodel/descriptor/test_descriptor.py | 5 +- .../universal/dpmodel/model/test_model.py | 8 +++ 25 files changed, 290 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 73447de955..752545626a 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -73,6 +73,9 @@ def __init__( if hasattr(self.fitting_net, "reinit_exclude"): self.fitting_net.reinit_exclude(self.atom_exclude_types) self.type_map = type_map + self.add_chg_spin_ebd: bool = getattr( + self.descriptor, "add_chg_spin_ebd", False + ) super().init_out_stat() def fitting_output_def(self) -> FittingOutputDef: @@ -179,11 +182,28 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape atype = extended_atype[:, :nloc] + + if self.fitting_net.get_dim_fparam() > 0 and fparam is None: + # use default fparam + from deepmd.dpmodel.array_api import ( + array_api_compat, + ) + + default_fparam = self.fitting_net.get_default_fparam() + assert default_fparam is not None + xp = array_api_compat.array_namespace(extended_coord) + fparam_input_for_des = xp.tile( + xp.reshape(default_fparam, (1, -1)), (nframes, 1) + ) + else: + fparam_input_for_des = fparam + descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, nlist, mapping=mapping, + fparam=fparam_input_for_des if self.add_chg_spin_ebd else None, ) ret = self.fitting_net( descriptor, diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 34dcba6335..b93987bedf 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -491,6 +491,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 5ac636c37c..1db09258c6 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -826,6 +826,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> tuple[Array, Array, Array, Array, Array]: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index e385ae5dda..ca6e50a1af 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -20,6 +20,7 @@ ) from deepmd.dpmodel.utils.network import ( NativeLayer, + get_activation_fn, ) from deepmd.dpmodel.utils.seed import ( child_seed, @@ -354,6 +355,7 @@ def __init__( use_tebd_bias: bool = False, use_loc_mapping: bool = True, type_map: list[str] | None = None, + add_chg_spin_ebd: bool = False, ) -> None: super().__init__() @@ -408,6 +410,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: ) self.use_econf_tebd = use_econf_tebd + self.add_chg_spin_ebd = add_chg_spin_ebd self.use_tebd_bias = use_tebd_bias self.use_loc_mapping = use_loc_mapping self.type_map = type_map @@ -426,6 +429,38 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: ) self.concat_output_tebd = concat_output_tebd self.precision = precision + + if self.add_chg_spin_ebd: + self.cs_activation_fn = get_activation_fn(activation_function) + # -100 ~ 100 is a conservative bound + self.chg_embedding = TypeEmbedNet( + ntypes=200, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + seed=child_seed(seed, 3), + ) + # 100 is a conservative upper bound + self.spin_embedding = TypeEmbedNet( + ntypes=100, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + seed=child_seed(seed, 4), + ) + self.mix_cs_mlp = NativeLayer( + 2 * self.tebd_dim, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 3), + ) + else: + self.chg_embedding = None + self.spin_embedding = None + self.mix_cs_mlp = None + self.exclude_types = exclude_types self.env_protection = env_protection self.trainable = trainable @@ -577,6 +612,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> tuple[Array, Array, Array, Array, Array]: """Compute the descriptor. @@ -623,6 +659,27 @@ def call( xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0), (nframes, nall, self.tebd_dim), ) + + if self.add_chg_spin_ebd: + assert fparam is not None + assert self.chg_embedding is not None + assert self.spin_embedding is not None + chg_tebd = self.chg_embedding.call() + spin_tebd = self.spin_embedding.call() + charge = xp.astype(fparam[:, 0], xp.int64) + 100 + spin = xp.astype(fparam[:, 1], xp.int64) + chg_ebd = xp.reshape( + xp.take(chg_tebd, xp.reshape(charge, (-1,)), axis=0), + (nframes, self.tebd_dim), + ) + spin_ebd = xp.reshape( + xp.take(spin_tebd, xp.reshape(spin, (-1,)), axis=0), + (nframes, self.tebd_dim), + ) + cs_cat = xp.concat([chg_ebd, spin_ebd], axis=-1) + sys_cs_embd = self.cs_activation_fn(self.mix_cs_mlp.call(cs_cat)) + node_ebd_ext = node_ebd_ext + xp.expand_dims(sys_cs_embd, axis=1) + node_ebd_inp = node_ebd_ext[:, :nloc, :] # repflows node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( @@ -653,9 +710,14 @@ def serialize(self) -> dict: "use_econf_tebd": self.use_econf_tebd, "use_tebd_bias": self.use_tebd_bias, "use_loc_mapping": self.use_loc_mapping, + "add_chg_spin_ebd": self.add_chg_spin_ebd, "type_map": self.type_map, "type_embedding": self.type_embedding.serialize(), } + if self.add_chg_spin_ebd: + data["chg_embedding"] = self.chg_embedding.serialize() + data["spin_embedding"] = self.spin_embedding.serialize() + data["mix_cs_mlp"] = self.mix_cs_mlp.serialize() repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), "angle_embd": repflows.angle_embd.serialize(), @@ -682,10 +744,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA3": data.pop("type") repflow_variable = data.pop("repflow_variable").copy() type_embedding = data.pop("type_embedding") + chg_embedding = data.pop("chg_embedding", None) + spin_embedding = data.pop("spin_embedding", None) + mix_cs_mlp = data.pop("mix_cs_mlp", None) data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) obj = cls(**data) obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + if obj.add_chg_spin_ebd and chg_embedding is not None: + obj.chg_embedding = TypeEmbedNet.deserialize(chg_embedding) + obj.spin_embedding = TypeEmbedNet.deserialize(spin_embedding) + obj.mix_cs_mlp = NativeLayer.deserialize(mix_cs_mlp) + # deserialize repflow statistic_repflows = repflow_variable.pop("@variables") env_mat = repflow_variable.pop("env_mat") diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 4279f0bfcd..1b0065a232 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -275,6 +275,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> tuple[ Array, Array | None, diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index f87ca2c5b6..c3a9e61d16 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -186,6 +186,7 @@ def fwd( extended_atype: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> Array: """Calculate descriptor.""" pass diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 4710987f54..abd9fc5c33 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -395,6 +395,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 5ea9ef525f..e978c6afde 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -367,6 +367,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 7877a1e9ab..ab6d7aefd0 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -342,6 +342,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> tuple[Array, Array]: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 994fa63b30..a50c016049 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -350,6 +350,7 @@ def call( atype_ext: Array, nlist: Array, mapping: Array | None = None, + fparam: Array | None = None, ) -> tuple[Array, Array]: """Compute the descriptor. diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 78fa0c3cf7..3f841ad03d 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -67,6 +67,9 @@ def __init__( if hasattr(self.fitting_net, "reinit_exclude"): self.fitting_net.reinit_exclude(self.atom_exclude_types) super().init_out_stat() + self.add_chg_spin_ebd: bool = getattr( + self.descriptor, "add_chg_spin_ebd", False + ) self.enable_eval_descriptor_hook = False self.enable_eval_fitting_last_layer_hook = False self.eval_descriptor_list = [] @@ -270,12 +273,24 @@ def forward_atomic( atype = extended_atype[:, :nloc] if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) + + if self.fitting_net.get_dim_fparam() > 0 and fparam is None: + # use default fparam + default_fparam_tensor = self.fitting_net.get_default_fparam() + assert default_fparam_tensor is not None + fparam_input_for_des = torch.tile( + default_fparam_tensor.unsqueeze(0), [nframes, 1] + ) + else: + fparam_input_for_des = fparam + descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, nlist, mapping=mapping, comm_dict=comm_dict, + fparam=fparam_input_for_des if self.add_chg_spin_ebd else None, ) assert descriptor is not None if self.enable_eval_descriptor_hook: diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 359bf2f084..df5f8297b9 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -666,6 +666,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 30d1987b97..6516cff8bf 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -715,6 +715,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 136527123e..0ae41cd03c 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -32,6 +32,7 @@ UpdateSel, ) from deepmd.pt.utils.utils import ( + ActivationFn, to_numpy_array, ) from deepmd.utils.data_system import ( @@ -120,6 +121,7 @@ def __init__( use_tebd_bias: bool = False, use_loc_mapping: bool = True, type_map: list[str] | None = None, + add_chg_spin_ebd: bool = False, ) -> None: super().__init__() @@ -174,6 +176,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: ) self.use_econf_tebd = use_econf_tebd + self.add_chg_spin_ebd = add_chg_spin_ebd self.use_loc_mapping = use_loc_mapping self.use_tebd_bias = use_tebd_bias self.type_map = type_map @@ -191,6 +194,34 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: self.concat_output_tebd = concat_output_tebd self.precision = precision self.prec = PRECISION_DICT[self.precision] + + if self.add_chg_spin_ebd: + self.act = ActivationFn(activation_function) + # -100 ~ 100 is a conservative bound + self.chg_embedding = TypeEmbedNet( + 200, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 3), + ) + # 100 is a conservative upper bound + self.spin_embedding = TypeEmbedNet( + 100, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 4), + ) + self.mix_cs_mlp = MLPLayer( + 2 * self.tebd_dim, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 3), + ) + else: + self.chg_embedding = None + self.spin_embedding = None + self.mix_cs_mlp = None + self.exclude_types = exclude_types self.env_protection = env_protection self.trainable = trainable @@ -395,9 +426,14 @@ def serialize(self) -> dict: "use_econf_tebd": self.use_econf_tebd, "use_tebd_bias": self.use_tebd_bias, "use_loc_mapping": self.use_loc_mapping, + "add_chg_spin_ebd": self.add_chg_spin_ebd, "type_map": self.type_map, "type_embedding": self.type_embedding.embedding.serialize(), } + if self.add_chg_spin_ebd: + data["chg_embedding"] = self.chg_embedding.embedding.serialize() + data["spin_embedding"] = self.spin_embedding.embedding.serialize() + data["mix_cs_mlp"] = self.mix_cs_mlp.serialize() repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), "angle_embd": repflows.angle_embd.serialize(), @@ -424,12 +460,24 @@ def deserialize(cls, data: dict) -> "DescrptDPA3": data.pop("type") repflow_variable = data.pop("repflow_variable").copy() type_embedding = data.pop("type_embedding") + chg_embedding = data.pop("chg_embedding", None) + spin_embedding = data.pop("spin_embedding", None) + mix_cs_mlp = data.pop("mix_cs_mlp", None) data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) obj = cls(**data) obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( type_embedding ) + if obj.add_chg_spin_ebd and chg_embedding is not None: + obj.chg_embedding.embedding = TypeEmbedNetConsistent.deserialize( + chg_embedding + ) + obj.spin_embedding.embedding = TypeEmbedNetConsistent.deserialize( + spin_embedding + ) + obj.mix_cs_mlp = MLPLayer.deserialize(mix_cs_mlp) + def t_cvt(xx: Any) -> torch.Tensor: return torch.tensor(xx, dtype=obj.repflows.prec, device=env.DEVICE) @@ -455,6 +503,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -504,6 +553,20 @@ def forward( node_ebd_ext = self.type_embedding(extended_atype[:, :nloc]) else: node_ebd_ext = self.type_embedding(extended_atype) + + if self.add_chg_spin_ebd: + assert fparam is not None + assert self.chg_embedding is not None + assert self.spin_embedding is not None + charge = fparam[:, 0].to(dtype=torch.int64) + 100 + spin = fparam[:, 1].to(dtype=torch.int64) + chg_ebd = self.chg_embedding(charge) + spin_ebd = self.spin_embedding(spin) + sys_cs_embd = self.act( + self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1)) + ) + node_ebd_ext = node_ebd_ext + sys_cs_embd.unsqueeze(1) + node_ebd_inp = node_ebd_ext[:, :nloc, :] # repflows node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 0f001bc4c8..dc738563d3 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -268,6 +268,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 59c165ddb0..2d4aa40d05 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -308,6 +308,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 1d5a8fc1a8..0cf2e5b344 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -427,6 +427,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 9771d8fe6f..6f8533fb7c 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -343,6 +343,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index d639cc94bc..2cc3a08312 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -423,6 +423,7 @@ def forward( nlist: torch.Tensor, mapping: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d1331a711a..9ca53310a3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1366,6 +1366,11 @@ def descrpt_dpa3_args() -> list[Argument]: doc_concat_output_tebd = ( "Whether to concat type embedding at the output of the descriptor." ) + doc_add_chg_spin_ebd = ( + "Whether to add charge and spin embedding to the descriptor. " + "When enabled, fparam is expected to have 2 values (charge, spin) " + "which are embedded and added to the type embedding." + ) doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." @@ -1389,6 +1394,13 @@ def descrpt_dpa3_args() -> list[Argument]: default=False, doc=doc_concat_output_tebd, ), + Argument( + "add_chg_spin_ebd", + bool, + optional=True, + default=False, + doc=doc_add_chg_spin_ebd, + ), Argument( "activation_function", str, diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index e82fb0dda8..73433420c7 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -102,6 +102,7 @@ def eval_dp_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( coords.reshape(1, -1, 3), @@ -117,7 +118,9 @@ def eval_dp_descriptor( dp_obj.get_sel(), distinguish_types=(not mixed_types), ) - return dp_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + return dp_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam + ) def eval_pt_descriptor( self, @@ -127,6 +130,7 @@ def eval_pt_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), @@ -142,9 +146,14 @@ def eval_pt_descriptor( pt_obj.get_sel(), distinguish_types=(not mixed_types), ) + fparam_pt = ( + torch.from_numpy(fparam).to(PT_DEVICE) if fparam is not None else None + ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x - for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + for x in pt_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_pt + ) ] def eval_pt_expt_descriptor( @@ -155,6 +164,7 @@ def eval_pt_expt_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), @@ -170,9 +180,14 @@ def eval_pt_expt_descriptor( pt_expt_obj.get_sel(), distinguish_types=(not mixed_types), ) + fparam_pt = ( + torch.from_numpy(fparam).to(PT_DEVICE) if fparam is not None else None + ) return [ x.detach().cpu().numpy() if torch.is_tensor(x) else x - for x in pt_expt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + for x in pt_expt_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_pt + ) ] def eval_jax_descriptor( @@ -183,6 +198,7 @@ def eval_jax_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( jnp.array(coords).reshape(1, -1, 3), @@ -198,9 +214,12 @@ def eval_jax_descriptor( jax_obj.get_sel(), distinguish_types=(not mixed_types), ) + fparam_jax = jnp.array(fparam) if fparam is not None else None return [ np.asarray(x) if isinstance(x, jnp.ndarray) else x - for x in jax_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + for x in jax_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_jax + ) ] def eval_pd_descriptor( @@ -211,6 +230,7 @@ def eval_pd_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pd( paddle.to_tensor(coords).to(PD_DEVICE).reshape([1, -1, 3]), @@ -226,9 +246,14 @@ def eval_pd_descriptor( pd_obj.get_sel(), distinguish_types=(not mixed_types), ) + fparam_pd = ( + paddle.to_tensor(fparam).to(PD_DEVICE) if fparam is not None else None + ) return [ x.detach().cpu().numpy() if paddle.is_tensor(x) else x - for x in pd_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + for x in pd_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_pd + ) ] def eval_array_api_strict_descriptor( @@ -239,6 +264,7 @@ def eval_array_api_strict_descriptor( atype: np.ndarray, box: np.ndarray, mixed_types: bool = False, + fparam: np.ndarray | None = None, ) -> Any: ext_coords, ext_atype, mapping = extend_coord_with_ghosts( array_api_strict.asarray(coords.reshape(1, -1, 3)), @@ -254,10 +280,17 @@ def eval_array_api_strict_descriptor( array_api_strict_obj.get_sel(), distinguish_types=(not mixed_types), ) + fparam_array_api = ( + array_api_strict.asarray(fparam) if fparam is not None else None + ) return [ to_numpy_array(x) if hasattr(x, "__array_namespace__") else x for x in array_api_strict_obj( - ext_coords, ext_atype, nlist=nlist, mapping=mapping + ext_coords, + ext_atype, + nlist=nlist, + mapping=mapping, + fparam=fparam_array_api, ) ] diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 65d471fb99..bca0759f5c 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -75,9 +75,10 @@ (True, False), # use_exp_switch (True, False), # use_dynamic_sel (True, False), # use_loc_mapping - (0.3, 0.0), # fix_stat_std + (0.3,), # fix_stat_std (1,), # n_multi_edge_message ("float64",), # precision + (False, True), # add_chg_spin_ebd ) class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): @property @@ -97,6 +98,7 @@ def data(self) -> dict: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param return { "ntypes": self.ntypes, @@ -137,6 +139,7 @@ def data(self) -> dict: "env_protection": 0.0, "use_loc_mapping": use_loc_mapping, "trainable": False, + "add_chg_spin_ebd": add_chg_spin_ebd, } @property @@ -156,6 +159,7 @@ def skip_pt(self) -> bool: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param return CommonTest.skip_pt @@ -176,8 +180,9 @@ def skip_pd(self) -> bool: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param - return CommonTest.skip_pd + return True if add_chg_spin_ebd else CommonTest.skip_pd @property def skip_dp(self) -> bool: @@ -196,6 +201,7 @@ def skip_dp(self) -> bool: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param return CommonTest.skip_dp @@ -216,6 +222,7 @@ def skip_tf(self) -> bool: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param return True @@ -280,7 +287,14 @@ def setUp(self) -> None: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param + # fparam for charge=5, spin=1 when add_chg_spin_ebd is True + self.fparam = ( + np.array([[5, 1]], dtype=GLOBAL_NP_FLOAT_PRECISION) + if add_chg_spin_ebd + else None + ) def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: return self.build_tf_descriptor( @@ -300,6 +314,7 @@ def eval_dp(self, dp_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def eval_pt(self, pt_obj: Any) -> Any: @@ -310,6 +325,7 @@ def eval_pt(self, pt_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def eval_pd(self, pd_obj: Any) -> Any: @@ -320,6 +336,7 @@ def eval_pd(self, pd_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def eval_jax(self, jax_obj: Any) -> Any: @@ -330,6 +347,7 @@ def eval_jax(self, jax_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def eval_pt_expt(self, pt_expt_obj: Any) -> Any: @@ -340,6 +358,7 @@ def eval_pt_expt(self, pt_expt_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: @@ -350,6 +369,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: self.atype, self.box, mixed_types=True, + fparam=self.fparam, ) def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: @@ -373,6 +393,7 @@ def rtol(self) -> float: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param if precision == "float64": return 1e-10 @@ -399,6 +420,7 @@ def atol(self) -> float: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index c2a53ba18d..12b0be4532 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -55,6 +55,7 @@ def test_consistency( nme, prec, ect, + add_chg_spin, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style @@ -65,6 +66,7 @@ def test_consistency( [1, 2], # n_multi_edge_message ["float64"], # precision [False], # use_econf_tebd + [False, True], # add_chg_spin_ebd ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -102,16 +104,28 @@ def test_consistency( precision=prec, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, + add_chg_spin_ebd=add_chg_spin, seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + + # Prepare fparam if needed + fparam = None + fparam_np = None + if add_chg_spin: + fparam = torch.tensor([[5, 1]], dtype=dtype, device=env.DEVICE).expand( + nf, -1 + ) + fparam_np = np.array([[5, 1]], dtype=np.float64).repeat(nf, axis=0) + rd0, _, _, _, _ = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + fparam=fparam, ) # serialization dd1 = DescrptDPA3.deserialize(dd0.serialize()) @@ -120,6 +134,7 @@ def test_consistency( torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + fparam=fparam, ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), @@ -130,7 +145,11 @@ def test_consistency( # dp impl dd2 = DPDescrptDPA3.deserialize(dd0.serialize()) rd2, _, _, _, _ = dd2.call( - self.coord_ext, self.atype_ext, self.nlist, self.mapping + self.coord_ext, + self.atype_ext, + self.nlist, + self.mapping, + fparam=fparam_np, ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 4e5281e2e2..6f5a337b4d 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -488,6 +488,7 @@ def DescriptorParamDPA3( use_dynamic_sel=False, precision="float64", use_loc_mapping=True, + add_chg_spin_ebd=False, ): input_dict = { # kwargs for repformer @@ -535,6 +536,7 @@ def DescriptorParamDPA3( "use_econf_tebd": False, "use_tebd_bias": False, "use_loc_mapping": use_loc_mapping, + "add_chg_spin_ebd": add_chg_spin_ebd, "type_map": type_map, "seed": GLOBAL_SEED, } @@ -547,7 +549,7 @@ def DescriptorParamDPA3( { "update_residual_init": ("const",), "exclude_types": ([], [[0, 1]]), - "update_angle": (True, False), + "update_angle": (True,), "a_compress_rate": (1,), "a_compress_e_rate": (2,), "a_compress_use_split": (True,), @@ -561,6 +563,7 @@ def DescriptorParamDPA3( "env_protection": (0.0, 1e-8), "precision": ("float64",), "use_loc_mapping": (True, False), + "add_chg_spin_ebd": (False, True), } ), ) diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index c82074c601..b5c6bd82ee 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -62,6 +62,14 @@ def skip_model_tests(test_obj): + if test_obj.input_dict_ds.get("add_chg_spin_ebd", False): + import inspect + + (FittingParam, _) = test_obj.param[1] + sig = inspect.signature(FittingParam) + numb_param = sig.parameters.get("numb_param") + if numb_param is None or numb_param.default != 2: + return True, "add_chg_spin_ebd requires numb_fparam=2" if not test_obj.input_dict_ds.get( "smooth_type_embedding", True ) or not test_obj.input_dict_ds.get("smooth", True): From 623269f4809804aeadc181a913dfb87a7878c89a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:58:53 +0800 Subject: [PATCH 2/6] fix ut --- deepmd/dpmodel/atomic_model/dp_atomic_model.py | 12 ++++++++++-- deepmd/pt/model/atomic_model/dp_atomic_model.py | 7 ++++++- source/tests/consistent/descriptor/common.py | 8 ++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 994121d5f7..6d1ec50b32 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -184,7 +184,12 @@ def forward_atomic( nframes, nloc, nnei = nlist.shape atype = xp_take_first_n(extended_atype, 1, nloc) - if self.fitting_net.get_dim_fparam() > 0 and fparam is None: + # Handle default fparam if fitting net supports it + if ( + hasattr(self.fitting_net, "get_dim_fparam") + and self.fitting_net.get_dim_fparam() > 0 + and fparam is None + ): # use default fparam from deepmd.dpmodel.array_api import ( array_api_compat, @@ -193,8 +198,11 @@ def forward_atomic( default_fparam = self.fitting_net.get_default_fparam() assert default_fparam is not None xp = array_api_compat.array_namespace(extended_coord) + default_fparam_array = xp.asarray( + default_fparam, dtype=extended_coord.dtype + ) fparam_input_for_des = xp.tile( - xp.reshape(default_fparam, (1, -1)), (nframes, 1) + xp.reshape(default_fparam_array, (1, -1)), (nframes, 1) ) else: fparam_input_for_des = fparam diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index a586cc56bd..62f9b6042d 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -274,7 +274,12 @@ def forward_atomic( if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) - if self.fitting_net.get_dim_fparam() > 0 and fparam is None: + # Handle default fparam if fitting net supports it + if ( + hasattr(self.fitting_net, "get_dim_fparam") + and self.fitting_net.get_dim_fparam() > 0 + and fparam is None + ): # use default fparam default_fparam_tensor = self.fitting_net.get_default_fparam() assert default_fparam_tensor is not None diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 73433420c7..33bf7312de 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -246,13 +246,13 @@ def eval_pd_descriptor( pd_obj.get_sel(), distinguish_types=(not mixed_types), ) - fparam_pd = ( - paddle.to_tensor(fparam).to(PD_DEVICE) if fparam is not None else None - ) return [ x.detach().cpu().numpy() if paddle.is_tensor(x) else x for x in pd_obj( - ext_coords, ext_atype, nlist=nlist, mapping=mapping, fparam=fparam_pd + ext_coords, + ext_atype, + nlist=nlist, + mapping=mapping, ) ] From 7a965360d0bb26e6726fbb94a20ed43afb012104 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 6 Mar 2026 22:22:27 +0800 Subject: [PATCH 3/6] fix ut --- deepmd/jax/descriptor/dpa3.py | 11 +++++++++-- source/tests/array_api_strict/descriptor/dpa3.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/deepmd/jax/descriptor/dpa3.py b/deepmd/jax/descriptor/dpa3.py index 9f734bd553..226acc48db 100644 --- a/deepmd/jax/descriptor/dpa3.py +++ b/deepmd/jax/descriptor/dpa3.py @@ -23,6 +23,9 @@ flax_version, nnx, ) +from deepmd.jax.utils.network import ( + NativeLayer, +) from deepmd.jax.utils.type_embed import ( TypeEmbedNet, ) @@ -40,8 +43,12 @@ def __setattr__(self, name: str, value: Any) -> None: value = nnx.data(value) elif name in {"repflows"}: value = DescrptBlockRepflows.deserialize(value.serialize()) - elif name in {"type_embedding"}: - value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"type_embedding", "chg_embedding", "spin_embedding"}: + if value is not None: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"mix_cs_mlp"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) else: pass return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/dpa3.py b/source/tests/array_api_strict/descriptor/dpa3.py index 19071ee59c..0086713e93 100644 --- a/source/tests/array_api_strict/descriptor/dpa3.py +++ b/source/tests/array_api_strict/descriptor/dpa3.py @@ -8,6 +8,9 @@ from ..common import ( to_array_api_strict_array, ) +from ..utils.network import ( + NativeLayer, +) from ..utils.type_embed import ( TypeEmbedNet, ) @@ -26,8 +29,12 @@ def __setattr__(self, name: str, value: Any) -> None: value = to_array_api_strict_array(value) elif name in {"repflows"}: value = DescrptBlockRepflows.deserialize(value.serialize()) - elif name in {"type_embedding"}: - value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"type_embedding", "chg_embedding", "spin_embedding"}: + if value is not None: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"mix_cs_mlp"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) else: pass return super().__setattr__(name, value) From 6cd1076909ed38bf15690329cb0825191fa50f89 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sat, 7 Mar 2026 16:13:13 +0800 Subject: [PATCH 4/6] fix device --- deepmd/dpmodel/atomic_model/dp_atomic_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 6d1ec50b32..466e3ddd95 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -199,7 +199,9 @@ def forward_atomic( assert default_fparam is not None xp = array_api_compat.array_namespace(extended_coord) default_fparam_array = xp.asarray( - default_fparam, dtype=extended_coord.dtype + default_fparam, + dtype=extended_coord.dtype, + device=array_api_compat.device(extended_coord), ) fparam_input_for_des = xp.tile( xp.reshape(default_fparam_array, (1, -1)), (nframes, 1) From 8979f8138966108e6d2aba273bba3b1e28b14e56 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 11 Mar 2026 19:58:39 +0800 Subject: [PATCH 5/6] add ut --- source/tests/consistent/model/test_ener.py | 140 +++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 038b428fd9..4c914bff41 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -1859,3 +1859,143 @@ def raise_error(): # 5. Cross-backend consistency after loading compare_variables_recursive(dp_ser_loaded, pt_ser_loaded) compare_variables_recursive(dp_ser_loaded, pe_ser_loaded) + + +@parameterized( + ("no_fparam", "explicit_fparam", "default_fparam"), # fparam_mode +) +@unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") +class TestEnerChgSpinEbdFparam(unittest.TestCase): + """Test dp/pt/pt_expt model forward consistency for add_chg_spin_ebd with three fparam modes. + + - no_fparam: numb_fparam=0, add_chg_spin_ebd=False (baseline) + - explicit_fparam: numb_fparam=2, add_chg_spin_ebd=True, fparam provided + - default_fparam: numb_fparam=2, default_fparam set, add_chg_spin_ebd=True, fparam=None + """ + + def setUp(self) -> None: + (self.fparam_mode,) = self.param + + add_chg_spin_ebd = self.fparam_mode != "no_fparam" + fitting_cfg: dict[str, Any] = { + "neuron": [10, 10], + "precision": "float64", + "seed": 1, + } + if self.fparam_mode != "no_fparam": + fitting_cfg["numb_fparam"] = 2 + if self.fparam_mode == "default_fparam": + fitting_cfg["default_fparam"] = [5, 1] + + data = model_args().normalize_value( + { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 20, + "e_dim": 10, + "a_dim": 8, + "nlayers": 3, + "e_rcut": 6.0, + "e_rcut_smth": 5.0, + "e_sel": 10, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 8, + "axis_neuron": 4, + "update_angle": True, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + }, + "precision": "float64", + "seed": 1, + "add_chg_spin_ebd": add_chg_spin_ebd, + }, + "fitting_net": fitting_cfg, + }, + trim_pattern="_*", + ) + + self.dp_model = get_model_dp(data) + serialized = self.dp_model.serialize() + self.pt_model = EnergyModelPT.deserialize(serialized) + self.pt_expt_model = EnergyModelPTExpt.deserialize(serialized) + + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + + # fparam: charge=5, spin=1 + if self.fparam_mode == "explicit_fparam": + self.fparam_np = np.array([[5, 1]], dtype=GLOBAL_NP_FLOAT_PRECISION) + else: + self.fparam_np = None + + def test_forward_consistency(self) -> None: + dp_ret = self.dp_model( + self.coords, self.atype, box=self.box, fparam=self.fparam_np + ) + pt_ret = { + kk: torch_to_numpy(vv) + for kk, vv in self.pt_model( + numpy_to_torch(self.coords), + numpy_to_torch(self.atype), + box=numpy_to_torch(self.box), + fparam=numpy_to_torch(self.fparam_np), + do_atomic_virial=True, + ).items() + } + coord_t = pt_expt_numpy_to_torch(self.coords) + coord_t.requires_grad_(True) + pe_ret = { + k: v.detach().cpu().numpy() + for k, v in self.pt_expt_model( + coord_t, + pt_expt_numpy_to_torch(self.atype), + box=pt_expt_numpy_to_torch(self.box), + fparam=pt_expt_numpy_to_torch(self.fparam_np), + do_atomic_virial=True, + ).items() + } + for key in ("energy", "atom_energy"): + np.testing.assert_allclose( + dp_ret[key], + pt_ret[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"dp vs pt mismatch in {key} (mode={self.fparam_mode})", + ) + np.testing.assert_allclose( + dp_ret[key], + pe_ret[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"dp vs pt_expt mismatch in {key} (mode={self.fparam_mode})", + ) From 8a7bf3a30a883429216fb0ed8dff230e72ce1bfb Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 13 Mar 2026 03:17:09 +0800 Subject: [PATCH 6/6] fix seed --- deepmd/dpmodel/descriptor/dpa3.py | 2 +- deepmd/pt/model/descriptor/dpa3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5ffcae5004..8222ee2d81 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -456,7 +456,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: 2 * self.tebd_dim, self.tebd_dim, precision=precision, - seed=child_seed(seed, 3), + seed=child_seed(seed, 5), ) else: self.chg_embedding = None diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0ae41cd03c..0c6982afe5 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -215,7 +215,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: 2 * self.tebd_dim, self.tebd_dim, precision=precision, - seed=child_seed(seed, 3), + seed=child_seed(seed, 5), ) else: self.chg_embedding = None