diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 1120078bb2..debddba6e7 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -231,6 +231,7 @@ def forward_common_atomic( mapping: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: """Common interface for atomic inference. @@ -252,6 +253,9 @@ def forward_common_atomic( frame parameters, shape: nf x dim_fparam aparam atomic parameter, shape: nf x nloc x dim_aparam + comm_dict + MPI communication metadata for parallel inference. ``None`` for + non-parallel inference (default). Returns ------- @@ -279,6 +283,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + comm_dict=comm_dict, ) ret_dict = self.apply_out_stat(ret_dict, atype) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 466e3ddd95..0505f63d83 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -157,6 +157,7 @@ def forward_atomic( mapping: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: """Models' atomic predictions. @@ -174,6 +175,9 @@ def forward_atomic( frame parameter. nf x ndf aparam atomic parameter. nf x nloc x nda + comm_dict + MPI communication metadata for parallel inference. ``None`` for + non-parallel inference (default). Forwarded to the descriptor. Returns ------- @@ -215,6 +219,7 @@ def forward_atomic( nlist, mapping=mapping, fparam=fparam_input_for_des if self.add_chg_spin_ebd else None, + comm_dict=comm_dict, ) ret = self.fitting_net( descriptor, diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 3ed9077df7..05ff8499f8 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -224,6 +224,7 @@ def forward_atomic( mapping: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: """Return atomic prediction. @@ -241,6 +242,10 @@ def forward_atomic( frame parameter. (nframes, ndf) aparam atomic parameter. (nframes, nloc, nda) + comm_dict + MPI communication metadata. Forwarded to each sub-model so GNN + sub-descriptors can perform parallel ghost exchange. ``None`` for + non-parallel inference (default). Returns ------- @@ -280,6 +285,7 @@ def forward_atomic( mapping, fparam, aparam, + comm_dict, )["energy"] ) weights = self._compute_weight(extended_coord, extended_atype, nlists_) diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 51c370eca0..c1ec9d2a00 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -253,7 +253,9 @@ def forward_atomic( mapping: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: + del comm_dict # pairtab is local; no MPI ghost exchange needed. xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) nframes, nloc, nnei = nlist.shape extended_coord = xp.reshape(extended_coord, (nframes, -1, 3)) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index bc2a04a836..04d0420009 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -397,6 +397,14 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.se_atten.has_message_passing() + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange. + + DPA1 (se_atten) is single-layer and does not exchange features + across ranks; same as the base se_e2_a path. + """ + return False + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return self.se_atten.need_sorted_nlist_for_lower() @@ -500,6 +508,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 2fa765f04b..e530398ca6 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -687,6 +687,16 @@ def has_message_passing(self) -> bool: [self.repinit.has_message_passing(), self.repformers.has_message_passing()] ) + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange. + + DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]`` + layout (no ``use_loc_mapping`` opt-out exists at the block level), + so multi-rank deployment always needs cross-rank exchange of + per-atom features between layers. + """ + return self.repformers.has_message_passing_across_ranks() + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return True @@ -831,6 +841,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[Array, Array, Array, Array, Array]: """Compute the descriptor. @@ -844,6 +855,11 @@ def call( The neighbor list. shape: nf x nloc x nnei mapping The index mapping, maps extended region index to local region. + comm_dict + MPI communication metadata for parallel inference. Forwarded to + the repformer block (the message-passing part). The repinit + sub-block does no message passing and does not receive it. + ``None`` for non-parallel inference (default). Returns ------- @@ -912,9 +928,18 @@ def call( assert self.tebd_transform is not None g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 - assert mapping is not None - mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1])) - g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1) + if comm_dict is None: + # non-parallel: gather g1 -> g1_ext via mapping, hand the + # nall-sized embedding to the repformer block. + assert mapping is not None + mapping_ext = xp.tile( + xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]) + ) + g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1) + else: + # parallel mode: hand the local-only g1 to the repformer block; + # its per-layer override fills ghosts via the MPI exchange. + g1_ext = g1 # repformer g1, g2, h2, rot_mat, sw = self.repformers( nlist_dict[ @@ -926,6 +951,7 @@ def call( atype_ext, g1_ext, mapping, + comm_dict=comm_dict, ) if self.concat_output_tebd: g1 = xp.concat([g1, g1_inp], axis=-1) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5f5aea50e5..c1d9531357 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -527,6 +527,17 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.repflows.has_message_passing() + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange. + + Delegates to repflows: ``False`` when ``use_loc_mapping=True`` + (per-layer messages stay within each rank's local atoms), + ``True`` when ``use_loc_mapping=False`` (ghost slots in + ``[nb, nall, n_dim]`` layout must be filled by cross-rank + exchange before each layer). + """ + return self.repflows.has_message_passing_across_ranks() + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return True @@ -616,6 +627,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[Array, Array, Array, Array, Array]: """Compute the descriptor. @@ -629,6 +641,9 @@ def call( The neighbor list. shape: nf x nloc x nnei mapping The index mapping, mapps extended region index to local region. + comm_dict + MPI communication metadata for parallel inference. Forwarded to + the repflows block. ``None`` for non-parallel inference (default). Returns ------- @@ -695,6 +710,7 @@ def call( atype_ext, node_ebd_ext, mapping, + comm_dict=comm_dict, ) if self.concat_output_tebd: node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1) diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index b15fbc15d2..a51220c5e2 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -168,6 +168,16 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange. + + ``True`` if any child descriptor needs cross-rank message passing + (e.g. a hybrid wrapping a DPA3 with ``use_loc_mapping=False``). + """ + return any( + descrpt.has_message_passing_across_ranks() for descrpt in self.descrpt_list + ) + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return True @@ -276,6 +286,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[ Array, Array | None, @@ -332,7 +343,9 @@ def call( # mixed_types is True, but descrpt.mixed_types is False assert nl_distinguish_types is not None nl = nl_distinguish_types[:, :, nci] - odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping) + odescriptor, gr, _g2, _h2, _sw = descrpt( + coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict + ) out_descriptor.append(odescriptor) if gr is not None: out_gr.append(gr) diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 47245898ce..8184b4e42a 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -107,6 +107,24 @@ def mixed_types(self) -> bool: def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" + def has_message_passing_across_ranks(self) -> bool: + """Returns whether the descriptor's message passing extends across rank + boundaries — i.e. whether it requires cross-rank exchange of intermediate + atomic features (per-layer node embeddings) during the forward pass. + + Distinct from generic ghost-coord/force exchange that every LAMMPS + pair_style does. This question gates whether the pt_expt backend + compiles a second "with-comm" AOTI artifact for multi-rank deployment. + + Concrete default ``False`` (non-GNN behavior) so pt and pd backend + descriptors that subclass ``BaseDescriptor`` directly do not have + to implement this method until they grow a multi-rank GNN path of + their own. GNN descriptors that need MPI ghost-feature exchange + (DPA2, DPA3 with ``use_loc_mapping=False``, hybrids wrapping such + children) override to return ``True``. + """ + return False + @abstractmethod def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 30637dc75a..bc94b877ea 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -506,6 +506,32 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def _exchange_ghosts( + self, + node_ebd: Array, + mapping_tiled: Array | None, + comm_dict: dict | None, + nall: int, + nloc: int, + ) -> Array: + """Build node_ebd_ext (the ghost-aware embedding) for the per-layer loop. + + Default: array-api gather via the pre-tiled `mapping_tiled`, or pass the + local-only `node_ebd` through when ``self.use_loc_mapping`` is set. + ``comm_dict``, ``nall``, ``nloc`` are unused in this default impl; they + exist so the pt_expt subclass can perform the per-layer MPI ghost + exchange (``deepmd_export::border_op``) when ``comm_dict is not None``. + """ + del comm_dict, nall, nloc + if self.use_loc_mapping: + return node_ebd + if mapping_tiled is None: + raise ValueError( + "`mapping` is required when use_loc_mapping=False unless " + "`_exchange_ghosts` is overridden for parallel comm handling." + ) + return xp_take_along_axis(node_ebd, mapping_tiled, axis=1) + def call( self, nlist: Array, @@ -514,6 +540,7 @@ def call( atype_embd_ext: Array | None = None, mapping: Array | None = None, type_embedding: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[Array, Array, Array, Array, Array]: xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) nframes, nloc, nnei = nlist.shape @@ -641,15 +668,24 @@ def call( # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim angle_ebd = self.angle_embd(angle_input) - # nb x nall x n_dim - mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim)) + # nb x nall x n_dim (pre-tiled mapping reused across layers when not + # using comm_dict). Skip the tile when mapping is None — pt_expt's + # parallel-mode override consults comm_dict instead. + mapping_tiled = ( + xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim)) + if mapping is not None + else None + ) for idx, ll in enumerate(self.layers): # node_ebd: nb x nloc x n_dim - # node_ebd_ext: nb x nall x n_dim - node_ebd_ext = ( - node_ebd - if self.use_loc_mapping - else xp_take_along_axis(node_ebd, mapping, axis=1) + # node_ebd_ext: nb x nall x n_dim (or nb x nloc x n_dim when + # use_loc_mapping=True) + node_ebd_ext = self._exchange_ghosts( + node_ebd, + mapping_tiled, + comm_dict, + nall, + nloc, ) node_ebd, edge_ebd, angle_ebd = ll.call( node_ebd_ext, @@ -696,6 +732,16 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange. + + Repflows passes ``node_ebd`` either in ``[nb, nloc, n_dim]`` layout + (``use_loc_mapping=True``: messages stay within the rank's local atoms) + or ``[nb, nall, n_dim]`` layout (``use_loc_mapping=False``: ghost slots + must be filled by cross-rank exchange before each layer). + """ + return not self.use_loc_mapping + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return True diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 5881b3a0b3..799ab0c3c3 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -480,6 +480,32 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def _exchange_ghosts( + self, + g1: Array, + mapping_tiled: Array | None, + comm_dict: dict | None, + nall: int, + nloc: int, + ) -> Array: + """Build g1_ext (the ghost-aware single-atom embedding) for the + per-layer loop. + + Default: array-api gather via the pre-tiled ``mapping_tiled``. + ``comm_dict``, ``nall``, ``nloc`` are unused in this default impl; + they exist so the pt_expt subclass can perform the per-layer MPI + ghost exchange (``deepmd_export::border_op``) when ``comm_dict is + not None``. + """ + del comm_dict, nall, nloc + if mapping_tiled is None: + raise ValueError( + "`mapping` is required by the default `_exchange_ghosts` " + "implementation; pass a valid mapping or override the method " + "for parallel comm handling." + ) + return xp_take_along_axis(g1, mapping_tiled, axis=1) + def call( self, nlist: Array, @@ -488,6 +514,7 @@ def call( atype_embd_ext: Array | None = None, mapping: Array | None = None, type_embedding: Array | None = None, + comm_dict: dict | None = None, ) -> Array: xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) @@ -524,12 +551,27 @@ def call( # set all padding positions to index of 0 # if a neighbor is real or not is indicated by nlist_mask nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist) - # nf x nall x ng1 - mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.g1_dim)) + # nall computed for the pt_expt parallel-mode override (uses nall to + # size the pad before MPI ghost exchange). dpmodel default ignores it. + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 + # nf x nall x ng1 (pre-tiled mapping reused across layers when not + # using comm_dict). Skip the tile when mapping is None — pt_expt's + # parallel-mode override consults comm_dict instead. + mapping_tiled = ( + xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.g1_dim)) + if mapping is not None + else None + ) for idx, ll in enumerate(self.layers): # g1: nf x nloc x ng1 # g1_ext: nf x nall x ng1 - g1_ext = xp_take_along_axis(g1, mapping, axis=1) + g1_ext = self._exchange_ghosts( + g1, + mapping_tiled, + comm_dict, + nall, + nloc, + ) g1, g2, h2 = ll.call( g1_ext, g2, @@ -558,6 +600,15 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer g1 needs MPI ghost exchange. + + Repformers has no ``use_loc_mapping`` opt-out; it always passes + ``g1`` in ``[nb, nall, n_dim]`` layout, so multi-rank always needs + cross-rank exchange of the per-atom feature tensor. + """ + return True + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 8997412325..f72b6f75e8 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -278,6 +278,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange.""" + return False + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return False @@ -399,6 +403,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index b5ba7a282f..6846710735 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -257,6 +257,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange.""" + return False + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return False @@ -371,6 +375,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> Array: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index e599669068..2d61736235 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -249,6 +249,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange.""" + return False + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return False @@ -346,6 +350,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[Array, Array]: """Compute the descriptor. diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 2d36994d61..2f6e749e19 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -255,6 +255,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.se_ttebd.has_message_passing() + def has_message_passing_across_ranks(self) -> bool: + """Returns whether per-layer node embeddings need MPI ghost exchange.""" + return self.se_ttebd.has_message_passing_across_ranks() + def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" return self.se_ttebd.need_sorted_nlist_for_lower() @@ -354,6 +358,7 @@ def call( nlist: Array, mapping: Array | None = None, fparam: Array | None = None, + comm_dict: dict | None = None, ) -> tuple[Array, Array]: """Compute the descriptor. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index fb77838b4c..d9617e981a 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -326,6 +326,7 @@ def call_common_lower( aparam: Array | None = None, do_atomic_virial: bool = False, extended_coord_corr: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -351,6 +352,11 @@ def call_common_lower( extended_coord_corr coordinates correction for virial in extended region. nf x (nall x 3) + comm_dict + MPI communication metadata for parallel inference (e.g. + LAMMPS multi-rank). Carries send/recv lists, processor IDs, + the MPI communicator handle, and per-rank nlocal/nghost. + ``None`` for non-parallel inference (default). Returns ------- @@ -379,6 +385,7 @@ def call_common_lower( aparam=ap, do_atomic_virial=do_atomic_virial, extended_coord_corr=extended_coord_corr, + comm_dict=comm_dict, ) model_predict = self._output_type_cast(model_predict, input_prec) return model_predict @@ -393,6 +400,7 @@ def forward_common_atomic( aparam: Array | None = None, do_atomic_virial: bool = False, extended_coord_corr: Array | None = None, + comm_dict: dict | None = None, ) -> dict[str, Array]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -401,6 +409,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + comm_dict=comm_dict, ) return fit_output_to_model_output( atomic_ret, diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index be6566e303..2de41945f3 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -748,6 +748,7 @@ def call_common_lower( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + comm_dict: dict | None = None, ) -> dict[str, Array]: """Return model prediction with raw internal keys. Lower interface that takes extended atomic coordinates, types and spins, nlist, and mapping @@ -800,6 +801,7 @@ def call_common_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, extended_coord_corr=extended_coord_corr, + comm_dict=comm_dict, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 7227839f1f..319b8e94a2 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -57,7 +57,9 @@ def forward_common_atomic( mapping: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange return super().forward_common_atomic( extended_coord, extended_atype, diff --git a/deepmd/jax/atomic_model/linear_atomic_model.py b/deepmd/jax/atomic_model/linear_atomic_model.py index 1c183db7ac..ecfc74cf95 100644 --- a/deepmd/jax/atomic_model/linear_atomic_model.py +++ b/deepmd/jax/atomic_model/linear_atomic_model.py @@ -61,7 +61,9 @@ def forward_common_atomic( mapping: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange return super().forward_common_atomic( extended_coord, extended_atype, diff --git a/deepmd/jax/atomic_model/pairtab_atomic_model.py b/deepmd/jax/atomic_model/pairtab_atomic_model.py index 7f18a6403c..0117bf1d2c 100644 --- a/deepmd/jax/atomic_model/pairtab_atomic_model.py +++ b/deepmd/jax/atomic_model/pairtab_atomic_model.py @@ -46,7 +46,9 @@ def forward_common_atomic( mapping: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange return super().forward_common_atomic( extended_coord, extended_atype, diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 4522e25586..f99fccd276 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -26,7 +26,9 @@ def forward_common_atomic( aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, extended_coord_corr: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, extended_atype, diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py index 3e96eb6689..55239bb608 100644 --- a/deepmd/jax/model/dp_model.py +++ b/deepmd/jax/model/dp_model.py @@ -56,7 +56,9 @@ def forward_common_atomic( aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, extended_coord_corr: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange return forward_common_atomic( self, extended_coord, diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py index 7751d22a1f..f2aa68ea1f 100644 --- a/deepmd/jax/model/dp_zbl_model.py +++ b/deepmd/jax/model/dp_zbl_model.py @@ -38,7 +38,9 @@ def forward_common_atomic( aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, extended_coord_corr: jnp.ndarray | None = None, + comm_dict: dict | None = None, ) -> dict[str, jnp.ndarray]: + del comm_dict # JAX path has no MPI ghost exchange return forward_common_atomic( self, extended_coord, diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 1667182d84..8253ed6338 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later # Import to register converters -from . import se_t_tebd_block # noqa: F401 +from . import ( # noqa: F401 + repflows, + repformers, + se_t_tebd_block, +) from .base_descriptor import ( BaseDescriptor, ) diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py index 01df91abd6..c43b07f9c2 100644 --- a/deepmd/pt_expt/descriptor/dpa1.py +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -183,6 +183,7 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptDPA1DP.call.__wrapped__( diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index 1723df5a30..21c392cd3c 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -233,11 +233,19 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptDPA2DP.call.__wrapped__( - self, coord_ext, atype_ext, nlist, mapping + self, + coord_ext, + atype_ext, + nlist, + mapping, + fparam, + comm_dict=comm_dict, ) + # Compressed path is local-only (no message passing during compress). return self._call_compressed(coord_ext, atype_ext, nlist, mapping) def _call_compressed( diff --git a/deepmd/pt_expt/descriptor/repflows.py b/deepmd/pt_expt/descriptor/repflows.py new file mode 100644 index 0000000000..dacab9f464 --- /dev/null +++ b/deepmd/pt_expt/descriptor/repflows.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""pt_expt wrapper around dpmodel ``DescrptBlockRepflows``. + +The wrapper overrides ``_exchange_ghosts`` so that, when running under +LAMMPS multi-rank with a non-None ``comm_dict``, each layer of the +RepFlow message-passing block exchanges ghost-atom embeddings via the +opaque ``deepmd_export::border_op`` wrapper (registered in +``deepmd/pt_expt/utils/comm.py``). This survives ``torch.export`` and +AOTInductor packaging. + +When ``comm_dict is None`` (single-rank inference / training), the +default array-api ``_exchange_ghosts`` from the dpmodel block is used — +zero behavioural change. +""" + +from __future__ import ( + annotations, +) + +import torch + +from deepmd.dpmodel.descriptor.repflows import ( + DescrptBlockRepflows as DescrptBlockRepflowsDP, +) +from deepmd.pt.utils.spin import ( + concat_switch_virtual, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + + +@torch_module +class DescrptBlockRepflows(DescrptBlockRepflowsDP): + """pt_expt wrapper for the RepFlow descriptor block.""" + + def _exchange_ghosts( + self, + node_ebd: torch.Tensor, + mapping_tiled: torch.Tensor | None, + comm_dict: dict | None, + nall: int, + nloc: int, + ) -> torch.Tensor: + if comm_dict is None: + return super()._exchange_ghosts( + node_ebd, + mapping_tiled, + comm_dict, + nall, + nloc, + ) + # Pt's parallel branch (repflows.py:580-587) requires the + # extended-region pathway (use_loc_mapping=False). The + # local-mapping codepath skips the per-layer ghost exchange + # entirely, so combining it with comm_dict is contradictory. + # Surface this as a clear error rather than producing silently + # wrong results. + if getattr(self, "use_loc_mapping", False): + raise RuntimeError( + "DescrptBlockRepflows._exchange_ghosts: comm_dict is " + "set but use_loc_mapping=True. Multi-rank parallel " + "inference requires use_loc_mapping=False so per-layer " + "ghost exchange is meaningful." + ) + # The squeeze(0) / unsqueeze(0) dance below assumes a single + # frame. LAMMPS always feeds nb=1 in production; refuse loudly + # if a Python caller batches frames so the mismatch surfaces + # here rather than as a malformed border_op tensor downstream. + if node_ebd.shape[0] != 1: + raise RuntimeError( + "DescrptBlockRepflows._exchange_ghosts: comm_dict path " + "only supports nf=1 (got nf=" + f"{node_ebd.shape[0]}). Multi-frame batching with " + "comm_dict is not supported." + ) + + has_spin = "has_spin" in comm_dict + if has_spin: + real_nloc, real_nall = nloc // 2, nall // 2 + real_pad = real_nall - real_nloc + node_real, node_virt = torch.split( + node_ebd, + [real_nloc, real_nloc], + dim=1, + ) + # combine real + virtual along feature dim, then pad to nall. + mix = torch.cat([node_real, node_virt], dim=2) + padded = torch.nn.functional.pad( + mix.squeeze(0), + (0, 0, 0, real_pad), + value=0.0, + ) + else: + padded = torch.nn.functional.pad( + node_ebd.squeeze(0), + (0, 0, 0, nall - nloc), + value=0.0, + ) + + exchanged = torch.ops.deepmd_export.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + padded, + comm_dict["communicator"], + comm_dict["nlocal"], + comm_dict["nghost"], + ).unsqueeze(0) + + if has_spin: + n_dim = node_ebd.shape[-1] + real_ext, virt_ext = torch.split(exchanged, [n_dim, n_dim], dim=2) + return concat_switch_virtual(real_ext, virt_ext, real_nloc) + return exchanged + + +# Register the converter so dpmodel's auto-wrap path picks up our pt_expt +# subclass instead of the generic _auto_wrap_native_op fallback. Without +# this, the override above would never fire. +register_dpmodel_mapping( + DescrptBlockRepflowsDP, + lambda v: DescrptBlockRepflows.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/descriptor/repformers.py b/deepmd/pt_expt/descriptor/repformers.py new file mode 100644 index 0000000000..9b8ddb4a85 --- /dev/null +++ b/deepmd/pt_expt/descriptor/repformers.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""pt_expt wrapper around dpmodel ``DescrptBlockRepformers``. + +Mirrors ``deepmd/pt_expt/descriptor/repflows.py``: overrides +``_exchange_ghosts`` so the per-layer ghost exchange uses the opaque +``deepmd_export::border_op`` when a ``comm_dict`` is provided. +""" + +from __future__ import ( + annotations, +) + +import torch + +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers as DescrptBlockRepformersDP, +) +from deepmd.pt.utils.spin import ( + concat_switch_virtual, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + + +@torch_module +class DescrptBlockRepformers(DescrptBlockRepformersDP): + """pt_expt wrapper for the Repformers descriptor block.""" + + def _exchange_ghosts( + self, + g1: torch.Tensor, + mapping_tiled: torch.Tensor | None, + comm_dict: dict | None, + nall: int, + nloc: int, + ) -> torch.Tensor: + if comm_dict is None: + return super()._exchange_ghosts( + g1, + mapping_tiled, + comm_dict, + nall, + nloc, + ) + # The squeeze(0) / unsqueeze(0) dance below assumes a single + # frame. LAMMPS always feeds nb=1 in production; refuse loudly + # if a Python caller batches frames so the mismatch surfaces + # here rather than as a malformed border_op tensor downstream. + if g1.shape[0] != 1: + raise RuntimeError( + "DescrptBlockRepformers._exchange_ghosts: comm_dict path " + "only supports nf=1 (got nf=" + f"{g1.shape[0]}). Multi-frame batching with comm_dict is " + "not supported." + ) + + has_spin = "has_spin" in comm_dict + if has_spin: + real_nloc, real_nall = nloc // 2, nall // 2 + real_pad = real_nall - real_nloc + g1_real, g1_virt = torch.split(g1, [real_nloc, real_nloc], dim=1) + mix = torch.cat([g1_real, g1_virt], dim=2) + padded = torch.nn.functional.pad( + mix.squeeze(0), + (0, 0, 0, real_pad), + value=0.0, + ) + else: + padded = torch.nn.functional.pad( + g1.squeeze(0), + (0, 0, 0, nall - nloc), + value=0.0, + ) + + exchanged = torch.ops.deepmd_export.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + padded, + comm_dict["communicator"], + comm_dict["nlocal"], + comm_dict["nghost"], + ).unsqueeze(0) + + if has_spin: + ng1 = g1.shape[-1] + real_ext, virt_ext = torch.split(exchanged, [ng1, ng1], dim=2) + return concat_switch_virtual(real_ext, virt_ext, real_nloc) + return exchanged + + +register_dpmodel_mapping( + DescrptBlockRepformersDP, + lambda v: DescrptBlockRepformers.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 61d611036e..45120c6d5d 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -139,6 +139,7 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptSeADP.call.__wrapped__( diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 22302f54e6..ab32be1131 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -128,6 +128,7 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptSeRDP.call.__wrapped__( diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 061306f281..69d6183642 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -139,6 +139,7 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptSeTDP.call.__wrapped__( diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index c0ae308971..cbcaf3822c 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -166,6 +166,7 @@ def call( nlist: torch.Tensor, mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> Any: if not self.compress: return DescrptSeTTebdDP.call.__wrapped__( diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index b28b81ffb1..0ef1f8c0b7 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -280,6 +280,7 @@ def forward_common_atomic( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, extended_coord_corr: torch.Tensor | None = None, + comm_dict: dict | None = None, ) -> dict[str, torch.Tensor]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -288,6 +289,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + comm_dict=comm_dict, ) model_ret = fit_output_to_model_output( atomic_ret, @@ -402,4 +404,109 @@ def fn( model.need_sorted_nlist_for_lower = _orig_need_sort return traced + def forward_common_lower_exportable_with_comm( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + do_atomic_virial: bool = False, + **make_fx_kwargs: Any, + ) -> torch.nn.Module: + """Trace forward_common_lower with comm_dict tensors as positional inputs. + + Used to compile a parallel-inference variant of the model + (.pt2 with-comm artifact) that drives MPI ghost-atom exchange + for GNN descriptors via the opaque + ``deepmd_export::border_op`` wrapper. The comm tensors enter + the exported program as 8 additional positional inputs after + the usual (coord, atype, nlist, mapping, fparam, aparam) — + this fixes the C++ ABI for ``DeepPotPTExpt`` (Phase 4). + + Tracing requires ``nswap >= 1`` (Phase 0 finding); with + ``nswap == 0`` the dim specializes and the artifact would + only run for that exact value. The C++ caller must always + provide ``nswap >= 1``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + # Same nnei-dynamic-axis workaround as the regular variant + # (see ``_pad_nlist_for_export``). Without it the with-comm + # trace specialises ``nnei`` to the sample width. + nlist = _pad_nlist_for_export(nlist) + comm_dict = { + "send_list": send_list, + "send_proc": send_proc, + "recv_proc": recv_proc, + "send_num": send_num, + "recv_num": recv_num, + "communicator": communicator, + "nlocal": nlocal, + "nghost": nghost, + } + return model.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + + # Force the sort branch in ``_format_nlist`` (mirrors the regular + # variant) so the compiled graph's ``nnei`` axis stays dynamic. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType( + lambda self: True, model + ) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal, + nghost, + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced + return CM diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py index e69ee29f5a..707d46f70e 100644 --- a/deepmd/pt_expt/model/spin_model.py +++ b/deepmd/pt_expt/model/spin_model.py @@ -135,6 +135,114 @@ def fn( backbone.need_sorted_nlist_for_lower = _orig_need_sort return traced + def forward_common_lower_exportable_with_comm( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + do_atomic_virial: bool = False, + **make_fx_kwargs: Any, + ) -> torch.nn.Module: + """Spin variant of ``forward_common_lower_exportable_with_comm``. + + Mirrors the non-spin version (see ``make_model.py``) but threads + ``extended_spin`` through and injects ``has_spin`` into + ``comm_dict`` so the pt_expt Repflow/Repformer override takes + the spin branch (split real/virtual + concat_switch_virtual). + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + # Same nnei-dynamic-axis workaround as the regular variant. + nlist = _pad_nlist_for_export(nlist) + comm_dict = { + "send_list": send_list, + "send_proc": send_proc, + "recv_proc": recv_proc, + "send_num": send_num, + "recv_num": recv_num, + "communicator": communicator, + "nlocal": nlocal, + "nghost": nghost, + # Trace-time marker so the override takes the spin path. + # Value is irrelevant — only key presence matters. + "has_spin": torch.tensor( + [1], + dtype=torch.int32, + device=extended_coord.device, + ), + } + return model.forward_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + + # Force the sort branch in ``_format_nlist`` so the compiled + # graph's ``nnei`` axis stays dynamic (mirrors the regular + # spin variant; backbone-level override is required). + backbone = self.backbone_model + _orig_need_sort = backbone.need_sorted_nlist_for_lower + backbone.need_sorted_nlist_for_lower = types.MethodType( + lambda self: True, backbone + ) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal, + nghost, + ) + finally: + backbone.need_sorted_nlist_for_lower = _orig_need_sort + return traced + def forward_common_lower( self, *args: Any, **kwargs: Any ) -> dict[str, torch.Tensor]: diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index efb026f7f1..99da68fe4f 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -22,7 +22,10 @@ # as it's a stateless utility class register_dpmodel_mapping(EnvMat, lambda v: v) +# Register opaque deepmd_export::border_op wrapper (used by GNN MPI +# parallel inference; see comm.py module docstring). # Register fake tensor implementations for custom tabulate ops +from deepmd.pt_expt.utils import comm # noqa: F401 from deepmd.pt_expt.utils import tabulate_ops # noqa: F401 __all__ = [ diff --git a/deepmd/pt_expt/utils/comm.py b/deepmd/pt_expt/utils/comm.py new file mode 100644 index 0000000000..434d2a97b0 --- /dev/null +++ b/deepmd/pt_expt/utils/comm.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Python-side fake / autograd registration for the C++-defined opaque +``deepmd_export::border_op`` and ``deepmd_export::border_op_backward``. + +The op schemas and concrete CPU/CUDA implementations are defined in +``source/op/pt/comm.cc`` (registered under explicit dispatch keys so +``torch.export`` records them as opaque external calls instead of +decomposing into the C++ kernel — which would hit ``data_ptr()`` on +FakeTensors and fail). Defining the schema in C++ also means a +``.pt2`` archive loaded by a pure-C++ process (LAMMPS via +``DeepPotPTExpt``) can dispatch through the registered op without +needing a Python interpreter. + +This module adds the Python-only metadata that the ops still need: + * ``register_fake`` so ``make_fx`` / ``torch.export`` can trace + through them with FakeTensor inputs. + * ``register_autograd`` so ``torch.autograd.grad`` (used inside + ``forward_common_lower_exportable_with_comm``) flows gradients + through the forward op back to its inputs. + +Constraints discovered during de-risking (scratch/derisk_border_op.py): + 1. Both forward and backward outputs must NOT alias their inputs + (the C++ kernels return the same tensor they modified) — the + C++ wrapper layer in ``comm.cc`` clones them before exposing. + 2. The fake impls honour ``g1.dtype`` (no float64 hardcoding). + 3. ``register_autograd`` makes the forward op differentiable; the + backward callback dispatches to the opaque + ``deepmd_export::border_op_backward`` op so ``make_fx`` tracing + through ``autograd.grad`` also sees a black box. +""" + +from __future__ import ( + annotations, +) + +import torch + + +def _check_underlying_ops_loaded() -> None: + """Surface a clearer error when libdeepmd_op_pt.so isn't loaded. + + pt_expt depends on libdeepmd_op_pt.so for the ``deepmd_export::*`` + op schemas + impls. Without it, the ops can't be registered for + fake/autograd metadata and callers get a cryptic AttributeError + on ``torch.ops.deepmd_export.border_op``. + + The .so is loaded as a side effect of ``import deepmd.pt`` (via + ``deepmd/pt/cxx_op.py``). We trigger that import here so callers + don't have to remember to do it first — important for environments + like DDP-spawned subprocesses that re-import modules from scratch + and never see the test conftest's ``import deepmd.pt``. + """ + if not ( + hasattr(torch.ops, "deepmd_export") + and hasattr(torch.ops.deepmd_export, "border_op") + and hasattr(torch.ops.deepmd_export, "border_op_backward") + ): + # Triggers cxx_op.py which torch.ops.load_library's the .so. + try: + import deepmd.pt # noqa: F401 + except Exception: + # If deepmd.pt itself fails to import, fall through to the + # explicit RuntimeError below — clearer than re-raising a + # potentially-unrelated import error. + pass + + if not ( + hasattr(torch.ops, "deepmd_export") + and hasattr(torch.ops.deepmd_export, "border_op") + and hasattr(torch.ops.deepmd_export, "border_op_backward") + ): + raise RuntimeError( + "torch.ops.deepmd_export.{border_op,border_op_backward} " + "are not registered. Build libdeepmd_op_pt.so and ensure " + "deepmd.pt is importable before this module." + ) + + +_check_underlying_ops_loaded() + + +# --------------------------------------------------------------------------- +# Fake (meta) impls — let make_fx / torch.export trace through. +# --------------------------------------------------------------------------- + + +@torch.library.register_fake("deepmd_export::border_op") +def _border_op_fake( + sendlist: torch.Tensor, + sendproc: torch.Tensor, + recvproc: torch.Tensor, + sendnum: torch.Tensor, + recvnum: torch.Tensor, + g1: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(g1) + + +@torch.library.register_fake("deepmd_export::border_op_backward") +def _border_op_backward_fake( + sendlist: torch.Tensor, + sendproc: torch.Tensor, + recvproc: torch.Tensor, + sendnum: torch.Tensor, + recvnum: torch.Tensor, + grad_g1: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(grad_g1) + + +# --------------------------------------------------------------------------- +# Autograd: route the forward op's backward through the backward op so +# ``make_fx`` tracing through ``torch.autograd.grad`` records both as +# opaque external calls. +# --------------------------------------------------------------------------- + + +def _border_op_setup_context( + ctx: torch.autograd.function.FunctionCtx, + inputs: tuple, + output: torch.Tensor, +) -> None: + ( + sendlist, + sendproc, + recvproc, + sendnum, + recvnum, + _g1, + communicator, + nlocal, + nghost, + ) = inputs + ctx.save_for_backward( + sendlist, + sendproc, + recvproc, + sendnum, + recvnum, + communicator, + nlocal, + nghost, + ) + + +def _border_op_backward( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple: + (sendlist, sendproc, recvproc, sendnum, recvnum, communicator, nlocal, nghost) = ( + ctx.saved_tensors + ) + grad_in = torch.ops.deepmd_export.border_op_backward( + sendlist, + sendproc, + recvproc, + sendnum, + recvnum, + grad_output, + communicator, + nlocal, + nghost, + ) + return ( + None, + None, + None, + None, + None, # sendlist..recvnum + grad_in, # g1 + None, + None, + None, # communicator, nlocal, nghost + ) + + +torch.library.register_autograd( + "deepmd_export::border_op", + _border_op_backward, + setup_context=_border_op_setup_context, +) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 7b2559db4f..d85a334493 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import ctypes import json import numpy as np @@ -97,6 +98,87 @@ def _json_to_numpy(model_obj: dict) -> dict: ) +def _needs_with_comm_artifact(model: torch.nn.Module) -> bool: + """Return ``True`` if the model needs a "with-comm" AOTI artifact compiled. + + The with-comm artifact carries the per-layer ``deepmd_export::border_op`` + calls that exchange node-embedding tensors across MPI ranks. Multi-rank + LAMMPS dispatches to it when the descriptor's message passing extends + across rank boundaries (i.e. layers consume neighbour features that + live on a different rank). Non-GNN descriptors and GNN descriptors with + ``use_loc_mapping=True`` keep all per-layer messaging local to each + rank's owned atoms; they need only the regular artifact. + + Delegates to ``descriptor.has_message_passing_across_ranks()``, which + descriptor classes implement explicitly. Returns ``False`` defensively + when the model has no single descriptor (linear/zbl/frozen) or when + the method is somehow missing or raises. + """ + desc = getattr(getattr(model, "atomic_model", None), "descriptor", None) + if desc is None or not hasattr(desc, "has_message_passing_across_ranks"): + return False + try: + return bool(desc.has_message_passing_across_ranks()) + except (AttributeError, NotImplementedError): + return False + + +# Module-level cache for the trace-time sendlist buffer. The pointer +# value embedded in ``send_list_tensor`` references this numpy array's +# data; the array must outlive the trace + export call. Caching here +# (rather than per-call) is fine because the contents are never read by +# the exported graph at runtime — only by the eager call inside +# ``make_fx`` when extracting output keys, and by ``torch.export`` when +# materializing example inputs. +_TRACE_SENDLIST_KEEPALIVE: list[np.ndarray] = [] + + +def _make_comm_sample_inputs( + nloc: int, + nghost: int, + device: torch.device, +) -> tuple[torch.Tensor, ...]: + """Build trivial-but-valid comm tensors for tracing the with-comm variant. + + Phase 0 finding: tracing with ``nswap == 0`` causes the dim to + specialize, so we must use ``nswap >= 1``. We use ``nswap == 1`` + with a single self-send swap whose sendlist points to ``nghost`` + local atoms (the actual indices don't matter for the trace — only + the validity of the pointer matters; ``border_op`` is opaque to + ``torch.export`` via the ``deepmd_export::border_op`` wrapper). + + Returns ``(send_list, send_proc, recv_proc, send_num, recv_num, + communicator, nlocal_ts, nghost_ts)`` — 8 tensors, matching the + canonical positional order of + ``forward_common_lower_exportable_with_comm``. + """ + nswap = 1 + send_count = max(1, nghost) + # The trace-time sendlist must be a real ``int**``: a tensor of + # int64 values, each value the address of a contiguous int32 array. + indices = np.zeros(send_count, dtype=np.int32) + _TRACE_SENDLIST_KEEPALIVE.append(indices) + addr = indices.ctypes.data_as(ctypes.c_void_p).value + send_list = torch.tensor([addr], dtype=torch.int64, device=device) + send_proc = torch.zeros(nswap, dtype=torch.int32, device=device) + recv_proc = torch.zeros(nswap, dtype=torch.int32, device=device) + send_num = torch.tensor([send_count], dtype=torch.int32, device=device) + recv_num = torch.tensor([send_count], dtype=torch.int32, device=device) + communicator = torch.zeros(1, dtype=torch.int64, device=device) + nlocal_ts = torch.tensor(nloc, dtype=torch.int32, device=device) + nghost_ts = torch.tensor(nghost, dtype=torch.int32, device=device) + return ( + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal_ts, + nghost_ts, + ) + + def _make_sample_inputs( model: torch.nn.Module, nframes: int = 1, @@ -200,6 +282,7 @@ def _make_sample_inputs( def _build_dynamic_shapes( *sample_inputs: torch.Tensor | None, has_spin: bool = False, + with_comm_dict: bool = False, model_nnei: int = 1, ) -> tuple: """Build dynamic shape specifications for torch.export. @@ -207,19 +290,44 @@ def _build_dynamic_shapes( Marks nframes, nloc, nall and nnei as dynamic dimensions so the exported program handles arbitrary frame, atom and neighbor counts. + When ``with_comm_dict`` is True, 8 additional comm tensors are + appended to the returned tuple — matching the positional order of + ``forward_common_lower_exportable_with_comm``. ``nswap`` is the + only dynamic dim among them; the rest are scalar or fixed-size. + Parameters ---------- *sample_inputs : torch.Tensor | None - Sample inputs: either 6 tensors (non-spin) or 7 tensors (spin). + Sample inputs: 6 tensors (non-spin) or 7 (spin), optionally + followed by 8 comm tensors when ``with_comm_dict``. has_spin : bool Whether the inputs include an extended_spin tensor. + with_comm_dict : bool + Whether the inputs include the 8 comm tensors. model_nnei : int The model's sum(sel). Used as the min for the dynamic nnei dim. Returns a tuple (not dict) to match positional args of the make_fx traced module, whose arg names may have suffixes like ``_1``. """ - nframes_dim = torch.export.Dim("nframes", min=1) - nall_dim = torch.export.Dim("nall", min=1) + # When tracing the with-comm variant, nframes is static at 1. + # Rationale: pt_expt's Repflow/Repformer parallel-mode override + # mirrors pt's repflows.py:593 ``node_ebd.squeeze(0)`` / + # ``…unsqueeze(0)`` pattern, which only works for nb=1. LAMMPS + # always drives inference with one frame so this matches reality. + # Marking nframes static (not dynamic) means it does not + # participate in duck-sizing — so the nframes==2 collision-avoidance + # chosen for the regular variant is *not* needed here, and the + # static value (1) is safe regardless of other tensors' sizes. + nframes_dim: torch.export.Dim | int = ( + 1 if with_comm_dict else torch.export.Dim("nframes", min=1) + ) + # Spin models double atom count internally (real + virtual). Some + # GNN ops in the spin path generate a min=4 constraint on the + # *pre-doubling* nall axis (matches "Suggested fixes" from + # torch.export's CONSTRAINT_VIOLATION error). Bump the min for spin + # so the export does not error on the inferred guard. + nall_min = 4 if has_spin else 1 + nall_dim = torch.export.Dim("nall", min=nall_min) nloc_dim = torch.export.Dim("nloc", min=1) nnei_dim = torch.export.Dim("nnei", min=max(1, model_nnei)) @@ -227,7 +335,7 @@ def _build_dynamic_shapes( # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) fparam = sample_inputs[5] aparam = sample_inputs[6] - return ( + base = ( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) {0: nframes_dim, 1: nall_dim}, # extended_spin: (nframes, nall, 3) @@ -244,7 +352,7 @@ def _build_dynamic_shapes( # (ext_coord, ext_atype, nlist, mapping, fparam, aparam) fparam = sample_inputs[4] aparam = sample_inputs[5] - return ( + base = ( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) { @@ -257,6 +365,21 @@ def _build_dynamic_shapes( {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam ) + if not with_comm_dict: + return base + + # All 8 comm tensors have static shapes: + # send_list, send_proc, recv_proc, send_num, recv_num: (nswap,) + # communicator: (1,) + # nlocal, nghost: scalar + # nswap is fixed once at LAMMPS init (it depends on the processor + # grid which doesn't change at runtime), so it's safe to bake it + # in as static at the trace value. Marking nswap dynamic instead + # raises a Constraints-violated error because the trace specialises + # it to the sample value (1) downstream of border_op anyway — + # there is no graph variation across nswap values. + return (*base, None, None, None, None, None, None, None, None) + def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: """Collect metadata from the model for C++ inference. @@ -315,6 +438,10 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: if is_spin: meta["ntypes_spin"] = model.spin.get_ntypes_spin() meta["use_spin"] = [bool(v) for v in model.spin.use_spin] + # Whether multi-rank LAMMPS needs a second "with-comm" AOTI artifact + # (per-layer ghost-feature MPI exchange via deepmd_export::border_op). + # The C++ DeepPotPTExpt / DeepSpinPTExpt loaders branch on this flag. + meta["has_comm_artifact"] = _needs_with_comm_artifact(model) return meta @@ -423,11 +550,38 @@ def deserialize_to_file( def _trace_and_export( data: dict, model_json_override: dict | None = None, + with_comm_dict: bool = False, do_atomic_virial: bool = False, ) -> tuple: """Common logic: build model, trace, export. - Returns (exported, metadata, data_for_json, output_keys). + Parameters + ---------- + data + Serialized model dict (with "model" and optionally + "model_def_script" keys). + model_json_override + Optional alternate dict to embed as model.json (used by + ``dp compress`` to store the compressed model dict while + tracing the uncompressed one). + with_comm_dict + If True, trace ``forward_common_lower_exportable_with_comm`` + instead of the regular variant. The resulting exported program + accepts 8 additional positional comm tensors (``send_list``, + ``send_proc``, ``recv_proc``, ``send_num``, ``recv_num``, + ``communicator``, ``nlocal``, ``nghost``) used by the pt_expt + Repflow/Repformer override to drive MPI ghost-atom exchange. + Only valid for models that need cross-rank ghost-feature exchange + (see ``_needs_with_comm_artifact``). + do_atomic_virial + If True, the traced graph computes per-atom virial (extra + autograd.grad backward passes); off by default to keep .pt2 + inference fast. Mirrors PR #5407 in upstream master. + + Returns + ------- + tuple + ``(exported, metadata, data_for_json, output_keys)``. """ from copy import ( deepcopy, @@ -470,19 +624,37 @@ def _trace_and_export( _orig_device = _env.DEVICE _env.DEVICE = torch.device("cpu") try: - nframes = 2 - sample_inputs = _make_sample_inputs(model, nframes=nframes, has_spin=is_spin) - # Collect all dimension sizes except dim-0 (nframes) from every tensor - other_dims: set[int] = set() - for t in sample_inputs: - if t is not None: - other_dims.update(t.shape[1:]) - while nframes in other_dims: - nframes += 1 - if nframes != 2: + if with_comm_dict: + # The pt_expt parallel-mode override (in pt's repflows.py + # line 593 too) uses ``squeeze(0)`` / ``unsqueeze(0)`` on + # ``node_ebd`` and so requires ``nframes == 1``. LAMMPS + # always drives inference with one frame, so this is the + # only realistic shape — and we mark dim 0 static in + # ``_build_dynamic_shapes`` to match. + nframes = 1 sample_inputs = _make_sample_inputs( - model, nframes=nframes, has_spin=is_spin + model, + nframes=nframes, + has_spin=is_spin, ) + else: + nframes = 2 + sample_inputs = _make_sample_inputs( + model, + nframes=nframes, + has_spin=is_spin, + ) + # Collect all dimension sizes except dim-0 (nframes) from every tensor + other_dims: set[int] = set() + for t in sample_inputs: + if t is not None: + other_dims.update(t.shape[1:]) + while nframes in other_dims: + nframes += 1 + if nframes != 2: + sample_inputs = _make_sample_inputs( + model, nframes=nframes, has_spin=is_spin + ) finally: _env.DEVICE = _orig_device @@ -493,40 +665,89 @@ def _trace_and_export( else: ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = sample_inputs + # 3b. Build comm-tensor sample inputs when tracing the with-comm + # variant (only valid for GNN models). The actual values don't + # matter for tracing — only that they're valid tensors of the right + # shape and dtype. See ``_make_comm_sample_inputs``. + if with_comm_dict: + if not _needs_with_comm_artifact(model): + raise ValueError( + "with_comm_dict=True requested but the model's descriptor " + "does not need cross-rank message passing " + "(has_message_passing_across_ranks() is False) — " + "there's nothing to compile." + ) + nloc_sample = nlist_t.shape[1] + nall_sample = ext_atype.shape[1] + nghost_sample = nall_sample - nloc_sample + comm_inputs = _make_comm_sample_inputs( + nloc=nloc_sample, + nghost=nghost_sample, + device=torch.device("cpu"), + ) + sample_inputs = sample_inputs + comm_inputs + # 4. Trace via make_fx on CPU. # This decomposes torch.autograd.grad into aten ops so the resulting # GraphModule no longer contains autograd calls. if is_spin: - traced = model.forward_common_lower_exportable( - ext_coord, - ext_atype, - ext_spin, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - ) + if with_comm_dict: + traced = model.forward_common_lower_exportable_with_comm( + ext_coord, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam, + aparam, + *comm_inputs, + do_atomic_virial=do_atomic_virial, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + else: + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) # 5. Extract output keys from the CPU-traced module. - sample_out = traced( - ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam - ) + sample_out = traced(*sample_inputs) else: - traced = model.forward_common_lower_exportable( - ext_coord, - ext_atype, - nlist_t, - mapping_t, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - ) + if with_comm_dict: + traced = model.forward_common_lower_exportable_with_comm( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + *comm_inputs, + do_atomic_virial=do_atomic_virial, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + else: + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) # 5. Extract output keys from the CPU-traced module. - sample_out = traced(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + sample_out = traced(*sample_inputs) output_keys = list(sample_out.keys()) @@ -536,7 +757,10 @@ def _trace_and_export( # ExportedProgram to the target device afterwards via the official # move_to_device_pass (avoids FakeTensor device-propagation errors). dynamic_shapes = _build_dynamic_shapes( - *sample_inputs, has_spin=is_spin, model_nnei=sum(model.get_sel()) + *sample_inputs, + has_spin=is_spin, + with_comm_dict=with_comm_dict, + model_nnei=sum(model.get_sel()), ) exported = torch.export.export( traced, @@ -583,7 +807,7 @@ def _deserialize_to_file_pte( ) -> None: """Deserialize a dictionary to a .pte model file.""" exported, metadata, data_for_json, output_keys = _trace_and_export( - data, model_json_override, do_atomic_virial + data, model_json_override, do_atomic_virial=do_atomic_virial ) model_def_script = data.get("model_def_script") or {} @@ -608,27 +832,51 @@ def _deserialize_to_file_pt2( Uses torch._inductor.aoti_compile_and_package to compile the exported program into a .pt2 package (ZIP archive with compiled shared libraries), then embeds metadata into the archive. + + For models whose descriptor reports + ``has_message_passing_across_ranks() == True`` (DPA2, DPA3 with + ``use_loc_mapping=False``, or hybrids wrapping such children), + compiles a SECOND ``with-comm`` artifact and packs it alongside the + regular one. The ``with-comm`` variant accepts comm-dict tensors as + additional positional inputs and drives MPI ghost-atom exchange via + ``deepmd_export::border_op``. The C++ ``DeepPotPTExpt`` loader picks + the artifact based on the LAMMPS rank count at runtime. + + Layout inside the .pt2 ZIP (PyTorch 2.11 strict layout): + regular → artifact at ``model/`` (AOTInductor's own layout) + with-comm → ``model/extra/forward_lower_with_comm.pt2`` (nested ZIP) + metadata → ``model/extra/metadata.json`` with + ``has_comm_artifact`` flag. The C++ reader matches + by ``/``-delimited suffix so the legacy root-level + ``extra/`` layout still loads. + + Old .pt2 files (pre-this-change) lack ``has_comm_artifact`` so the + C++ loader must default to ``False`` when the field is missing. """ + import os + import tempfile import zipfile from torch._inductor import ( aoti_compile_and_package, ) + # First artifact: regular (no comm). Always produced. exported, metadata, data_for_json, output_keys = _trace_and_export( - data, model_json_override, do_atomic_virial + data, model_json_override, do_atomic_virial=do_atomic_virial ) + metadata["output_keys"] = output_keys # On CUDA, aggressive kernel fusion (default realize_opcount_threshold=30) # causes NaN in the backward pass (force/virial) of attention-based # descriptors (DPA1, DPA2). Setting threshold=0 prevents fusion and # avoids the NaN. Only applied on CUDA; CPU compilation is unaffected. # - # NOTE: `torch._inductor.config` is a process-wide singleton. The + # NOTE: ``torch._inductor.config`` is a process-wide singleton. The # save/restore pattern here is NOT thread-safe — concurrent AOTInductor - # compilations from multiple threads would race on this global. Callers - # must serialise `.pt2` exports if running under a thread pool. Processes - # are fine (each has its own inductor config). + # compilations from multiple threads would race on this global. Callers + # must serialise ``.pt2`` exports if running under a thread pool. + # Processes are fine (each has its own inductor config). import torch._inductor.config as _inductor_config import deepmd.pt_expt.utils.env as _env @@ -642,13 +890,50 @@ def _deserialize_to_file_pt2( finally: _inductor_config.realize_opcount_threshold = saved_threshold - # Embed metadata into the .pt2 ZIP archive. Entries are placed under - # ``model/extra/`` so the strict PyTorch 2.11 ``load_pt2`` loader - # accepts the archive without emitting the "outdated pt2 file" - # fallback warning. See the module-level comment on - # ``PT2_EXTRA_PREFIX`` for the rationale. + # Second artifact: with-comm. Only for descriptors whose message + # passing extends across rank boundaries. The flag was computed + # from the model in ``_collect_metadata`` and is already in + # ``metadata`` here. + has_comm_artifact = bool(metadata.get("has_comm_artifact")) + with_comm_bytes: bytes | None = None + with_comm_output_keys: list[str] | None = None + if has_comm_artifact: + exported_wc, _meta_wc, _data_wc, with_comm_output_keys = _trace_and_export( + data, + model_json_override, + with_comm_dict=True, + do_atomic_virial=do_atomic_virial, + ) + with tempfile.TemporaryDirectory() as td: + wc_path = os.path.join(td, "forward_lower_with_comm.pt2") + saved_threshold = _inductor_config.realize_opcount_threshold + if is_cuda: + _inductor_config.realize_opcount_threshold = 0 + try: + aoti_compile_and_package(exported_wc, package_path=wc_path) + finally: + _inductor_config.realize_opcount_threshold = saved_threshold + with open(wc_path, "rb") as f: + with_comm_bytes = f.read() + # The output keys are identical between the two artifacts (same + # forward_lower output dict); record only one set in metadata. + # If they ever diverge we'll surface a hard error here. + if with_comm_output_keys != output_keys: + raise RuntimeError( + "with-comm artifact output keys diverge from regular: " + f"regular={output_keys} vs with_comm={with_comm_output_keys}" + ) + + # Embed metadata + supplementary files into the .pt2 ZIP archive. + # Entries are placed under ``model/extra/`` so the strict PyTorch + # 2.11 ``load_pt2`` loader accepts the archive without emitting the + # "outdated pt2 file" fallback warning. See the module-level + # comment on ``PT2_EXTRA_PREFIX`` for the rationale. The C++ reader + # (``commonPTExpt.h::read_zip_entry``) accepts both the legacy + # root-level ``extra/`` layout and the new ``model/extra/`` layout + # via suffix matching, so the with-comm artifact moves with the + # rest. model_def_script = data.get("model_def_script") or {} - metadata["output_keys"] = output_keys with zipfile.ZipFile(model_file, "a") as zf: zf.writestr(PT2_EXTRA_PREFIX + "metadata.json", json.dumps(metadata)) zf.writestr( @@ -659,3 +944,7 @@ def _deserialize_to_file_pt2( PT2_EXTRA_PREFIX + "model.json", json.dumps(data_for_json, separators=(",", ":")), ) + if with_comm_bytes is not None: + zf.writestr( + PT2_EXTRA_PREFIX + "forward_lower_with_comm.pt2", with_comm_bytes + ) diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index 1bcf44a885..3559702f6a 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -16,6 +16,12 @@ #include "DeepPot.h" +// Forward-declare to keep TempFile out of public header. Defined in +// commonPTExpt.h. +namespace deepmd::ptexpt { +class TempFile; +} + namespace torch::inductor { class AOTIModelPackageLoader; } @@ -214,6 +220,14 @@ class DeepPotPTExpt : public DeepPotBackend { at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) std::unique_ptr loader; + // Optional second AOTInductor artifact for the multi-rank GNN code + // path (Phase 4). Loaded only if the .pt2 metadata reports + // ``has_comm_artifact == true`` AND the model has GNN message + // passing. ``with_comm_tempfile_`` owns the extracted nested .pt2 + // for the lifetime of ``with_comm_loader``. + bool has_comm_artifact_ = false; + std::unique_ptr with_comm_tempfile_; + std::unique_ptr with_comm_loader; /** * @brief Multi-frame loop for standalone compute (no nlist). @@ -266,6 +280,24 @@ class DeepPotPTExpt : public DeepPotBackend { const torch::Tensor& fparam, const torch::Tensor& aparam); + /** + * @brief Run the with-comm .pt2 artifact with comm tensors appended. + * + * @param[in] base 4-6 base inputs (coord, atype, nlist, mapping, + * fparam?, aparam?) — same as ``run_model``. + * @param[in] comm_tensors 8 comm tensors in canonical positional + * order: send_list, send_proc, recv_proc, send_num, + * recv_num, communicator, nlocal, nghost. + */ + std::vector run_model_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors); + /** * @brief Extract outputs from flat tensor list using output_keys. */ diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h index 47b38767d4..cc1304c69e 100644 --- a/source/api_cc/include/DeepSpinPTExpt.h +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -14,6 +14,11 @@ #include "DeepSpin.h" +// Forward-declare the temp-file helper from commonPTExpt.h. +namespace deepmd::ptexpt { +class TempFile; +} + namespace torch::inductor { class AOTIModelPackageLoader; } @@ -189,6 +194,10 @@ class DeepSpinPTExpt : public DeepSpinBackend { at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) std::unique_ptr loader; + // Optional with-comm artifact for multi-rank GNN spin inference. + bool has_comm_artifact_ = false; + std::unique_ptr with_comm_tempfile_; + std::unique_ptr with_comm_loader; std::vector run_model(const torch::Tensor& coord, const torch::Tensor& atype, @@ -198,6 +207,20 @@ class DeepSpinPTExpt : public DeepSpinBackend { const torch::Tensor& fparam, const torch::Tensor& aparam); + /** + * @brief Run with-comm spin artifact: 5-7 base inputs (incl. + * extended_spin) + 8 comm tensors. + */ + std::vector run_model_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& spin, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors); + void extract_outputs(std::map& output_map, const std::vector& flat_outputs); diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index db099ed464..910c2f6f7a 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "SimulationRegion.h" @@ -62,6 +63,13 @@ void DeepPotPTExpt::init(const std::string& model, return; } + // Load libdeepmd_op_pt.so so its TORCH_LIBRARY_FRAGMENT entries + // (deepmd::*, deepmd_export::*) are visible to torch's dispatcher + // before the AOTI module loads. Without this, multi-rank GNN .pt2 + // archives fail at pair_style time with + // ``Could not find schema for deepmd_export::border_op``. + deepmd::load_op_library(); + if (!file_content.empty()) { throw deepmd::deepmd_exception( "In-memory file_content loading is not supported for .pt2 models. " @@ -154,6 +162,40 @@ void DeepPotPTExpt::init(const std::string& model, gpu_enabled ? static_cast(gpu_id) : static_cast(-1)); + // Phase 4: load the optional with-comm artifact for multi-rank GNN + // inference. Pre-Phase-3 .pt2 files lack ``has_comm_artifact``; + // default to false so old artifacts keep working. If the metadata + // flag is set but the nested artifact fails to extract or compile, + // keep ``has_comm_artifact_=true`` and let single-rank dispatch + // continue working; multi-rank dispatch then fails fast at + // ``run_model_with_comm()`` rather than silently dropping the MPI + // exchange and producing wrong results. + has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") && + metadata["has_comm_artifact"].as_bool(); + if (has_comm_artifact_) { + try { + // Extract the nested ``extra/forward_lower_with_comm.pt2`` into a + // temp file and load it as a second AOTI module. The TempFile + // unlinks the temp file on destruction. + with_comm_tempfile_ = std::make_unique( + deepmd::ptexpt::TempFile::from_zip_entry( + model, "extra/forward_lower_with_comm.pt2")); + with_comm_loader = + std::make_unique( + with_comm_tempfile_->path(), "model", false, 1, + gpu_enabled ? static_cast(gpu_id) + : static_cast(-1)); + } catch (const std::exception& e) { + std::cerr << "DeepPotPTExpt: failed to load with-comm artifact (" + << e.what() + << "); single-rank inference will still work, but multi-rank " + "LAMMPS dispatch will throw." + << std::endl; + with_comm_tempfile_.reset(); + with_comm_loader.reset(); + } + } + int num_intra_nthreads, num_inter_nthreads; get_env_nthreads(num_intra_nthreads, num_inter_nthreads); if (num_inter_nthreads) { @@ -194,6 +236,43 @@ std::vector DeepPotPTExpt::run_model( return loader->run(inputs); } +std::vector DeepPotPTExpt::run_model_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors) { + if (!with_comm_loader) { + throw deepmd::deepmd_exception( + "run_model_with_comm called but the with-comm artifact is not " + "available. Either the .pt2 file has no with-comm artifact compiled " + "(programming error: the caller should check has_comm_artifact_ " + "before invoking this path), or the artifact was present in the " + ".pt2 metadata but failed to load at init time (see earlier stderr " + "log). Multi-rank LAMMPS requires a working with-comm artifact."); + } + if (comm_tensors.size() != 8) { + throw deepmd::deepmd_exception( + "run_model_with_comm: comm_tensors must contain exactly 8 tensors " + "(send_list, send_proc, recv_proc, send_num, recv_num, " + "communicator, nlocal, nghost). Got " + + std::to_string(comm_tensors.size()) + "."); + } + std::vector inputs = {coord, atype, nlist, mapping}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + for (const auto& t : comm_tensors) { + inputs.push_back(t); + } + return with_comm_loader->run(inputs); +} + void DeepPotPTExpt::extract_outputs( std::map& output_map, const std::vector& flat_outputs) { @@ -349,9 +428,45 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // Run the .pt2 model - auto flat_outputs = run_model(coord_Tensor, atype_Tensor, firstneigh_tensor, - mapping_tensor, fparam_tensor, aparam_tensor); + // Phase 4 dispatch: use the with-comm artifact when LAMMPS is + // running multi-rank. ``lmp_list.nswap > 0`` is the proxy for + // "multi-rank with cross-domain communication"; in single-rank + // mode LAMMPS sets nswap=0. Falling back to the regular artifact + // for nswap=0 is correct because that artifact uses the mapping + // tensor to gather ghost embeddings from local atoms. + std::vector flat_outputs; + bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0; + if (use_with_comm && !with_comm_loader) { + throw deepmd::deepmd_exception( + "Multi-rank LAMMPS requires the with-comm artifact, but it failed " + "to load at init time. See the earlier stderr log for the underlying " + "error."); + } + // When NULL-type atoms exist, remapped storage must outlive comm + // tensors (the int** pointer-array tensor references it). + std::vector> remapped_sendlist; + std::vector remapped_sendlist_ptrs; + std::vector remapped_sendnum, remapped_recvnum; + if (use_with_comm) { + bool has_null_atoms = (nall_real < nall); + std::vector comm_tensors; + if (has_null_atoms) { + comm_tensors = + deepmd::ptexpt::build_comm_tensors_positional_with_virtual_atoms( + lmp_list, fwd_map, nloc, nghost_real, remapped_sendlist, + remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum); + } else { + comm_tensors = deepmd::ptexpt::build_comm_tensors_positional( + lmp_list, lmp_list.sendlist, lmp_list.sendnum, lmp_list.recvnum, nloc, + nghost_real); + } + flat_outputs = run_model_with_comm( + coord_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, + fparam_tensor, aparam_tensor, comm_tensors); + } else { + flat_outputs = run_model(coord_Tensor, atype_Tensor, firstneigh_tensor, + mapping_tensor, fparam_tensor, aparam_tensor); + } // Map flat outputs to internal keys std::map output_map; diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index dcd7df55a4..2ac4369f5f 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "SimulationRegion.h" @@ -62,6 +63,11 @@ void DeepSpinPTExpt::init(const std::string& model, return; } + // Load libdeepmd_op_pt.so so deepmd_export::* schemas are visible + // to torch's dispatcher before the AOTI module loads. See + // DeepPotPTExpt::init for the full rationale. + deepmd::load_op_library(); + if (!file_content.empty()) { throw deepmd::deepmd_exception( "In-memory file_content loading is not supported for .pt2 models. " @@ -166,6 +172,34 @@ void DeepSpinPTExpt::init(const std::string& model, gpu_enabled ? static_cast(gpu_id) : static_cast(-1)); + // Phase 4: load the optional with-comm artifact for multi-rank GNN + // spin inference. Mirrors DeepPotPTExpt; see its init() comment for + // the rationale on keeping ``has_comm_artifact_=true`` on load + // failure so multi-rank dispatch fails fast rather than silently + // dropping the MPI exchange. + has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") && + metadata["has_comm_artifact"].as_bool(); + if (has_comm_artifact_) { + try { + with_comm_tempfile_ = std::make_unique( + deepmd::ptexpt::TempFile::from_zip_entry( + model, "extra/forward_lower_with_comm.pt2")); + with_comm_loader = + std::make_unique( + with_comm_tempfile_->path(), "model", false, 1, + gpu_enabled ? static_cast(gpu_id) + : static_cast(-1)); + } catch (const std::exception& e) { + std::cerr << "DeepSpinPTExpt: failed to load with-comm artifact (" + << e.what() + << "); single-rank inference will still work, but multi-rank " + "LAMMPS dispatch will throw." + << std::endl; + with_comm_tempfile_.reset(); + with_comm_loader.reset(); + } + } + int num_intra_nthreads, num_inter_nthreads; get_env_nthreads(num_intra_nthreads, num_inter_nthreads); if (num_inter_nthreads) { @@ -207,6 +241,42 @@ std::vector DeepSpinPTExpt::run_model( return loader->run(inputs); } +std::vector DeepSpinPTExpt::run_model_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& spin, + const torch::Tensor& nlist, + const torch::Tensor& mapping, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const std::vector& comm_tensors) { + if (!with_comm_loader) { + throw deepmd::deepmd_exception( + "DeepSpinPTExpt::run_model_with_comm called but the with-comm " + "artifact is not available. Either the .pt2 file has no with-comm " + "artifact compiled, or the artifact was present in the .pt2 metadata " + "but failed to load at init time (see earlier stderr log). Multi-rank " + "LAMMPS requires a working with-comm artifact."); + } + if (comm_tensors.size() != 8) { + throw deepmd::deepmd_exception( + "DeepSpinPTExpt::run_model_with_comm: comm_tensors must contain " + "exactly 8 tensors. Got " + + std::to_string(comm_tensors.size()) + "."); + } + std::vector inputs = {coord, atype, spin, nlist, mapping}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + for (const auto& t : comm_tensors) { + inputs.push_back(t); + } + return with_comm_loader->run(inputs); +} + void DeepSpinPTExpt::extract_outputs( std::map& output_map, const std::vector& flat_outputs) { @@ -376,10 +446,44 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // Run the .pt2 model (7 args for spin) - auto flat_outputs = - run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, - mapping_tensor, fparam_tensor, aparam_tensor); + // Phase 4 dispatch: route to with-comm artifact in multi-rank mode. + // ``has_spin=tensor([1])`` is baked into the with-comm graph at + // trace time (Phase 3, spin_model.forward_common_lower_exportable + // _with_comm), so C++ supplies the same 8 comm tensors as the + // non-spin path. ``nlocal``/``nghost`` carry the real-atom counts + // (pre atom-doubling); the spin override halves them internally. + std::vector flat_outputs; + bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0; + if (use_with_comm && !with_comm_loader) { + throw deepmd::deepmd_exception( + "Multi-rank LAMMPS requires the with-comm artifact, but it failed " + "to load at init time. See the earlier stderr log for the underlying " + "error."); + } + std::vector> remapped_sendlist; + std::vector remapped_sendlist_ptrs; + std::vector remapped_sendnum, remapped_recvnum; + if (use_with_comm) { + bool has_null_atoms = (nall_real < nall); + std::vector comm_tensors; + if (has_null_atoms) { + comm_tensors = + deepmd::ptexpt::build_comm_tensors_positional_with_virtual_atoms( + lmp_list, fwd_map, nloc, nghost_real, remapped_sendlist, + remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum); + } else { + comm_tensors = deepmd::ptexpt::build_comm_tensors_positional( + lmp_list, lmp_list.sendlist, lmp_list.sendnum, lmp_list.recvnum, nloc, + nghost_real); + } + flat_outputs = run_model_with_comm( + coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, + mapping_tensor, fparam_tensor, aparam_tensor, comm_tensors); + } else { + flat_outputs = + run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, + mapping_tensor, fparam_tensor, aparam_tensor); + } std::map output_map; extract_outputs(output_map, flat_outputs); diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 7aae726242..0f59bb0e04 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -276,18 +276,18 @@ void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist, int inum = natoms >= 0 ? natoms : inlist.inum; ilist.resize(inum); jlist.resize(inum); + // Guard against an empty subdomain (inum == 0) and atoms with zero + // neighbours (jnum == 0): `&vec[0]` is undefined behaviour for an + // empty vector and libstdc++ debug-mode asserts on it. Use data() + // and skip the copy when the count is zero. if (inum > 0) { - memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); + memcpy(ilist.data(), inlist.ilist, inum * sizeof(int)); } for (int ii = 0; ii < inum; ++ii) { int jnum = inlist.numneigh[ii]; jlist[ii].resize(jnum); - // Guard against empty jlist[ii]: `&vec[0]` is undefined behaviour for - // empty vectors and libstdc++ debug mode asserts on it. This happens - // when a subdomain's local atoms legitimately have zero neighbours - // within cutoff (e.g. under spatial partitioning). if (jnum > 0) { - memcpy(&jlist[ii][0], inlist.firstneigh[ii], jnum * sizeof(int)); + memcpy(jlist[ii].data(), inlist.firstneigh[ii], jnum * sizeof(int)); for (int jj = 0; jj < jnum; ++jj) { jlist[ii][jj] &= inlist.mask; } diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index ddc8ad5014..2d5d773b02 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -1,17 +1,26 @@ // SPDX-License-Identifier: LGPL-3.0-or-later // Shared utilities for pt_expt (.pt2 / AOTInductor) backend classes. -// Provides: JSON parser, ZIP archive reader, and type-sorted nlist builder. +// Provides: JSON parser, ZIP archive reader, type-sorted nlist builder, +// and helpers for the with-comm dual-artifact layout (Phase 4 of the +// GNN MPI plumbing). #pragma once +#include +#include + #include #include +#include +#include #include #include #include #include #include +#include "common.h" // for remap_comm_sendlist #include "errors.h" +#include "neighbor_list.h" namespace deepmd { namespace ptexpt { @@ -438,11 +447,183 @@ inline std::string read_zip_entry(const std::string& zip_path, } // ============================================================================ -// Create raw neighbor list tensor. -// The .pt2 compiled graph already contains format_nlist (distance sort + -// truncation) and an internal +1 pad that guarantees the sort branch fires. -// The C++ side just flattens the jagged nlist into a rectangular tensor. +// With-comm artifact extraction (Phase 4) +// +// GNN .pt2 archives carry a nested ``extra/forward_lower_with_comm.pt2`` +// alongside the regular forward_lower artifact. AOTInductor's +// ``ModelPackageLoader`` reads .pt2 files from disk, so to load the +// nested artifact we extract it to a temp file. +// ============================================================================ + +/** + * @brief RAII handle for a temp file on disk. + * + * Used to hold the extracted with-comm .pt2 artifact for the lifetime + * of the loader. Destructor unlinks the file. + */ +class TempFile { + public: + TempFile() = default; + TempFile(const TempFile&) = delete; + TempFile& operator=(const TempFile&) = delete; + TempFile(TempFile&& other) noexcept : path_(std::move(other.path_)) { + other.path_.clear(); + } + TempFile& operator=(TempFile&& other) noexcept { + if (this != &other) { + cleanup(); + path_ = std::move(other.path_); + other.path_.clear(); + } + return *this; + } + ~TempFile() { cleanup(); } + + const std::string& path() const { return path_; } + bool empty() const { return path_.empty(); } + + /** + * @brief Write the content of an existing .pt2 ZIP entry to a temp + * file and return a TempFile owning that path. + * + * The temp file is created via ``mkstemp(3)`` (atomic, unique, + * 0600 permissions) under the system tempdir (TMPDIR or /tmp). + */ + static TempFile from_zip_entry(const std::string& outer_pt2_path, + const std::string& entry_name) { + std::string content = read_zip_entry(outer_pt2_path, entry_name); + const char* tmpdir = std::getenv("TMPDIR"); + std::string tmpl = + std::string(tmpdir ? tmpdir : "/tmp") + "/dp_pt2_with_comm_XXXXXX"; + std::vector buf(tmpl.begin(), tmpl.end()); + buf.push_back('\0'); + int fd = mkstemp(buf.data()); + if (fd < 0) { + throw deepmd::deepmd_exception( + "Failed to create temp file for nested .pt2 artifact: " + tmpl); + } + std::string path(buf.data()); + // Write content to the fd so we don't race with another process + // opening the same path. + ssize_t written = 0; + const char* p = content.data(); + ssize_t remain = static_cast(content.size()); + while (remain > 0) { + ssize_t n = ::write(fd, p + written, static_cast(remain)); + if (n < 0) { + ::close(fd); + ::unlink(path.c_str()); + throw deepmd::deepmd_exception( + "Failed to write nested .pt2 artifact to temp file: " + path); + } + written += n; + remain -= n; + } + ::close(fd); + TempFile tf; + tf.path_ = std::move(path); + return tf; + } + + private: + void cleanup() { + if (!path_.empty()) { + ::unlink(path_.c_str()); + path_.clear(); + } + } + std::string path_; +}; + +// ============================================================================ +// comm_dict tensor packing for the with-comm artifact (Phase 4) +// +// The with-comm AOTInductor artifact accepts comm tensors as 8 additional +// positional inputs (after the regular 4-6 inputs) in this canonical order: +// send_list (nswap, int64 ptr-array packed as int64 tensor) +// send_proc (nswap, int32) +// recv_proc (nswap, int32) +// send_num (nswap, int32) +// recv_num (nswap, int32) +// communicator (1, int64 — MPI handle as opaque int) +// nlocal (scalar int32) +// nghost (scalar int32) +// This mirrors deepmd_export::border_op's argument order in +// deepmd/pt_expt/utils/comm.py. // ============================================================================ +/** + * @brief Build the 8 comm-tensor positional inputs from LAMMPS data + * (Phase 5 working signature, restored after the consolidation + * attempt regressed). + */ +inline std::vector build_comm_tensors_positional( + const InputNlist& lmp_list, + int** sendlist, + int* sendnum, + int* recvnum, + int nlocal, + int nghost) { + int nswap = lmp_list.nswap; + auto int32_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32); + auto int64_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); + + at::Tensor sendlist_tensor = + torch::from_blob(static_cast(sendlist), {nswap}, int64_option); + at::Tensor sendproc_tensor = + torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); + at::Tensor recvproc_tensor = + torch::from_blob(lmp_list.recvproc, {nswap}, int32_option); + at::Tensor sendnum_tensor = torch::from_blob(sendnum, {nswap}, int32_option); + at::Tensor recvnum_tensor = torch::from_blob(recvnum, {nswap}, int32_option); + + static std::int64_t null_communicator = 0; + at::Tensor communicator_tensor; + if (lmp_list.world == nullptr) { + communicator_tensor = + torch::from_blob(&null_communicator, {1}, int64_option); + } else { + communicator_tensor = + torch::from_blob(const_cast(lmp_list.world), {1}, int64_option); + } + + at::Tensor nlocal_tensor = torch::tensor(nlocal, int32_option); + at::Tensor nghost_tensor = torch::tensor(nghost, int32_option); + + return {sendlist_tensor, sendproc_tensor, recvproc_tensor, sendnum_tensor, + recvnum_tensor, communicator_tensor, nlocal_tensor, nghost_tensor}; +} + +/** + * @brief Build the 8 comm-tensor positional inputs with NULL-type-atom + * remapping. When ``select_real_atoms_coord`` filters atoms (atype < + * 0), ``fwd_map`` translates original sendlist indices into real-atom + * indices (with ``-1`` for filtered). Mirrors + * ``commonPT.h::build_comm_dict_with_virtual_atoms``. The remapped + * storage must outlive the returned tensors. + */ +inline std::vector build_comm_tensors_positional_with_virtual_atoms( + const InputNlist& lmp_list, + const std::vector& fwd_map, + int nlocal, + int nghost, + std::vector>& remapped_sendlist, + std::vector& remapped_sendlist_ptrs, + std::vector& remapped_sendnum, + std::vector& remapped_recvnum) { + remap_comm_sendlist(remapped_sendlist, remapped_sendnum, remapped_recvnum, + lmp_list, fwd_map); + int nswap = lmp_list.nswap; + remapped_sendlist_ptrs.resize(nswap); + for (int s = 0; s < nswap; ++s) { + remapped_sendlist_ptrs[s] = remapped_sendlist[s].data(); + } + return build_comm_tensors_positional(lmp_list, remapped_sendlist_ptrs.data(), + remapped_sendnum.data(), + remapped_recvnum.data(), nlocal, nghost); +} + } // namespace ptexpt } // namespace deepmd diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index a812f776fc..a570747f29 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -11,6 +11,9 @@ if(ENABLE_TENSORFLOW) endif() if(ENABLE_PYTORCH) target_compile_definitions(runUnitTests_cc PRIVATE BUILD_PYTORCH) + # Link torch so __has_include() succeeds and + # BUILD_PT_EXPT is set for the test binary; otherwise pt_expt tests all + # GTEST_SKIP() with "PyTorch support is not enabled". target_link_libraries(runUnitTests_cc "${TORCH_LIBRARIES}") endif() if(ENABLE_JAX) diff --git a/source/api_cc/tests/test_with_comm_load_failure_ptexpt.cc b/source/api_cc/tests/test_with_comm_load_failure_ptexpt.cc new file mode 100644 index 0000000000..10111a41b7 --- /dev/null +++ b/source/api_cc/tests/test_with_comm_load_failure_ptexpt.cc @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Tests for the dispatch-site fail-fast guard when the with-comm AOTI +// artifact failed to load at init time. The fixtures are produced by +// source/tests/infer/gen_corrupt_with_comm.py: copies of the valid +// multi-rank .pt2 archives whose nested +// ``model/extra/forward_lower_with_comm.pt2`` entry has been replaced +// with garbage bytes. The outer metadata still claims +// ``has_comm_artifact: true`` so the loader exercises the catch path. +// +// Expectations: +// * init() succeeds (the loader logs and falls back instead of aborting). +// * Single-rank dispatch (nswap == 0) keeps working through the regular +// forward_lower artifact. +// * Multi-rank dispatch (nswap > 0) throws a deepmd::deepmd_exception +// instead of silently dropping the MPI ghost-embedding exchange. +#include + +#include +#include + +#include "DeepPot.h" +// Include the PT_Expt headers so BUILD_PT_EXPT / BUILD_PT_EXPT_SPIN are +// visible to the GTEST_SKIP guard below. +#include "DeepPotPTExpt.h" +#include "DeepSpin.h" +#include "DeepSpinPTExpt.h" +#include "common.h" +#include "neighbor_list.h" +#include "test_utils.h" + +namespace { +constexpr const char* kPotCorrupt = + "../../tests/infer/deeppot_dpa3_mpi_corrupt_with_comm.pt2"; +constexpr const char* kSpinCorrupt = + "../../tests/infer/deeppot_dpa3_spin_mpi_corrupt_with_comm.pt2"; + +bool file_exists(const char* path) { + std::ifstream f(path); + return f.good(); +} +} // namespace + +// ============================================================================ +// DeepPot (non-spin) — corrupted with-comm artifact +// ============================================================================ + +class TestDeepPotPTExptWithCommLoadFailure : public ::testing::Test { + protected: + // Coordinates / atype / box copied from gen_dpa3.py so the regular + // forward_lower artifact has well-formed inputs to evaluate. + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + + deepmd::DeepPot dp; + + void SetUp() override { +#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT + GTEST_SKIP() << "Skip because PyTorch / pt_expt support is not enabled."; +#endif + if (!file_exists(kPotCorrupt)) { + GTEST_SKIP() << "Skipping: " << kPotCorrupt + << " not found. Run source/tests/infer/" + "gen_corrupt_with_comm.py first."; + } + // Init must succeed: the with-comm loader fails internally and the + // catch block keeps the regular single-rank artifact usable. + ASSERT_NO_THROW(dp.init(kPotCorrupt)); + } +}; + +TEST_F(TestDeepPotPTExptWithCommLoadFailure, single_rank_compute_succeeds) { + // nswap == 0 (default InputNlist) routes through the regular + // forward_lower artifact; the broken with-comm artifact is not + // consulted, so compute must succeed. + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector> nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, atype, + box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(), + firstneigh.data()); + convert_nlist(inlist, nlist_data); + inlist.mapping = mapping.data(); + ASSERT_EQ(inlist.nswap, 0); // pre-condition: single-rank dispatch + + double ener; + std::vector force_, virial; + EXPECT_NO_THROW(dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, + nall - nloc, inlist, 0)); + EXPECT_EQ(force_.size(), nall * 3); + EXPECT_EQ(virial.size(), 9); +} + +TEST_F(TestDeepPotPTExptWithCommLoadFailure, multi_rank_compute_throws) { + // nswap > 0 forces the dispatch site to ``run_model_with_comm``; the + // load-failure guard added by PR #5430 must throw rather than silently + // falling back to the single-rank path. The send/recv arrays remain + // null — the guard fires before any of them are dereferenced. + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector> nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, atype, + box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, ilist.data(), numneigh.data(), + firstneigh.data()); + convert_nlist(inlist, nlist_data); + inlist.mapping = mapping.data(); + inlist.nswap = 1; // simulate multi-rank without populating send/recv + + double ener; + std::vector force_, virial; + EXPECT_THROW(dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, + nall - nloc, inlist, 0), + deepmd::deepmd_exception); +} + +// ============================================================================ +// DeepSpin — corrupted with-comm artifact +// ============================================================================ + +class TestDeepSpinPTExptWithCommLoadFailure : public ::testing::Test { + protected: + std::vector coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; + // Match deeppot_dpa3_spin_mpi.pt2 spin layout (type 0 has spin, types + // 1+ do not) — spin vector packed alongside coord. + std::vector spin = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0., + 0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.}; + std::vector atype = {0, 1, 1, 0, 1, 1}; + std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; + + deepmd::DeepSpin dp; + + void SetUp() override { +#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN + GTEST_SKIP() << "Skip because PyTorch / pt_expt spin support is not " + "enabled."; +#endif + if (!file_exists(kSpinCorrupt)) { + GTEST_SKIP() << "Skipping: " << kSpinCorrupt + << " not found. Run source/tests/infer/" + "gen_corrupt_with_comm.py first."; + } + ASSERT_NO_THROW(dp.init(kSpinCorrupt)); + } +}; + +TEST_F(TestDeepSpinPTExptWithCommLoadFailure, single_rank_compute_succeeds) { + // NoPBC + hardcoded all-pairs nlist mirrors the + // ``cpu_lmp_nlist`` pattern in test_deeppot_dpa_ptexpt_spin.cc: + // nloc == natoms == nall, no ghost atoms. + const int natoms = static_cast(atype.size()); + std::vector empty_box; + std::vector> nlist_data = {{1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, + {0, 1, 3, 4, 5}, {0, 1, 2, 4, 5}, + {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, ilist.data(), numneigh.data(), + firstneigh.data()); + convert_nlist(inlist, nlist_data); + ASSERT_EQ(inlist.nswap, 0); + + double ener; + std::vector force_, force_mag, virial; + EXPECT_NO_THROW(dp.compute(ener, force_, force_mag, virial, coord, spin, + atype, empty_box, 0, inlist, 0)); +} + +TEST_F(TestDeepSpinPTExptWithCommLoadFailure, multi_rank_compute_throws) { + const int natoms = static_cast(atype.size()); + std::vector empty_box; + std::vector> nlist_data = {{1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, + {0, 1, 3, 4, 5}, {0, 1, 2, 4, 5}, + {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, ilist.data(), numneigh.data(), + firstneigh.data()); + convert_nlist(inlist, nlist_data); + inlist.nswap = 1; // simulate multi-rank without populating send/recv + + double ener; + std::vector force_, force_mag, virial; + EXPECT_THROW(dp.compute(ener, force_, force_mag, virial, coord, spin, atype, + empty_box, 0, inlist, 0), + deepmd::deepmd_exception); +} diff --git a/source/lmp/tests/run_mpi_pair_deepmd_dpa3_pt2.py b/source/lmp/tests/run_mpi_pair_deepmd_dpa3_pt2.py new file mode 100644 index 0000000000..042f47c56c --- /dev/null +++ b/source/lmp/tests/run_mpi_pair_deepmd_dpa3_pt2.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Multi-rank LAMMPS driver for DPA3 .pt2 (Phase 5 of GNN MPI). + +Run via ``mpirun -n N python run_mpi_pair_deepmd_dpa3_pt2.py DATAFILE PB_FILE OUTPUT``. +Mirrors ``run_mpi_pair_deepmd.py`` but targets a GNN model whose .pt2 archive +carries the with-comm artifact (Phase 3 dual-artifact layout). The C++ +``DeepPotPTExpt`` (Phase 4) routes to the with-comm artifact when LAMMPS +reports nswap > 0 (multi-rank), driving MPI ghost-atom exchange via +``deepmd_export::border_op`` per layer. + +Rank 0 writes potential energy + per-atom forces (3 cols) + per-atom +virial (9 cols, from ``compute centroid/stress/atom NULL pair`` in +LAMMPS internal units) to ``OUTPUT`` so the parent pytest process can +compare against the single-rank reference. +""" + +from __future__ import ( + annotations, +) + +import argparse + +import numpy as np +from lammps import ( + PyLammps, +) +from mpi4py import ( + MPI, +) + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + +parser = argparse.ArgumentParser() +parser.add_argument("DATAFILE", type=str, help="LAMMPS data file (atom positions)") +parser.add_argument("PB_FILE", type=str, help=".pt2 model file") +parser.add_argument("OUTPUT", type=str, help="Output file for energies + forces") +parser.add_argument( + "--nsteps", + type=int, + default=0, + help="Number of MD steps to run after the initial force evaluation; " + "with --nsteps > 10 (LAMMPS neigh_modify every=10) the dispatch path " + "is exercised across at least one neighbor-list rebuild.", +) +parser.add_argument( + "--processors", + type=str, + default="2 1 1", + help="LAMMPS processors grid. Default '2 1 1' forces multi-rank " + "domain decomposition (nswap>0). Pass '1 1 1' for a single-rank " + "reference run on the same archive (single-artifact dispatch).", +) +parser.add_argument( + "--pair-coeff", + type=str, + default="* *", + help="pair_coeff arguments (after 'pair_coeff'). Default '* *' " + "uses identity LAMMPS-type-to-deepmd-atype mapping (assumes the " + "data file's types match the model's type_map order). For NULL-type " + "tests pass e.g. '* * O H NULL' so the third LAMMPS type becomes " + "deepmd atype=-1 (filtered before model evaluation).", +) +parser.add_argument( + "--mass3", + type=float, + default=None, + help="Optional mass for LAMMPS atom type 3 (and any higher types). " + "Used by the NULL-type fixture; ignored when only 2 types exist.", +) +parser.add_argument( + "--neigh-every", + type=int, + default=10, + help="LAMMPS ``neigh_modify every`` value. Default 10 mirrors the " + "production-realistic interval. Pass 1 for tests that want to " + "trigger nlist rebuilds on every step (and run a small ``--nsteps`` " + "to keep wall time low while still exercising the rebuild path).", +) +parser.add_argument( + "--null-vx", + type=float, + default=None, + help="Optional initial x-velocity (units: Angstrom/ps in metal " + "units) for LAMMPS atom type 3 atoms. Real atoms stay at v=0. " + "Used by the NULL-type rebuild test to make NULL atoms cross the " + "rank boundary in a few MD steps without destabilising real-atom " + "dynamics — the deepmd model never sees NULL atoms (filtered by " + "``select_real_atoms_coord``) so their unphysical velocity is " + "harmless.", +) +parser.add_argument( + "--null-vx-split", + action="store_true", + help="With ``--null-vx X``, split type-3 atoms into two groups by " + "their LAMMPS atom-id parity: even ids get +X, odd ids get -X. " + "Used by the NULL-type rebuild test to send different NULLs in " + "opposite directions, so the cross-rank sendlist composition " + "changes in BOTH directions per rebuild (rank 0 loses one NULL, " + "gains another simultaneously).", +) +parser.add_argument( + "--real-temp", + type=float, + default=None, + help="Optional initial thermal temperature (Kelvin) for non-NULL " + "atoms via ``velocity realgroup create T seed``. Each real atom " + "gets a random thermal velocity in a different direction — used " + "to exercise sendlist composition changes from real-atom motion " + "rather than only from NULL motion.", +) +args = parser.parse_args() + +lammps = PyLammps() +# Force the requested domain decomposition. The default "2 1 1" +# combined with the simulation box guarantees nswap > 0 on the C++ +# side, so DeepPotPTExpt routes to the with-comm AOTI artifact. Pass +# "1 1 1" to obtain a single-rank reference using the same archive +# (the regular artifact handles nswap==0). +lammps.processors(args.processors) +lammps.units("metal") +lammps.boundary("p p p") +lammps.atom_style("atomic") +# ``atom_modify map yes`` is required when single-rank dispatch goes +# through the regular artifact of a use_loc_mapping=False .pt2: the +# C++ side needs the LAMMPS global-id->local-index map to build the +# ``mapping`` tensor. It is harmless under multi-rank. +lammps.atom_modify("map yes") +lammps.neighbor("2.0 bin") +lammps.neigh_modify(f"every {args.neigh_every} delay 0 check no") +lammps.read_data(args.DATAFILE) +lammps.mass("1 16") +lammps.mass("2 2") +if args.mass3 is not None: + # Used by the NULL-type test where the data file has a 3rd LAMMPS + # type that maps to a NULL deepmd atype (filtered before model + # evaluation). The mass value is physically irrelevant — these + # atoms get zero force from the deepmd model. + lammps.mass(f"3 {args.mass3}") +lammps.timestep(0.0005) +lammps.fix("1 all nve") +# Initial velocities. Order matters: thermalize real atoms first +# (``velocity all create`` would also affect type 3, so we restrict +# it to a real-atom group), then set the NULL bias on type 3. +if args.real_temp is not None: + lammps.group("realgroup type 1 2") + lammps.velocity(f"realgroup create {args.real_temp:.6f} 12345 mom yes rot yes") +if args.null_vx is not None: + # Restrict initial velocity to LAMMPS type 3 atoms (NULL-type + # in the deepmd plugin's pair_coeff mapping). Real atoms stay + # at v=0 (or thermal); only the NULL atoms get the high vx, so + # the deepmd model's force outputs on real atoms remain bounded + # and the NVE integrator stays stable. + lammps.group("nullgroup type 3") + if args.null_vx_split: + # Send NULL atoms with even/odd LAMMPS atom-id in opposite + # directions. Hardcoded to the null_type fixture's NULL ids + # (7, 8); sufficient because the runner is only used by this + # branch's tests, not as a general utility. + lammps.group("null_id7 id 7") + lammps.group("null_id8 id 8") + lammps.velocity(f"null_id7 set {-args.null_vx:.6f} 0.0 0.0 units box") + lammps.velocity(f"null_id8 set {args.null_vx:.6f} 0.0 0.0 units box") + else: + lammps.velocity(f"nullgroup set {args.null_vx:.6f} 0.0 0.0 units box") + +lammps.pair_style(f"deepmd {args.PB_FILE}") +lammps.pair_coeff(args.pair_coeff) +# Per-atom virial from the pair contribution. ``centroid/stress/atom`` +# is parallel-safe (rank-local data, gathered below). LAMMPS computes +# stress*volume per atom in internal units; the parent test reverses +# the unit conversion (divide by ``constants.nktv2p``) before comparing +# against the reference virial. +lammps.compute("virial all centroid/stress/atom NULL pair") +lammps.run(0) + +# Optional: run additional MD steps to exercise the with-comm +# dispatch across neighbor-list rebuilds (LAMMPS rebuilds every +# 10 steps with our neigh_modify config, so any nsteps >= 10 +# triggers at least one rebuild). +if args.nsteps > 0: + lammps.run(args.nsteps) + +# Forces need to be gathered across ranks. PyLammps's ``atoms[i]`` +# only exposes rank-local atoms; ``gather_atoms`` returns the global, +# id-ordered array on every rank. We also gather ``id`` and reorder +# explicitly by id rather than trusting an implicit ordering — this +# is robust against subdomain layout, empty subdomains, and any +# future LAMMPS change in gather ordering. +forces_global = lammps.lmp.gather_atoms("f", 1, 3) +ids_global = lammps.lmp.gather_atoms("id", 0, 1) +# Gather the per-atom virial across ranks. ``lmp.gather`` accepts +# named per-atom computes (``c_``) and returns the global, +# id-ordered array on every rank. +virial_global = lammps.lmp.gather("c_virial", 1, 9) +# ``PyLammps.eval`` is rank-0-only. +if rank == 0: + pe_global = lammps.eval("pe") + natoms = lammps.atoms.natoms + forces = np.array(forces_global, dtype=np.float64).reshape(natoms, 3) + virials = np.array(virial_global, dtype=np.float64).reshape(natoms, 9) + ids = np.array(ids_global, dtype=np.int64).reshape(natoms) + # Sort by atom id so output is unambiguously id-ordered (id 1 first). + order = np.argsort(ids) + forces = forces[order] + virials = virials[order] + with open(args.OUTPUT, "w") as f: + f.write(f"{pe_global:.16e}\n") + # Each row: 3 force components followed by 9 virial components. + for fi, vi in zip(forces, virials, strict=True): + row = np.concatenate([fi, vi]) + f.write(" ".join(f"{v:.16e}" for v in row) + "\n") + +MPI.Finalize() diff --git a/source/lmp/tests/run_mpi_pair_deepmd_spin_dpa3_pt2.py b/source/lmp/tests/run_mpi_pair_deepmd_spin_dpa3_pt2.py new file mode 100644 index 0000000000..3637238968 --- /dev/null +++ b/source/lmp/tests/run_mpi_pair_deepmd_spin_dpa3_pt2.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Multi-rank LAMMPS driver for the DPA3 spin GNN .pt2 fixture. + +Mirrors ``run_mpi_pair_deepmd_dpa3_pt2.py`` but for spin models: +``atom_style spin`` / ``pair_style deepspin`` and gathers the +per-atom magnetic force ``fm`` in addition to the normal force and +per-atom virial. The DPA3 spin .pt2 with ``use_loc_mapping=False`` +carries a with-comm AOTI artifact (Phase 3 dual-artifact layout); the +C++ ``DeepSpinPTExpt`` (Phase 4c) routes to it when LAMMPS reports +nswap > 0 (multi-rank), driving MPI ghost-atom exchange via +``deepmd_export::border_op``. + +Rank 0 writes potential energy + per-atom forces (3 cols) + +per-atom force_mag (3 cols) + per-atom virial (9 cols, from +``compute centroid/stress/atom NULL pair`` in LAMMPS internal units) +to ``OUTPUT`` so the parent pytest process can compare against the +single-rank reference. +""" + +from __future__ import ( + annotations, +) + +import argparse + +import numpy as np +from lammps import ( + PyLammps, +) +from mpi4py import ( + MPI, +) + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + +parser = argparse.ArgumentParser() +parser.add_argument( + "DATAFILE", type=str, help="LAMMPS data file (atom positions + spin)" +) +parser.add_argument("PB_FILE", type=str, help=".pt2 model file (spin GNN)") +parser.add_argument( + "OUTPUT", type=str, help="Output file for energies + forces + force_mag + virial" +) +parser.add_argument( + "--nsteps", + type=int, + default=0, + help="Number of MD steps to run after the initial force evaluation. " + "Note: integrating spin requires fix nve/spin which is outside the " + "scope of this multi-rank correctness test; we only run static " + "force/energy evaluations and an optional run > 0 to exercise the " + "with-comm dispatch across neighbour-list rebuilds.", +) +parser.add_argument( + "--processors", + type=str, + default="2 1 1", + help="LAMMPS processors grid. Default '2 1 1' forces multi-rank " + "domain decomposition (nswap>0). Pass '1 1 1' for a single-rank " + "reference run on the same archive.", +) +parser.add_argument( + "--pair-coeff", + type=str, + default="* *", + help="pair_coeff arguments (after 'pair_coeff'). Default '* *' " + "uses identity LAMMPS-type-to-deepmd-atype mapping. For NULL-type " + "tests pass e.g. '* * Ni O NULL' so the third LAMMPS type becomes " + "deepmd atype=-1 (filtered before model evaluation).", +) +parser.add_argument( + "--mass3", + type=float, + default=None, + help="Optional mass for LAMMPS atom type 3 (and any higher types). " + "Used by the NULL-type fixture; ignored when only 2 types exist.", +) +args = parser.parse_args() + +lammps = PyLammps() +lammps.processors(args.processors) +lammps.units("metal") +lammps.boundary("p p p") +lammps.atom_style("spin") +lammps.atom_modify("map yes") +lammps.neighbor("2.0 bin") +lammps.neigh_modify("every 10 delay 0 check no") +lammps.read_data(args.DATAFILE) +lammps.mass("1 58") +lammps.mass("2 16") +if args.mass3 is not None: + # NULL-type fixture: third LAMMPS type maps to deepmd atype=-1 + # via pair_coeff and is filtered before model evaluation. Mass + # is physically irrelevant. + lammps.mass(f"3 {args.mass3}") +lammps.timestep(0.0005) +lammps.fix("1 all nve") + +lammps.pair_style(f"deepspin {args.PB_FILE}") +lammps.pair_coeff(args.pair_coeff) +lammps.compute("virial all centroid/stress/atom NULL pair") +# Per-atom magnetic force components. LAMMPS does not expose ``fm`` +# through the legacy ``extract``/``gather_atoms`` registry, so we go +# via ``compute property/atom fmx fmy fmz`` + ``gather`` to obtain a +# global, id-ordered (nlocal+nghost reduced) array on every rank. +lammps.compute("fmprop all property/atom fmx fmy fmz") +lammps.run(0) + +if args.nsteps > 0: + lammps.run(args.nsteps) + +# All per-atom data goes through the LAMMPS global gather API. +# ``c_fmprop`` is the compute defined above (fmx/fmy/fmz columns). +forces_global = lammps.lmp.gather_atoms("f", 1, 3) +ids_global = lammps.lmp.gather_atoms("id", 0, 1) +virial_global = lammps.lmp.gather("c_virial", 1, 9) +fm_global = lammps.lmp.gather("c_fmprop", 1, 3) + +if rank == 0: + pe_global = lammps.eval("pe") + natoms = lammps.atoms.natoms + forces = np.array(forces_global, dtype=np.float64).reshape(natoms, 3) + fm = np.array(fm_global, dtype=np.float64).reshape(natoms, 3) + virials = np.array(virial_global, dtype=np.float64).reshape(natoms, 9) + ids = np.array(ids_global, dtype=np.int64).reshape(natoms) + order = np.argsort(ids) + forces = forces[order] + fm = fm[order] + virials = virials[order] + with open(args.OUTPUT, "w") as f: + f.write(f"{pe_global:.16e}\n") + # Each row: 3 force + 3 force_mag + 9 virial = 15 columns. + for fi, fmi, vi in zip(forces, fm, virials, strict=True): + row = np.concatenate([fi, fmi, vi]) + f.write(" ".join(f"{v:.16e}" for v in row) + "\n") + +MPI.Finalize() diff --git a/source/lmp/tests/test_lammps_dpa2_pt2.py b/source/lmp/tests/test_lammps_dpa2_pt2.py new file mode 100644 index 0000000000..48ed966605 --- /dev/null +++ b/source/lmp/tests/test_lammps_dpa2_pt2.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Multi-rank LAMMPS test for DPA2 .pt2 (extends GNN MPI Phase 5 to DPA2). + +DPA2's repformer block participates in the per-layer ghost-atom MPI +exchange just like DPA3's repflows; the with-comm AOTInductor artifact +is produced automatically by ``deepmd/pt_expt/utils/serialization.py`` +because ``_has_message_passing`` returns True for any DPA2 model. + +Unlike DPA3 (which has ``use_loc_mapping``), DPA2's repformer always +takes a ``mapping`` tensor, so a single ``deeppot_dpa2.pt2`` already +carries the dual-artifact layout — no separate ``_mpi.pt2`` needed. + +This file targets the gap "DPA2 multi-rank dispatch never tested +end-to-end" recorded in +``memory/gnn_mpi_untested_paths.md::Dispatch wired, no test fixture``. +The reference is a same-archive single-rank run (``mpirun -n 1`` +through the same dual-artifact ``.pt2``); no hardcoded reference +values are needed. +""" + +from __future__ import ( + annotations, +) + +import importlib.util +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import numpy as np +import pytest +from write_lmp_data import ( + write_lmp_data, +) + +# Reuses the same generic mpirun driver as the DPA3 multi-rank tests — +# the script is descriptor-agnostic (just LAMMPS + pair_style deepmd). +RUNNER_PATH = Path(__file__).parent / "run_mpi_pair_deepmd_dpa3_pt2.py" + +pb_file = Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa2.pt2" +data_file = Path(__file__).parent / "data_dpa2_pt2.lmp" + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) + + +def setup_module() -> None: + if os.environ.get("ENABLE_PYTORCH", "1") != "1": + pytest.skip( + "Skip test because PyTorch support is not enabled.", + ) + write_lmp_data(box, coord, type_OH, data_file) + + +def teardown_module() -> None: + if data_file.exists(): + os.remove(data_file) + + +def _run_mpi_subprocess(nprocs: int = 2) -> dict: + """Invoke the generic mpirun driver and parse the output. + + With ``nprocs == 2`` (default) the runner forces ``processors 2 1 1`` + so ``DeepPotPTExpt`` routes to the with-comm artifact. With + ``nprocs == 1`` the runner uses ``processors 1 1 1`` and the C++ + side falls back to the regular artifact — useful as a same-archive + reference for value comparison. + """ + with tempfile.NamedTemporaryFile(mode="r", suffix=".out", delete=False) as f: + out_path = f.name + try: + argv = [ + "mpirun", + "-n", + str(nprocs), + sys.executable, + str(RUNNER_PATH), + str(data_file.resolve()), + str(pb_file.resolve()), + out_path, + ] + if nprocs == 1: + argv.extend(["--processors", "1 1 1"]) + sp.check_call(argv) + with open(out_path) as fh: + lines = fh.read().strip().splitlines() + pe = float(lines[0]) + rows = np.array( + [list(map(float, line.split())) for line in lines[1:]], + dtype=np.float64, + ) + forces = rows[:, :3] + virials = rows[:, 3:] + return {"pe": pe, "forces": forces, "virials": virials} + finally: + if os.path.exists(out_path): + os.remove(out_path) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa2() -> None: + """Multi-rank DPA2 .pt2 dispatch must match the same-archive + single-rank reference for energy, forces, and virial. + + Verifies that: + - ``DeepPotPTExpt::compute`` correctly routes to the with-comm + artifact for DPA2 (descriptor-agnostic dispatch). + - The pt_expt ``DescrptBlockRepformers._exchange_ghosts`` override + drives ``deepmd_export::border_op`` for repformer's per-layer + ghost exchange (the path equivalent to DPA3's repflows). + - Different ``model_nnei`` from DPA3 (DPA2 repformer has nsel=15 + vs DPA3's e_sel=30) — exercises the dynamic-nnei with-comm + trace at a different baked-in value. + + No hardcoded reference; compares against a same-archive single-rank + run (``mpirun -n 1`` + ``processors 1 1 1`` falls back to the + regular artifact). + """ + out_mpi = _run_mpi_subprocess(nprocs=2) + out_ref = _run_mpi_subprocess(nprocs=1) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-12, abs=1e-12) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) diff --git a/source/lmp/tests/test_lammps_dpa3_pt2.py b/source/lmp/tests/test_lammps_dpa3_pt2.py index 806d3f0b46..ecabe25a28 100644 --- a/source/lmp/tests/test_lammps_dpa3_pt2.py +++ b/source/lmp/tests/test_lammps_dpa3_pt2.py @@ -5,7 +5,12 @@ Reference values from source/tests/infer/gen_dpa3.py / C++ test. """ +import importlib.util import os +import shutil +import subprocess as sp +import sys +import tempfile from pathlib import ( Path, ) @@ -24,12 +29,46 @@ ) pb_file = Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa3.pt2" +# Multi-rank-capable variant (use_loc_mapping=False; carries the +# nested forward_lower_with_comm.pt2 artifact). Produced alongside +# deeppot_dpa3.pt2 by source/tests/infer/gen_dpa3.py. +pb_file_mpi = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa3_mpi.pt2" +) ref_file = ( Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa3.expected" ) data_file = Path(__file__).parent / "data_dpa3_pt2.lmp" data_file_si = Path(__file__).parent / "data_dpa3_pt2.si" data_type_map_file = Path(__file__).parent / "data_type_map_dpa3_pt2.lmp" +# Elongated-box variant for the empty-subdomain MPI test: x is +# extended to 30 Å while atoms remain in x ∈ [0.25, 12.83]. Combined +# with ``processors 2 1 1`` this leaves rank 1 (x ≥ 15) with zero +# local atoms — a corner case the comm-dispatch path must handle +# without crashing or producing wrong forces. +data_file_empty_subdomain = Path(__file__).parent / "data_dpa3_pt2_empty_subdomain.lmp" +# NULL-type variant: 6 real atoms (types 1,2) + 2 type-3 atoms straddling +# the x=6.5 rank boundary. With ``pair_coeff * * O H NULL`` LAMMPS type 3 +# maps to deepmd atype=-1, so those atoms are filtered by +# ``select_real_atoms_coord`` and the comm tensors must be remapped via +# ``fwd_map`` before being handed to the with-comm artifact. Forces on +# the 6 real atoms must match the no-NULL baseline; NULL atoms get zero +# force from the deepmd model. +data_file_null_type = Path(__file__).parent / "data_dpa3_pt2_null_type.lmp" +# Isolated-NULL fixture: box=30 Å in x so rank 0 (x ∈ [0, 15]) has a +# subdomain interior that is NOT within rcut of any boundary. With +# rcut=6, boundary-adjacent regions are [0, 6] (PBC of right wall) +# and [9, 15] (left wall of rank 1) — atoms in x in (6, 9) are LOCAL +# but not in any sendlist. Place 1 NULL atom at x=7.5 (in this gap) +# so the remap branch is reached but the sendlists contain no NULL +# entries — exercises ``has_null_atoms=true`` with no-op remap. +data_file_null_isolated = Path(__file__).parent / "data_dpa3_pt2_null_isolated.lmp" +# All-NULL-rank fixture: box=30 Å in x. 6 real atoms in rank 0 +# (x < 13). 2 NULL atoms in rank 1 (x ∈ {20, 25}). Under +# ``processors 2 1 1`` rank 1 owns ONLY NULL atoms, so after +# ``select_real_atoms_coord`` rank 1 has nloc_real=0 (intersection +# of empty-subdomain and NULL-type paths). +data_file_all_null_rank = Path(__file__).parent / "data_dpa3_pt2_all_null_rank.lmp" # Reference values written by source/tests/infer/gen_dpa3.py (PBC case). # Guarded with try/except because gen_dpa3.py only runs when PyTorch is built; @@ -72,10 +111,71 @@ def setup_module() -> None: type_OH, data_file_si, ) + # Elongated x-axis; atoms unchanged. With ``processors 2 1 1`` the + # split is at x = 15 Å and rank 1 owns x ≥ 15, which is empty. + box_empty_subdomain = np.array([0, 30, 0, 13, 0, 13, 0, 0, 0]) + write_lmp_data(box_empty_subdomain, coord, type_OH, data_file_empty_subdomain) + # NULL-type fixture: original 6 real atoms (types 1,2) plus 2 LAMMPS + # type-3 atoms placed within rcut (~6 Å) of real atoms on BOTH sides + # of the x=6.5 rank boundary. The NULL atoms appear in real atoms' + # neighbour lists and in the cross-rank sendlists, so the comm-tensor + # remap (``fwd_map``-based) is genuinely exercised — not trivial. + coord_null_type = np.concatenate( + [ + coord, + np.array( + [ + [5.5, 6.0, 6.0], # rank 0 side, near boundary + [7.5, 7.0, 7.0], # rank 1 side, near boundary + ] + ), + ] + ) + type_null = np.concatenate([type_OH, np.array([3, 3])]) + write_lmp_data(box, coord_null_type, type_null, data_file_null_type) + # Isolated-NULL fixture: same elongated box as empty-subdomain + # plus one NULL atom in rank 0's subdomain interior (x ∈ (6, 9)). + coord_null_isolated = np.concatenate([coord, np.array([[7.5, 6.5, 6.5]])]) + type_null_isolated = np.concatenate([type_OH, np.array([3])]) + write_lmp_data( + box_empty_subdomain, + coord_null_isolated, + type_null_isolated, + data_file_null_isolated, + ) + # All-NULL-rank fixture: box=30 in x. Real atoms in rank 0 + # (their original coords; all x < 13). NULL atoms placed in + # rank 1 (x ∈ {20, 25}). Rank 1 owns ONLY NULL atoms. + coord_all_null_rank = np.concatenate( + [ + coord, + np.array( + [ + [20.0, 6.5, 6.5], + [25.0, 6.5, 6.5], + ] + ), + ] + ) + type_all_null_rank = np.concatenate([type_OH, np.array([3, 3])]) + write_lmp_data( + box_empty_subdomain, + coord_all_null_rank, + type_all_null_rank, + data_file_all_null_rank, + ) def teardown_module() -> None: - for f in [data_file, data_type_map_file, data_file_si]: + for f in [ + data_file, + data_type_map_file, + data_file_si, + data_file_empty_subdomain, + data_file_null_type, + data_file_null_isolated, + data_file_all_null_rank, + ]: if f.exists(): os.remove(f) @@ -240,3 +340,489 @@ def test_pair_deepmd_si(lammps_si) -> None: expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si ) lammps_si.run(1) + + +# --------------------------------------------------------------------------- +# Multi-rank test (Phase 5 of GNN MPI) +# +# Drives the .pt2 model under ``mpirun -n 2`` so the C++ ``DeepPotPTExpt`` +# routes to the with-comm AOTI artifact (Phase 4) and ``border_op`` does +# real MPI ghost exchange between two ranks. The expected energy/forces +# are the same as the single-rank reference (single-rank LAMMPS would +# need ``atom_modify map yes`` to use the regular artifact; multi-rank +# uses the with-comm artifact whose graph reproduces the gather via +# MPI exchange). +# --------------------------------------------------------------------------- + + +def _run_mpi_subprocess( + extra_args: list[str] | None = None, + nprocs: int = 2, + data_path: Path | None = None, + processors: str | None = None, + runner_args: list[str] | None = None, +) -> dict: + """Helper: invoke run_mpi_pair_deepmd_dpa3_pt2.py under + ``mpirun -n `` and return + ``{"pe": float, "forces": (n, 3) array, "virials": (n, 9) array}``. + + With ``nprocs == 1`` the runner is invoked with ``--processors 1 1 1`` + so the C++ side sees ``nswap == 0`` and routes to the regular + (single-rank) artifact of the dual-artifact .pt2 — useful as a + same-archive reference for multi-rank comparisons. + + ``data_path`` (default ``data_file``) selects the LAMMPS data file — + the empty-subdomain test points at a non-default elongated-box file. + + ``processors`` overrides the runner's default decomposition string + (``"2 1 1"``); used by the ``test_*_decomposition`` variants to + exercise 2D / 3D processor grids (Px*Py*Pz must equal nprocs). + """ + if data_path is None: + data_path = data_file + with tempfile.NamedTemporaryFile(mode="r", suffix=".out", delete=False) as f: + out_path = f.name + try: + argv = [ + "mpirun", + "-n", + str(nprocs), + sys.executable, + str(Path(__file__).parent / "run_mpi_pair_deepmd_dpa3_pt2.py"), + str(data_path.resolve()), + str(pb_file_mpi.resolve()), + out_path, + ] + if processors is not None: + argv.extend(["--processors", processors]) + elif nprocs == 1: + argv.extend(["--processors", "1 1 1"]) + if extra_args: + argv.extend(extra_args) + if runner_args: + argv.extend(runner_args) + sp.check_call(argv) + with open(out_path) as fh: + lines = fh.read().strip().splitlines() + pe = float(lines[0]) + rows = np.array( + [list(map(float, line.split())) for line in lines[1:]], + dtype=np.float64, + ) + # Each row is (3 force) + (9 virial); see runner script. + forces = rows[:, :3] + virials = rows[:, 3:] + return {"pe": pe, "forces": forces, "virials": virials} + finally: + if os.path.exists(out_path): + os.remove(out_path) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3() -> None: + """Multi-rank LAMMPS run for DPA3 .pt2 must match the single-rank + reference within numerical tolerance for energy, forces, and virial. + + Forces are the autograd output of energy through the with-comm + graph, so they implicitly validate the backward path of + ``deepmd_export::border_op``. Per-atom virial is gathered from + ``compute centroid/stress/atom NULL pair`` (parallel-safe) — the + earlier deadlock comment was specific to ``compute pressure NULL + virial`` + ``lammps.eval(...)``, which we sidestep entirely. + + Requires the .pt2 archive to carry a with-comm artifact (Phase 3 + output for GNN models). If the archive lacks it, the C++ falls + back to the regular artifact and produces wrong cross-rank values + — which the assertion would catch (loud test failure, not silent). + """ + out = _run_mpi_subprocess() + # Energy matches single-rank reference. + assert out["pe"] == pytest.approx(expected_e) + # Per-atom forces match (atoms in id-sorted order from the + # subprocess script). + for ii in range(6): + np.testing.assert_allclose( + out["forces"][ii], + expected_f[ii], + atol=1e-8, + rtol=0, + ) + # Per-atom virial matches the gen_dpa3.py reference. LAMMPS + # centroid/stress/atom returns components in [xx, yy, zz, xy, xz, + # yz, yx, zx, zy] order; ``expected_v`` columns follow the same + # column-major flattening as the single-rank ``test_pair_deepmd_virial`` + # (which uses idx_map [0, 4, 8, 3, 6, 7, 1, 2, 5] from c_virial[1..9] + # to expected_v columns). The inverse permutation maps + # ``out["virials"]`` columns back to ``expected_v`` columns. + expected_v_to_lammps = [0, 6, 7, 3, 1, 8, 4, 5, 2] + np.testing.assert_allclose( + out["virials"][:, expected_v_to_lammps] / constants.nktv2p, + expected_v, + atol=1e-8, + rtol=0, + ) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_nlist_rebuild() -> None: + """Multi-rank with neighbor-list rebuilds, validated against a + single-rank reference of the same archive and trajectory. + + Uses ``neigh_modify every 1`` so a rebuild happens on every step, + then runs 3 steps — yields 3 rebuilds in roughly 1/8 the wall + time of a 25-step ``every 10`` run. The same trajectory is then + run under ``mpirun -n 1`` (regular-artifact dispatch on the same + dual-artifact .pt2) to obtain a reference; comparing the two + catches a wrong-but-finite force from a dispatch bug. + + NVE is deterministic up to floating-point summation order, so + the cross-rank divergence after 3 steps is bounded by accumulated + round-off — small for a 6-atom system but non-zero, hence the + relaxed (but still tight) tolerances. + """ + runner_args = ["--neigh-every", "1"] + out_mpi = _run_mpi_subprocess( + extra_args=["--nsteps", "3"], nprocs=2, runner_args=runner_args + ) + out_ref = _run_mpi_subprocess( + extra_args=["--nsteps", "3"], nprocs=1, runner_args=runner_args + ) + np.testing.assert_allclose( + out_mpi["forces"], + out_ref["forces"], + atol=1e-6, + rtol=1e-6, + ) + np.testing.assert_allclose( + out_mpi["virials"], + out_ref["virials"], + atol=1e-6, + rtol=1e-6, + ) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-8, abs=1e-10) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_empty_subdomain() -> None: + """Multi-rank DPA3 with one rank owning zero local atoms. + + Runs 5 MD steps with ``neigh_modify every 100`` so the nlist is + rebuilt only once (at step 0, ago=0) and the next 4 force + evaluations exercise the cached ``mapping_tensor`` / + ``firstneigh_tensor`` path (PR 5407 caching) under empty + subdomain. Atoms move ~0 (v=0 default) so positions only differ + by tiny round-off, but the C++ dispatch path with cached state + on rank 1 (which has nloc=0) must still produce correct + cross-rank forces. + + Uses a 30 x 13 x 13 box with all six atoms clustered in x in + [0.25, 12.83]. Under ``processors 2 1 1`` the split is at x = 15 + so rank 1 owns an empty subdomain. The comm-dispatch path must + still produce correct forces and virial (compared against a + same-archive single-rank reference of the same trajectory). + + This catches: zero-length send/recv lists in the comm tensors, + division-by-zero in nlocal-dependent reshapes, silent drop of a + rank's contribution when it has no atoms to evaluate, AND + cache-hit (ago>0) bugs specific to the empty-subdomain rank. + """ + runner_args = ["--neigh-every", "100"] + out_mpi = _run_mpi_subprocess( + nprocs=2, + data_path=data_file_empty_subdomain, + extra_args=["--nsteps", "5"], + runner_args=runner_args, + ) + out_ref = _run_mpi_subprocess( + nprocs=1, + data_path=data_file_empty_subdomain, + extra_args=["--nsteps", "5"], + runner_args=runner_args, + ) + np.testing.assert_allclose( + out_mpi["forces"], out_ref["forces"], atol=1e-6, rtol=1e-6 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-6, rtol=1e-6 + ) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-8, abs=1e-10) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + "nprocs,processors", + [ + # 2D ``2 2 1`` is omitted: ``8 @ 2 2 2`` already exercises 2D + # face exchange (it's a superset, in 3D), so the 2D-only case + # is redundant. The two kept variants give complementary + # coverage: 1D-deep sendlist chains vs 3D border exchange. + (4, "4 1 1"), # 1D-deep chain; sendlist depth = 3 (each pair is 1+2 swaps) + (8, "2 2 2"), # 3D decomposition; full xyz border exchange + ], +) +def test_pair_deepmd_mpi_dpa3_decomposition(nprocs, processors) -> None: + """Multi-rank DPA3 .pt2 must match the single-rank reference under + deeper / 3D processor grids beyond the canonical 2x1x1 (N=2) layout. + + Production MD typically runs with 8/16/32+ ranks and 2D/3D + decompositions. Bugs that don't fire at N=2 (deeper sendlist + chains, 3D border swaps, asymmetric subdomains, multiple empty + cells in the 2x2x2 split of a small fixture) have zero coverage + without this test. + + The 6-atom 13x13x13 fixture is intentionally small relative to + the rank count: in the 2x2x2 split each subdomain is + ~6.5x6.5x6.5 A, so several subdomains are empty — exercising the + empty-subdomain ``copy_from_nlist`` guard fix in 3D. + """ + out_mpi = _run_mpi_subprocess(nprocs=nprocs, processors=processors) + # Step-0 evaluation; bit-exact match expected against the + # gen_dpa3.py-derived reference. + assert out_mpi["pe"] == pytest.approx(expected_e, rel=0, abs=1e-8) + for ii in range(6): + np.testing.assert_allclose( + out_mpi["forces"][ii], expected_f[ii], atol=1e-8, rtol=0 + ) + expected_v_to_lammps = [0, 6, 7, 3, 1, 8, 4, 5, 2] + np.testing.assert_allclose( + out_mpi["virials"][:, expected_v_to_lammps] / constants.nktv2p, + expected_v, + atol=1e-8, + rtol=0, + ) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_null_type() -> None: + """Multi-rank DPA3 .pt2 with NULL-type atoms. + + Exercises ``select_real_atoms_coord`` filtering AND + ``build_comm_tensors_positional_with_virtual_atoms`` remapping + under multi-rank dispatch — neither path was reachable in any + previous test fixture. + + Setup: 6 real atoms (types 1,2) at the canonical positions plus + 2 LAMMPS type-3 atoms straddling the x=6.5 rank boundary. With + ``pair_coeff * * O H NULL`` the type-3 atoms map to deepmd + atype=-1 and are filtered before model evaluation. Because the + NULL atoms sit within rcut of real atoms on BOTH sides of the + boundary, they appear in cross-rank sendlists — forcing the + ``fwd_map``-based remap (which translates unfiltered LAMMPS + indices into filtered real-atom indices, dropping ``-1`` slots). + + Assertions: + - Forces on the 6 real atoms (ids 1..6, id-sorted output) match + the no-NULL baseline ``expected_f`` exactly. NULL atoms don't + contribute to the deepmd model so real-atom forces are + identical to the 6-atom baseline. + - NULL-atom forces (ids 7,8) are zero — the deepmd model is the + only pair_style and skips them entirely. + - Total energy matches ``expected_e``. + - Per-atom virial on real atoms matches ``expected_v``. + """ + out_mpi = _run_mpi_subprocess( + nprocs=2, + data_path=data_file_null_type, + runner_args=["--pair-coeff", "* * O H NULL", "--mass3", "5.0"], + ) + # Forces on real atoms (ids 1..6) match the no-NULL baseline. + real_forces = out_mpi["forces"][:6] + for ii in range(6): + np.testing.assert_allclose(real_forces[ii], expected_f[ii], atol=1e-8, rtol=0) + # NULL atoms (ids 7,8) get zero force from the deepmd model. + null_forces = out_mpi["forces"][6:] + np.testing.assert_allclose(null_forces, 0.0, atol=1e-12, rtol=0) + # Total potential energy unchanged (NULL atoms contribute 0). + assert out_mpi["pe"] == pytest.approx(expected_e, rel=0, abs=1e-8) + # Per-atom virial on real atoms matches expected_v with the same + # column permutation as test_pair_deepmd_mpi_dpa3. + expected_v_to_lammps = [0, 6, 7, 3, 1, 8, 4, 5, 2] + real_virials = out_mpi["virials"][:6] + np.testing.assert_allclose( + real_virials[:, expected_v_to_lammps] / constants.nktv2p, + expected_v, + atol=1e-8, + rtol=0, + ) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_null_isolated() -> None: + """NULL atom local on a rank but absent from every sendlist. + + Box is 30x13x13 with split at x=15. With rcut=6 the boundary + rcut-windows on rank 0 are x ∈ [0, 6] (PBC of right wall via + x=30) and x ∈ [9, 15] (left wall of rank 1). Atoms in + x ∈ (6, 9) are LOCAL on rank 0 but never appear in any + cross-rank sendlist. Placing a NULL atom at x=7.5 puts it in + that gap. + + Coverage: ``has_null_atoms == True`` triggers the + ``_with_virtual_atoms`` branch, but the remap encounters NO + NULL entries in any sendlist (no-op remap). The + ``test_pair_deepmd_mpi_dpa3_null_type`` test exercises the + remap-with-NULLs case; this one pins the + remap-with-no-NULLs-in-sendlist case. + + Comparison is mpi-vs-single-rank on the same fixture (no hardcoded + reference because the box differs from the canonical 13x13x13). + """ + out_mpi = _run_mpi_subprocess( + nprocs=2, + data_path=data_file_null_isolated, + runner_args=["--pair-coeff", "* * O H NULL", "--mass3", "5.0"], + ) + out_ref = _run_mpi_subprocess( + nprocs=1, + data_path=data_file_null_isolated, + runner_args=["--pair-coeff", "* * O H NULL", "--mass3", "5.0"], + ) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=0, abs=1e-8) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_all_null_rank() -> None: + """Rank that owns ONLY NULL atoms (intersection of empty-subdomain + and NULL-type paths). + + Box=30x13x13, split at x=15. Real atoms (types 1,2) are all in + rank 0 (x < 13). NULL atoms (type 3) are at x ∈ {20, 25}, + both in rank 1. After ``select_real_atoms_coord``: + + - Rank 0: nloc_real=6 (all real local), receives NULL atoms as + ghosts via PBC -> filtered -> nall_real ≤ nall. + - Rank 1: nloc_real=0 (all local atoms filtered out — empty + subdomain after filter), receives real atoms as ghosts. + + Tests that the comm-dispatch path handles a rank with zero real + locals correctly. The empty-subdomain ``copy_from_nlist`` guard + must fire on rank 1, AND the ``_with_virtual_atoms`` remap must + handle the case where the local section of the sendlist is + entirely NULL. + """ + out_mpi = _run_mpi_subprocess( + nprocs=2, + data_path=data_file_all_null_rank, + runner_args=["--pair-coeff", "* * O H NULL", "--mass3", "5.0"], + ) + out_ref = _run_mpi_subprocess( + nprocs=1, + data_path=data_file_all_null_rank, + runner_args=["--pair-coeff", "* * O H NULL", "--mass3", "5.0"], + ) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=0, abs=1e-8) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_null_type_nlist_rebuild() -> None: + """NULL atoms cross the boundary in OPPOSITE directions while + real atoms move randomly via thermal motion — sendlist + composition changes both ways per rebuild. + + Initial conditions: + - Real atoms (types 1, 2): thermal velocities at T=10000 K + (``--real-temp 10000``). Each real atom gets a different + random direction; mass-weighted RMS speed is roughly + 3 - 9 A/ps so motion in 3 steps is ~0.005 - 0.015 A. Tiny + but enough to perturb sendlist composition under + ``every 1`` rebuilds. + - NULL atom 7 (id=7) at x=5.5: gets ``v_x = -2000 A/ps`` via + ``--null-vx 2000 --null-vx-split`` (odd id -> negative). + Wraps via PBC: x = 5.5 -> 4.5 -> 3.5 -> 2.5 (stays in rank 0 + but drifts deeper into the PBC ghost region of rank 1). + - NULL atom 8 (id=8) at x=7.5: gets ``v_x = +2000 A/ps`` + (even id -> positive). x = 7.5 -> 8.5 -> 9.5 -> 10.5 (stays + in rank 1 but drifts deeper). + + The +x/-x split means each rebuild sees NULL atoms entering + different sendlists (rank 0's right-edge sendlist gains NULL 7 + even as it loses NULL 8 deeper into rank 1's domain, and vice + versa). Real-atom thermal motion provides additional sendlist + perturbation per atom. + + Coverage: ``has_null_atoms`` must remain True; the + ``_with_virtual_atoms`` remap must produce correct outputs as + NULL atoms migrate in mixed directions and real-atom positions + shift. Compares mpi-2-rank vs mpi-1-rank trajectories + deterministically (both use the same velocity seed 12345). + """ + runner_args = [ + "--pair-coeff", + "* * O H NULL", + "--mass3", + "5.0", + "--neigh-every", + "1", + "--null-vx", + "2000.0", + "--null-vx-split", + "--real-temp", + "10000.0", + ] + out_mpi = _run_mpi_subprocess( + nprocs=2, + data_path=data_file_null_type, + extra_args=["--nsteps", "3"], + runner_args=runner_args, + ) + out_ref = _run_mpi_subprocess( + nprocs=1, + data_path=data_file_null_type, + extra_args=["--nsteps", "3"], + runner_args=runner_args, + ) + np.testing.assert_allclose( + out_mpi["forces"], out_ref["forces"], atol=1e-6, rtol=1e-6 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-6, rtol=1e-6 + ) + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-8, abs=1e-10) diff --git a/source/lmp/tests/test_lammps_dpa3_pt2_fp32.py b/source/lmp/tests/test_lammps_dpa3_pt2_fp32.py new file mode 100644 index 0000000000..1f8eed2512 --- /dev/null +++ b/source/lmp/tests/test_lammps_dpa3_pt2_fp32.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Float32 multi-rank LAMMPS test for DPA3 GNN .pt2. + +The float64 multi-rank test in ``test_lammps_dpa3_pt2.py`` validates the +comm_dict path against a same-archive single-rank reference (atol 1e-8). +This file does the same thing for the float32 variant of the fixture +(``deeppot_dpa3_mpi_fp32.pt2``) — the model and trace are byte-identical +in every respect except ``descriptor.precision``/``fitting_net.precision`` +being set to ``float32``. + +Why a separate test file: + 1. The fp32 fixture is not packaged into ``deeppot_dpa3_mpi.pt2``; + it is a sibling artifact produced by the same gen script. + 2. fp32 needs looser tolerances. The C++ ``border_op`` kernel's + ``forward_t`` template path (chosen automatically via + ``g1.dtype()`` dispatch in ``source/op/pt/comm.cc``) loses ~7 + decimal digits of precision relative to the ``forward_t`` + path. Single-precision GEMM in the AOTI-compiled kernel adds + further drift. + +What this file validates that the float64 test does not: + * ``border_op`` template dispatch on ``g1.dtype() == kFloat`` (vs + ``kDouble``) actually fires under MPI. + * ``register_fake`` returns ``torch.empty_like(g1)`` so the FX trace + preserves float32 dtype through the opaque op. + * ``register_autograd``'s ``border_op_backward`` invocation also + runs under float32, returning float32 gradients. + * MPI exchange uses ``MPI_FLOAT`` (vs ``MPI_DOUBLE``), halving the + bandwidth per ghost atom — relevant for slow interconnects. + +This is a regression-only test for the comm path. It does not pin any +hardcoded numerical values; mpi-2 must agree with same-archive mpi-1 +within float32 tolerances. +""" + +from __future__ import ( + annotations, +) + +import importlib.util +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import numpy as np +import pytest +from write_lmp_data import ( + write_lmp_data, +) + +pb_file_mpi_fp32 = ( + Path(__file__).parent.parent.parent + / "tests" + / "infer" + / "deeppot_dpa3_mpi_fp32.pt2" +) +data_file = Path(__file__).parent / "data_dpa3_pt2_fp32.lmp" + +# Same 6-atom O-H system as the float64 test. ``processors 2 1 1`` +# splits at x=6.5 -> 3 atoms per rank. +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) + + +def setup_module() -> None: + if os.environ.get("ENABLE_PYTORCH", "1") != "1": + pytest.skip("Skip test because PyTorch support is not enabled.") + write_lmp_data(box, coord, type_OH, data_file) + + +def teardown_module() -> None: + if data_file.exists(): + os.remove(data_file) + + +def _run_mpi_subprocess( + nprocs: int, + processors: str | None = None, +) -> dict: + """Run ``run_mpi_pair_deepmd_dpa3_pt2.py`` against the fp32 archive. + + Returns ``{"pe", "forces", "virials"}`` parsed from the runner's + output file. + """ + with tempfile.NamedTemporaryFile(mode="r", suffix=".out", delete=False) as f: + out_path = f.name + try: + argv = [ + "mpirun", + "-n", + str(nprocs), + sys.executable, + str(Path(__file__).parent / "run_mpi_pair_deepmd_dpa3_pt2.py"), + str(data_file.resolve()), + str(pb_file_mpi_fp32.resolve()), + out_path, + ] + if processors is not None: + argv.extend(["--processors", processors]) + elif nprocs == 1: + argv.extend(["--processors", "1 1 1"]) + sp.check_call(argv) + with open(out_path) as fh: + lines = fh.read().strip().splitlines() + pe = float(lines[0]) + rows = np.array( + [list(map(float, line.split())) for line in lines[1:]], + dtype=np.float64, + ) + forces = rows[:, :3] + virials = rows[:, 3:] + return {"pe": pe, "forces": forces, "virials": virials} + finally: + if os.path.exists(out_path): + os.remove(out_path) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_fp32() -> None: + """Float32 DPA3 multi-rank must match same-archive single-rank. + + Tolerances follow standard float32 expectations: + * energy: ``rel=1e-5`` (~7 decimal digits, with mantissa noise) + * force: ``atol=1e-4`` absolute (force magnitudes are O(1e-1) for + this system, so ``rel=1e-3``) + * virial: ``atol=5e-4`` per component + + Single-rank uses the regular artifact (nswap=0); multi-rank uses + the with-comm artifact -- so any divergence beyond float32 noise + is necessarily in the multi-rank dispatch (border_op template + dispatch, MPI_FLOAT exchange, register_fake/register_autograd + dtype handling). + """ + out_mpi = _run_mpi_subprocess(nprocs=2) + out_ref = _run_mpi_subprocess(nprocs=1) + + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-5, abs=1e-7) + np.testing.assert_allclose( + out_mpi["forces"], out_ref["forces"], atol=1e-4, rtol=1e-3 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=5e-4, rtol=1e-3 + ) diff --git a/source/lmp/tests/test_lammps_spin_dpa3_pt2.py b/source/lmp/tests/test_lammps_spin_dpa3_pt2.py new file mode 100644 index 0000000000..7c7c5787a7 --- /dev/null +++ b/source/lmp/tests/test_lammps_spin_dpa3_pt2.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Multi-rank LAMMPS test for the DPA3 spin GNN .pt2 fixture. + +The DPA3 spin .pt2 (``deeppot_dpa3_spin_mpi.pt2``) is generated by +``source/tests/infer/gen_spin.py`` with ``use_loc_mapping=False``, +producing a dual-artifact archive whose nested +``forward_lower_with_comm.pt2`` is selected by ``DeepSpinPTExpt`` when +LAMMPS reports ``nswap > 0`` (multi-rank). This test exercises the +spin GNN multi-rank dispatch end-to-end: + +1. Eager parity is already covered by + ``source/tests/pt_expt/model/test_spin_export_with_comm.py + ::test_spin_dpa3_eager_parity`` (Python override only). +2. AOTI compile of the with-comm artifact is verified at fixture + generation time (``gen_spin.py`` calls ``convert_backend`` which + triggers the compile). +3. **This test** wires the loaded artifact through ``DeepSpinPTExpt``, + ``commPTExpt::build_comm_tensors_positional``, the C++ + ``deepmd_export::border_op`` registration, and real MPI ghost + exchange between two LAMMPS subdomains. A passing test means the + full chain (Python override + AOTI export + C++ load + comm-tensor + build + custom op invocation + MPI exchange) produces forces / + force_mag / virial identical to a same-archive single-rank + reference within numerical tolerance. + +Compares mpi-2 vs same-archive mpi-1 to avoid hardcoding numerical +references (the same approach used for the DPA3 / DPA2 multi-rank +tests). Same-archive means the regular and with-comm artifacts come +from the same trace, so any divergence is purely the multi-rank +dispatch path's responsibility. +""" + +from __future__ import ( + annotations, +) + +import importlib.util +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import numpy as np +import pytest +from write_lmp_data import ( + write_lmp_data_spin, +) + +pb_file_mpi = ( + Path(__file__).parent.parent.parent + / "tests" + / "infer" + / "deeppot_dpa3_spin_mpi.pt2" +) +data_file = Path(__file__).parent / "data_dpa3_spin_pt2.lmp" +# Elongated-box fixture for the spin empty-subdomain MPI test: x is +# extended to 30 A while atoms remain in x in [3, 13]. Combined with +# ``processors 2 1 1`` this leaves rank 1 (x >= 15) with zero local +# atoms, exercising the ``copy_from_nlist`` empty-rank guard for spin. +data_file_empty_subdomain = ( + Path(__file__).parent / "data_dpa3_spin_pt2_empty_subdomain.lmp" +) +# NULL-type fixture: 4 real Ni-O atoms + 2 LAMMPS type-3 atoms +# straddling the x=6.5 rank boundary. With ``pair_coeff * * Ni O NULL`` +# LAMMPS type 3 maps to deepmd atype=-1, so those atoms are filtered +# by ``select_real_atoms_coord`` and the comm tensors must be remapped +# via ``fwd_map`` before being handed to the with-comm artifact. +# Forces / force_mag on the 4 real atoms must match the no-NULL +# baseline (mpi-1 reference run); NULL atoms get zero forces from the +# deepmd model. +data_file_null_type = Path(__file__).parent / "data_dpa3_spin_pt2_null_type.lmp" + +# 4-atom Ni-O system; same layout as ``test_lammps_spin_pt2.py``. With +# ``processors 2 1 1`` the split sits at x=6.5 -> 2 atoms per rank. +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +spin = np.array( + [ + [0, 0, 1.2737], + [0, 0, 1.2737], + [0, 0, 0], + [0, 0, 0], + ] +) +type_NiO = np.array([1, 1, 2, 2]) + + +def setup_module() -> None: + if os.environ.get("ENABLE_PYTORCH", "1") != "1": + pytest.skip( + "Skip test because PyTorch support is not enabled.", + ) + write_lmp_data_spin(box, coord, spin, type_NiO, data_file) + # Elongated x-axis; atoms unchanged. ``processors 2 1 1`` splits + # at x=15 A and rank 1 owns x >= 15, which is empty. + box_empty = np.array([0, 30, 0, 13, 0, 13, 0, 0, 0]) + write_lmp_data_spin(box_empty, coord, spin, type_NiO, data_file_empty_subdomain) + # NULL-type fixture: append 2 LAMMPS type-3 atoms within rcut + # (~4 A) of real atoms on BOTH sides of the x=6.5 rank boundary, + # so they appear in cross-rank sendlists and the fwd_map-based + # comm-tensor remap is genuinely exercised. NULL atoms still need + # spin coordinates (write_lmp_data_spin format); we give them + # zero spin like the type-2 (O) atoms. + coord_null = np.concatenate( + [ + coord, + np.array( + [ + [5.5, 6.0, 6.0], # rank 0 side, near boundary + [7.5, 7.0, 7.0], # rank 1 side, near boundary + ] + ), + ] + ) + spin_null = np.concatenate([spin, np.zeros((2, 3))]) + type_null = np.concatenate([type_NiO, np.array([3, 3])]) + write_lmp_data_spin(box, coord_null, spin_null, type_null, data_file_null_type) + + +def teardown_module() -> None: + for f in [data_file, data_file_empty_subdomain, data_file_null_type]: + if f.exists(): + os.remove(f) + + +def _run_mpi_subprocess( + nprocs: int, + extra_args: list[str] | None = None, + processors: str | None = None, + data_path: Path | None = None, + runner_args: list[str] | None = None, +) -> dict: + """Run ``run_mpi_pair_deepmd_spin_dpa3_pt2.py`` under + ``mpirun -n `` and return + ``{"pe", "forces", "force_mag", "virials"}``. + + ``data_path`` (default ``data_file``) selects the LAMMPS data file + -- the empty-subdomain and NULL-type tests point at non-default + fixtures. ``runner_args`` flows additional flags (e.g. + ``--pair-coeff``, ``--mass3``) to the subprocess runner. + """ + if data_path is None: + data_path = data_file + with tempfile.NamedTemporaryFile(mode="r", suffix=".out", delete=False) as f: + out_path = f.name + try: + argv = [ + "mpirun", + "-n", + str(nprocs), + sys.executable, + str(Path(__file__).parent / "run_mpi_pair_deepmd_spin_dpa3_pt2.py"), + str(data_path.resolve()), + str(pb_file_mpi.resolve()), + out_path, + ] + if processors is not None: + argv.extend(["--processors", processors]) + elif nprocs == 1: + argv.extend(["--processors", "1 1 1"]) + if extra_args: + argv.extend(extra_args) + if runner_args: + argv.extend(runner_args) + sp.check_call(argv) + with open(out_path) as fh: + lines = fh.read().strip().splitlines() + pe = float(lines[0]) + rows = np.array( + [list(map(float, line.split())) for line in lines[1:]], + dtype=np.float64, + ) + # Each row: 3 force + 3 force_mag + 9 virial = 15 cols (see runner). + forces = rows[:, :3] + force_mag = rows[:, 3:6] + virials = rows[:, 6:] + return { + "pe": pe, + "forces": forces, + "force_mag": force_mag, + "virials": virials, + } + finally: + if os.path.exists(out_path): + os.remove(out_path) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_spin() -> None: + """Multi-rank LAMMPS run for spin DPA3 .pt2 must match the + same-archive single-rank reference within numerical tolerance for + energy, forces, force_mag, and per-atom virial. + + Going via mpi-1 (rather than a hardcoded reference array) means we + are validating the multi-rank dispatch path itself, isolated from + any tracing / AOTI precision drift that might appear at fixture + generation time. Single-rank uses the regular artifact (nswap=0); + multi-rank uses the with-comm artifact — so a divergence here is + necessarily a problem in either ``DeepSpinPTExpt`` multi-rank + dispatch, the spin branch of ``_exchange_ghosts``, the C++ + ``deepmd_export::border_op`` invocation, or the comm-tensor + builder. + """ + out_mpi = _run_mpi_subprocess(nprocs=2) + out_ref = _run_mpi_subprocess(nprocs=1) + + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-10, abs=1e-12) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["force_mag"], out_ref["force_mag"], atol=1e-8, rtol=0 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_spin_empty_subdomain() -> None: + """Spin DPA3 multi-rank with one empty rank. + + Elongated x box (30 A) + ``processors 2 1 1`` puts all 4 atoms on + rank 0; rank 1 has nloc=0. Exercises the C++ ``copy_from_nlist`` + empty-rank guard for the spin path (the with-comm artifact still + runs on rank 1 with nloc_real=0). Compares against same-archive + mpi-1 reference. + """ + out_mpi = _run_mpi_subprocess(nprocs=2, data_path=data_file_empty_subdomain) + out_ref = _run_mpi_subprocess(nprocs=1, data_path=data_file_empty_subdomain) + + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-10, abs=1e-12) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["force_mag"], out_ref["force_mag"], atol=1e-8, rtol=0 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +def test_pair_deepmd_mpi_dpa3_spin_null_type() -> None: + """Spin DPA3 multi-rank with NULL-type atoms straddling the + rank boundary. + + Two LAMMPS type-3 atoms (mapped to deepmd atype=-1 via + ``pair_coeff * * Ni O NULL``) sit at x=5.5 and x=7.5, just inside + the rcut window of either side of the x=6.5 boundary. They appear + in the cross-rank sendlists and are filtered by + ``select_real_atoms_coord`` -- so the spin path goes through + ``DeepSpinPTExpt::compute`` with ``nall_real < nall``, triggering + the ``has_null_atoms`` branch that calls + ``build_comm_tensors_positional_with_virtual_atoms`` (fwd_map-based + sendlist remap). Compares mpi-2 vs same-archive mpi-1 reference + (nullifying NULL forces and using the same fwd_map remap on rank 0 + too). + """ + runner_args = ["--pair-coeff", "* * Ni O NULL", "--mass3", "1.0"] + out_mpi = _run_mpi_subprocess( + nprocs=2, data_path=data_file_null_type, runner_args=runner_args + ) + out_ref = _run_mpi_subprocess( + nprocs=1, data_path=data_file_null_type, runner_args=runner_args + ) + + assert out_mpi["pe"] == pytest.approx(out_ref["pe"], rel=1e-10, abs=1e-12) + np.testing.assert_allclose(out_mpi["forces"], out_ref["forces"], atol=1e-8, rtol=0) + np.testing.assert_allclose( + out_mpi["force_mag"], out_ref["force_mag"], atol=1e-8, rtol=0 + ) + np.testing.assert_allclose( + out_mpi["virials"], out_ref["virials"], atol=1e-8, rtol=0 + ) + # Sanity: NULL atoms (ids 5, 6) get exactly zero forces from the + # deepmd model. ``write_lmp_data_spin`` writes atoms in the order + # given (id 1..N), so type-3 NULL atoms are ids 5, 6 (after the 4 + # real Ni-O atoms). + np.testing.assert_array_equal(out_mpi["forces"][4:], np.zeros((2, 3))) + np.testing.assert_array_equal(out_mpi["force_mag"][4:], np.zeros((2, 3))) diff --git a/source/op/pt/comm.cc b/source/op/pt/comm.cc index 97466a4833..31691d5e7d 100644 --- a/source/op/pt/comm.cc +++ b/source/op/pt/comm.cc @@ -139,20 +139,22 @@ class Border : public torch::autograd::Function { } else { #endif #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) -#ifdef USE_MPI - if (cuda_aware == 0) { - memcpy(recv_g1, send_g1, - (unsigned long)nsend * tensor_size * sizeof(FPTYPE)); - } else { + // Self-send branch: choose the host-vs-device memcpy based on + // where the data actually lives, not on MPI state. The buffer + // we read/write is ``recv_g1_tensor`` whose device is either + // (a) the original ``g1`` device, or (b) CPU after the + // non-cuda-aware MPI fallback above. Reading that device + // directly is the only correct dispatch for build configs + // where USE_MPI is on but the call site uses CPU tensors + // (e.g. unit tests of border_op without MPI init). + if (recv_g1_tensor.is_cuda()) { gpuMemcpy(recv_g1, send_g1, (unsigned long)nsend * tensor_size * sizeof(FPTYPE), gpuMemcpyDeviceToDevice); + } else { + memcpy(recv_g1, send_g1, + (unsigned long)nsend * tensor_size * sizeof(FPTYPE)); } -#else - gpuMemcpy(recv_g1, send_g1, - (unsigned long)nsend * tensor_size * sizeof(FPTYPE), - gpuMemcpyDeviceToDevice); -#endif #else memcpy(recv_g1, send_g1, (unsigned long)nsend * tensor_size * sizeof(FPTYPE)); @@ -163,8 +165,29 @@ class Border : public torch::autograd::Function { recv_g1 += nrecv * tensor_size; } #ifdef USE_MPI + // Drain pending eager-send ACKs before returning. In the + // asymmetric ghost-exchange pattern (one rank only Sends, the + // other only Irecvs at a given swap — e.g. an empty subdomain + // under ``processors 2 1 1``) the sender's MPI_Send returns once + // the eager-buffered message is queued, but MPICH's internal + // accounting marks the message as "in flight" until the sender's + // progress engine processes the receiver's ACK. In the symmetric + // case the sender's own MPI_Wait on its Irecv drains those ACKs. + // In the asymmetric case there is no such Wait, and the message + // stays "in flight" all the way to MPI_Finalize, which then + // reports ``Communicator (...) being freed has N unmatched + // message(s)``. An MPI_Barrier on the same communicator forces a + // round-trip on every rank, drains ACKs, and clears the counter. + if (mpi_init && world_size >= 1) { + MPI_Barrier(world); + } #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) - if (cuda_aware == 0) { + // Only copy back when ``recv_g1_tensor`` was actually moved to a + // different device above (the cuda_aware==0 CPU fallback). When + // ``recv_g1_tensor`` still aliases ``g1`` no copy is needed; the + // is_alias_of check is the precise correctness condition and works + // for both CUDA and CPU call sites. + if (!recv_g1_tensor.is_alias_of(g1)) { g1.copy_(recv_g1_tensor); } #endif @@ -174,17 +197,6 @@ class Border : public torch::autograd::Function { static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { - bool type_flag = (grad_output[0].dtype() == torch::kDouble) ? true : false; - if (type_flag) { - return backward_t(ctx, grad_output); - } else { - return backward_t(ctx, grad_output); - } - } - template - static torch::autograd::variable_list backward_t( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); torch::Tensor sendlist_tensor = saved_variables[0]; torch::Tensor sendproc_tensor = saved_variables[1]; @@ -194,8 +206,41 @@ class Border : public torch::autograd::Function { torch::Tensor communicator_tensor = saved_variables[5]; torch::Tensor nlocal_tensor = saved_variables[6]; torch::Tensor nghost_tensor = saved_variables[7]; + torch::Tensor d_in = border_op_backward_dispatch( + sendlist_tensor, sendproc_tensor, recvproc_tensor, sendnum_tensor, + recvnum_tensor, grad_output[0], communicator_tensor, nlocal_tensor, + nghost_tensor); + return {torch::Tensor(), torch::Tensor(), torch::Tensor(), + torch::Tensor(), torch::Tensor(), d_in, + torch::Tensor(), torch::Tensor(), torch::Tensor(), + torch::Tensor()}; + } + + // Forward declaration; defined as a free function below so it can be + // registered as a separate torch op (deepmd::border_op_backward) used by + // the pt_expt opaque-op autograd wrapper. + static torch::Tensor border_op_backward_dispatch( + const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& grad_g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor); - torch::Tensor d_local_g1_tensor = grad_output[0].contiguous(); + template + static torch::Tensor backward_t(const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& grad_g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + torch::Tensor d_local_g1_tensor = grad_g1.contiguous(); #ifdef USE_MPI int mpi_init = 0; MPI_Initialized(&mpi_init); @@ -216,8 +261,8 @@ class Border : public torch::autograd::Function { cuda_aware = MPIX_Query_cuda_support(); #endif if (cuda_aware == 0) { - d_local_g1_tensor = torch::empty_like(grad_output[0]).to(torch::kCPU); - d_local_g1_tensor.copy_(grad_output[0]); + d_local_g1_tensor = torch::empty_like(grad_g1).to(torch::kCPU); + d_local_g1_tensor.copy_(grad_g1); } } #endif @@ -282,20 +327,20 @@ class Border : public torch::autograd::Function { #endif if (nrecv) { #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) -#ifdef USE_MPI - if (cuda_aware == 0) { - memcpy(recv_g1, send_g1, - (unsigned long)nrecv * tensor_size * sizeof(FPTYPE)); - } else { + // Self-send branch: dispatch on the actual device of the + // ``recv_g1_tensor`` buffer, not on MPI state. Same rationale + // as the forward kernel — USE_MPI builds may be called with + // CPU tensors (unit tests of border_op_backward) where the + // gpuMemcpy path silently fails with cudaErrorInvalidValue + // and leaves recv_g1 uninitialized. + if (recv_g1_tensor.is_cuda()) { gpuMemcpy(recv_g1, send_g1, (unsigned long)nrecv * tensor_size * sizeof(FPTYPE), gpuMemcpyDeviceToDevice); + } else { + memcpy(recv_g1, send_g1, + (unsigned long)nrecv * tensor_size * sizeof(FPTYPE)); } -#else - gpuMemcpy(recv_g1, send_g1, - (unsigned long)nrecv * tensor_size * sizeof(FPTYPE), - gpuMemcpyDeviceToDevice); -#endif #else memcpy(recv_g1, send_g1, (unsigned long)nrecv * tensor_size * sizeof(FPTYPE)); @@ -310,17 +355,27 @@ class Border : public torch::autograd::Function { } } #ifdef USE_MPI + // Drain pending eager-send ACKs before returning — see forward_t + // for the full rationale. Backward has the same asymmetric + // Send/Irecv pattern (now in the reverse direction) and the same + // unmatched-message trap when one rank only Sends. + if (mpi_init && world_size >= 1) { + MPI_Barrier(world); + } #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) - if (cuda_aware == 0) { - grad_output[0].copy_(d_local_g1_tensor); + // Move result back to the device of the input grad only when + // ``d_local_g1_tensor`` was actually moved to a different device + // above (the cuda_aware==0 CPU fallback). The is_alias_of check + // is the precise correctness condition and works for both CUDA + // and CPU call sites (no-op when the tensor still aliases input). + if (!d_local_g1_tensor.is_alias_of(grad_g1)) { + d_local_g1_tensor = d_local_g1_tensor.to(grad_g1.device()); } #endif #endif - - return {torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), - torch::Tensor(), grad_output[0], torch::Tensor(), torch::Tensor(), - torch::Tensor(), torch::Tensor()}; + return d_local_g1_tensor; } + #ifdef USE_MPI static void unpack_communicator(const torch::Tensor& communicator_tensor, MPI_Comm& mpi_comm) { @@ -363,4 +418,170 @@ std::vector border_op(const torch::Tensor& sendlist_tensor, communicator_tensor, nlocal_tensor, nghost_tensor); } -TORCH_LIBRARY_FRAGMENT(deepmd, m) { m.def("border_op", border_op); } +// Define Border::border_op_backward_dispatch out-of-line so the type-flag +// dispatch can refer to the templated backward_t members declared in the +// class. +torch::Tensor Border::border_op_backward_dispatch( + const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& grad_g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + bool type_flag = (grad_g1.dtype() == torch::kDouble); + if (type_flag) { + return backward_t(sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, grad_g1, + communicator_tensor, nlocal_tensor, + nghost_tensor); + } else { + return backward_t(sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, grad_g1, + communicator_tensor, nlocal_tensor, nghost_tensor); + } +} + +/** + * @brief Standalone backward of border_op for use by pt_expt's opaque-op + * autograd wrapper. Performs the symmetric MPI exchange that the autograd + * Border::backward applies, but without an autograd context — comm tensors + * are passed directly so the op can be registered as a torch op and + * embedded in an AOTInductor graph. + * + * The comm topology is symmetric: the same sendlist/sendnum/recvnum buffers + * encode the forward exchange; backward simply swaps send <-> recv and + * accumulates gradients into the local atom slots. + * + * @param[in] sendlist_tensor send-list pointer-array (forward direction) + * @param[in] sendproc_tensor send-proc IDs (forward direction) + * @param[in] recvproc_tensor recv-proc IDs (forward direction) + * @param[in] sendnum_tensor atoms sent per swap (forward direction) + * @param[in] recvnum_tensor atoms received per swap (forward direction) + * @param[in] grad_g1 upstream gradient w.r.t. g1 of forward + * @param[in] communicator_tensor MPI communicator handle as int64 + * @param[in] nlocal_tensor number of local atoms (per rank) + * @param[in] nghost_tensor number of ghost atoms (per rank) + * @return d_in (gradient w.r.t. forward g1 input), same shape as grad_g1. + */ +torch::Tensor border_op_backward(const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& grad_g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + return Border::border_op_backward_dispatch( + sendlist_tensor, sendproc_tensor, recvproc_tensor, sendnum_tensor, + recvnum_tensor, grad_g1, communicator_tensor, nlocal_tensor, + nghost_tensor); +} + +TORCH_LIBRARY_FRAGMENT(deepmd, m) { + m.def("border_op", border_op); + m.def("border_op_backward", border_op_backward); +} + +// ============================================================================ +// Opaque wrappers for the pt_expt (.pt2 / AOTInductor) export path. +// +// ``deepmd::border_op`` and ``deepmd::border_op_backward`` are registered +// without an explicit dispatch key, which makes them +// ``CompositeImplicitAutograd`` ops. ``torch.export`` decomposes such ops +// during tracing — i.e., it tries to inline the C++ kernel — and that +// fails because the kernel calls ``data_ptr()`` on FakeTensors. +// +// These ``deepmd_export::*`` wrappers are registered with explicit +// ``CPU`` and ``CUDA`` dispatch keys so ``torch.export`` records them as +// opaque external calls in the graph. The .pt2 archive embeds the call +// sites; at runtime the dispatcher routes back to the underlying +// ``deepmd::*`` op. Both clones because ``deepmd::border_op`` returns +// the same tensor it modified in place, which violates AOTInductor's +// no-aliasing rule for graph outputs. +// +// Python (``deepmd/pt_expt/utils/comm.py``) layers ``register_fake`` and +// ``register_autograd`` on top of these C++-defined ops so traced graphs +// can run their fake/backward. +// ============================================================================ + +namespace { +// ``DEEPMD_MAYBE_UNUSED`` silences CodeQL's ``cpp/unused-static-function`` +// query — the functions ARE used: ``TORCH_LIBRARY_IMPL(...)`` below +// registers them as op implementations via function-pointer arguments, +// which CodeQL's static dataflow can't see through. The attribute is +// C++17, so guard it for the legacy-torch (< 2.1) build path which +// CMakeLists.txt holds at C++14. +#if __cplusplus >= 201703L +#define DEEPMD_MAYBE_UNUSED [[maybe_unused]] +#else +#define DEEPMD_MAYBE_UNUSED +#endif +DEEPMD_MAYBE_UNUSED torch::Tensor border_op_export( + const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& g1_tensor, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + auto out = border_op(sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, g1_tensor, + communicator_tensor, nlocal_tensor, nghost_tensor); + // border_op returns {g1_tensor} — a list whose first element aliases + // g1_tensor. Clone for AOTI graph-output correctness. + if (out.empty()) { + throw std::runtime_error( + "border_op_export: border_op returned an empty output list, which " + "indicates an internal error in the underlying border_op kernel."); + } + return out[0].clone(); +} + +DEEPMD_MAYBE_UNUSED torch::Tensor border_op_backward_export( + const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& grad_g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + return border_op_backward(sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, grad_g1, + communicator_tensor, nlocal_tensor, nghost_tensor) + .clone(); +} +} // namespace +#undef DEEPMD_MAYBE_UNUSED + +TORCH_LIBRARY_FRAGMENT(deepmd_export, m) { + m.def( + "border_op(Tensor sendlist, Tensor sendproc, Tensor recvproc, " + "Tensor sendnum, Tensor recvnum, Tensor g1, Tensor communicator, " + "Tensor nlocal, Tensor nghost) -> Tensor"); + m.def( + "border_op_backward(Tensor sendlist, Tensor sendproc, Tensor recvproc, " + "Tensor sendnum, Tensor recvnum, Tensor grad_g1, Tensor communicator, " + "Tensor nlocal, Tensor nghost) -> Tensor"); +} + +// Register CPU + CUDA implementations under explicit dispatch keys so +// torch.export sees opaque external calls (vs CompositeImplicitAutograd +// which gets decomposed during trace). +TORCH_LIBRARY_IMPL(deepmd_export, CPU, m) { + m.impl("border_op", border_op_export); + m.impl("border_op_backward", border_op_backward_export); +} +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) +TORCH_LIBRARY_IMPL(deepmd_export, CUDA, m) { + m.impl("border_op", border_op_export); + m.impl("border_op_backward", border_op_backward_export); +} +#endif diff --git a/source/tests/infer/gen_corrupt_with_comm.py b/source/tests/infer/gen_corrupt_with_comm.py new file mode 100644 index 0000000000..ff0d16158c --- /dev/null +++ b/source/tests/infer/gen_corrupt_with_comm.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Generate ``deeppot_*_corrupt_with_comm.pt2`` fixtures. + +The fixtures are copies of the corresponding multi-rank ``.pt2`` archives +in which the nested ``model/extra/forward_lower_with_comm.pt2`` entry has +been overwritten with garbage bytes. The outer metadata still claims +``has_comm_artifact: true``, so: + +- ``DeepPotPTExpt::init`` / ``DeepSpinPTExpt::init`` exercise the + try/catch fallback path on the with-comm AOTI loader. +- Single-rank dispatch (``nswap == 0``) keeps working via the regular + artifact. +- Multi-rank dispatch (``nswap > 0``) hits the explicit dispatch-site + throw added in PR #5430, instead of silently dropping the MPI + ghost-embedding exchange. + +Consumed by ``source/api_cc/tests/test_with_comm_load_failure_ptexpt.cc``. +""" + +import os +import zipfile + +WITH_COMM_ENTRY = "model/extra/forward_lower_with_comm.pt2" +GARBAGE = b"NOT_A_VALID_AOTI_ARCHIVE_" * 32 + + +def corrupt_with_comm(src: str, dst: str) -> None: + """Copy ``src`` to ``dst`` with the nested with-comm entry replaced.""" + with ( + zipfile.ZipFile(src, "r") as zin, + zipfile.ZipFile(dst, "w", compression=zipfile.ZIP_STORED) as zout, + ): + replaced = False + for info in zin.infolist(): + data = zin.read(info.filename) + if info.filename == WITH_COMM_ENTRY: + data = GARBAGE + replaced = True + zout.writestr(info, data) + if not replaced: + raise RuntimeError( + f"{src} does not contain {WITH_COMM_ENTRY}; cannot corrupt." + ) + + +def main() -> None: + base_dir = os.path.dirname(__file__) + pairs = [ + ("deeppot_dpa3_mpi.pt2", "deeppot_dpa3_mpi_corrupt_with_comm.pt2"), + ( + "deeppot_dpa3_spin_mpi.pt2", + "deeppot_dpa3_spin_mpi_corrupt_with_comm.pt2", + ), + ] + for src_name, dst_name in pairs: + src = os.path.join(base_dir, src_name) + dst = os.path.join(base_dir, dst_name) + if not os.path.exists(src): + print(f"Skipping {dst_name}: source {src_name} not found.") # noqa: T201 + continue + corrupt_with_comm(src, dst) + print(f"Wrote {dst}") # noqa: T201 + + +if __name__ == "__main__": + main() diff --git a/source/tests/infer/gen_dpa2.py b/source/tests/infer/gen_dpa2.py index fba5341b02..093000bc59 100644 --- a/source/tests/infer/gen_dpa2.py +++ b/source/tests/infer/gen_dpa2.py @@ -108,6 +108,11 @@ def main(): pt2_path = os.path.join(base_dir, "deeppot_dpa2.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 + # DPA2's repformer block has no ``use_loc_mapping`` knob (unlike + # DPA3), so a single .pt2 already carries the dual-artifact layout + # (regular + with-comm) — ``has_message_passing_across_ranks`` + # returns True and the serializer produces both. No separate _mpi.pt2 + # needed. pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data), do_atomic_virial=True) pth_path = os.path.join(base_dir, "deeppot_dpa2.pth") diff --git a/source/tests/infer/gen_dpa3.py b/source/tests/infer/gen_dpa3.py index bf0d0967ae..c0a4434e33 100644 --- a/source/tests/infer/gen_dpa3.py +++ b/source/tests/infer/gen_dpa3.py @@ -88,6 +88,53 @@ def main(): print(f"Exporting to {pt2_path} ...") # noqa: T201 pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data), do_atomic_virial=True) + # Multi-rank LAMMPS variant (use_loc_mapping=False) — produces a + # dual-artifact .pt2 with the with-comm AOTI module nested inside + # so the C++ DeepPotPTExpt routes to it under mpirun. See + # source/lmp/tests/test_lammps_dpa3_pt2.py::test_pair_deepmd_mpi_dpa3. + config_mpi = copy.deepcopy(config) + config_mpi["descriptor"]["use_loc_mapping"] = False + # Defensive deep copy: get_model is allowed to mutate its argument + # in place, and we still need ``config_mpi`` intact below for + # ``model_def_script``. + model_mpi = get_model(copy.deepcopy(config_mpi)) + data_mpi = { + "model": model_mpi.serialize(), + "model_def_script": config_mpi, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + pt2_mpi_path = os.path.join(base_dir, "deeppot_dpa3_mpi.pt2") + print(f"Exporting to {pt2_mpi_path} ...") # noqa: T201 + pt_expt_deserialize_to_file( + pt2_mpi_path, copy.deepcopy(data_mpi), do_atomic_virial=True + ) + + # Float32 multi-rank variant — same architecture as the float64 + # MPI fixture but with ``precision: float32``. Used by + # source/lmp/tests/test_lammps_dpa3_pt2_fp32.py to validate that + # the comm_dict path (border_op + register_fake/register_autograd) + # is dtype-agnostic in practice, not just by inspection. + config_mpi_fp32 = copy.deepcopy(config_mpi) + config_mpi_fp32["descriptor"]["precision"] = "float32" + config_mpi_fp32["fitting_net"]["precision"] = "float32" + model_mpi_fp32 = get_model(copy.deepcopy(config_mpi_fp32)) + data_mpi_fp32 = { + "model": model_mpi_fp32.serialize(), + "model_def_script": config_mpi_fp32, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + pt2_mpi_fp32_path = os.path.join(base_dir, "deeppot_dpa3_mpi_fp32.pt2") + print(f"Exporting to {pt2_mpi_fp32_path} ...") # noqa: T201 + pt_expt_deserialize_to_file( + pt2_mpi_fp32_path, + copy.deepcopy(data_mpi_fp32), + do_atomic_virial=True, + ) + pth_path = os.path.join(base_dir, "deeppot_dpa3.pth") print(f"Exporting to {pth_path} ...") # noqa: T201 try: diff --git a/source/tests/infer/gen_spin.py b/source/tests/infer/gen_spin.py index c7c170b7d1..c17546504b 100644 --- a/source/tests/infer/gen_spin.py +++ b/source/tests/infer/gen_spin.py @@ -84,6 +84,66 @@ def _build_yaml(yaml_path: str) -> None: save_dp_model(yaml_path, data) +def _build_dpa3_mpi_yaml(yaml_path: str) -> None: + """Build a DPA3 spin model with use_loc_mapping=False (multi-rank). + + The default ``deeppot_dpa_spin.yaml`` uses se_atten (DPA1) which + is non-GNN — single-artifact .pt2, no multi-rank ghost exchange. + This variant uses DPA3 (repflows, GNN) with use_loc_mapping=False + so the dual-artifact .pt2 carries the with-comm AOTI module that + DeepSpinPTExpt routes to under mpirun > 1. + + Type map matches the existing 4-atom Ni-O test data + (``write_lmp_data_spin``): two types, Ni (spin-active), O (no spin). + """ + from deepmd.dpmodel.model.model import ( + get_model, + ) + from deepmd.dpmodel.utils.serialization import ( + save_dp_model, + ) + + config = { + "type_map": ["Ni", "O"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 6, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 8, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 4, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": False, + "precision": "float64", + "seed": 1, + }, + "fitting_net": {"neuron": [5, 5, 5], "resnet_dt": True, "seed": 1}, + "spin": {"use_spin": [True, False], "virtual_scale": [0.3140, 0.0]}, + } + + model = get_model(copy.deepcopy(config)) + model_dict = model.serialize() + + data = { + "model": model_dict, + "model_def_script": config, + "backend": "dpmodel", + "software": "deepmd-kit", + "version": "3.0.0", + } + + print(f"Building DPA3 spin dpmodel and saving to {yaml_path} ...") # noqa: T201 + save_dp_model(yaml_path, data) + + def main(): from deepmd.entrypoints.convert_backend import ( convert_backend, @@ -96,12 +156,23 @@ def main(): pth_path = os.path.join(base_dir, "deeppot_dpa_spin.pth") pt2_path = os.path.join(base_dir, "deeppot_dpa_spin.pt2") - # ---- 1. Build .yaml if it doesn't exist ---- + # Multi-rank GNN spin variant (DPA3 + use_loc_mapping=False). + # Produces a dual-artifact .pt2 that DeepSpinPTExpt routes to + # under mpirun > 1 (Phase 4c spin multi-rank dispatch). + yaml_dpa3_path = os.path.join(base_dir, "deeppot_dpa3_spin_mpi.yaml") + pt2_dpa3_path = os.path.join(base_dir, "deeppot_dpa3_spin_mpi.pt2") + + # ---- 1. Build .yamls if they don't exist ---- if not os.path.exists(yaml_path): _build_yaml(yaml_path) else: print(f"Using existing {yaml_path}") # noqa: T201 + if not os.path.exists(yaml_dpa3_path): + _build_dpa3_mpi_yaml(yaml_dpa3_path) + else: + print(f"Using existing {yaml_dpa3_path}") # noqa: T201 + # ---- 2. Convert .yaml -> .pth and .yaml -> .pt2 ---- # Import deepmd.pt to register the backend (needed for convert_backend) import deepmd.pt # noqa: F401 @@ -114,6 +185,9 @@ def main(): print(f"Converting to {pt2_path} ...") # noqa: T201 convert_backend(INPUT=yaml_path, OUTPUT=pt2_path, atomic_virial=True) + print(f"Converting to {pt2_dpa3_path} ...") # noqa: T201 + convert_backend(INPUT=yaml_dpa3_path, OUTPUT=pt2_dpa3_path, atomic_virial=True) + print("Export done.") # noqa: T201 # ---- 3. Run inference for PBC test ---- diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index f2a9b07a6a..d4d987fe95 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -17,6 +17,10 @@ _get_current_function_mode_stack, ) +# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via +# ``_check_underlying_ops_loaded()``, so we no longer need to preload +# ``deepmd.pt`` here. + def _pop_device_contexts() -> list: """Pop all stale DeviceContext modes from the torch function mode stack.""" diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index 0fc971ba58..c5a2ed57a6 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -294,3 +294,40 @@ def test_share_params(self, shared_level) -> None: # invalid level raises with pytest.raises(NotImplementedError): dd1.share_params(dd0, shared_level=2) + + +def test_has_message_passing_across_ranks() -> None: + """DPA1 (se_atten) is single-layer attention; no cross-rank + feature exchange is needed at multi-rank deployment. + """ + import copy + + from deepmd.dpmodel.model.model import ( + get_model, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_atten", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": 20, + "neuron": [2, 4], + "axis_neuron": 2, + "attn": 5, + "attn_layer": 1, + "type_one_side": True, + "precision": "float64", + "seed": 1, + }, + "fitting_net": { + "neuron": [4, 4], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + desc = get_model(copy.deepcopy(config)).atomic_model.descriptor + assert desc.has_message_passing() is False + assert desc.has_message_passing_across_ranks() is False diff --git a/source/tests/pt_expt/descriptor/test_dpa2.py b/source/tests/pt_expt/descriptor/test_dpa2.py index fb0005e13a..217bcdb230 100644 --- a/source/tests/pt_expt/descriptor/test_dpa2.py +++ b/source/tests/pt_expt/descriptor/test_dpa2.py @@ -426,3 +426,60 @@ def fn(coord_ext, atype_ext, nlist, mapping): rtol=rtol, atol=atol, ) + + +def test_has_message_passing_across_ranks() -> None: + """DPA2's repformer always passes ``g1`` in ``[nb, nall, n_dim]`` + layout (no ``use_loc_mapping`` opt-out exists), so cross-rank + message passing is always required for multi-rank deployment. + """ + import copy + + from deepmd.dpmodel.model.model import ( + get_model, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "rcut_smth": 2.0, + "nsel": 20, + "neuron": [2, 4], + "axis_neuron": 4, + "tebd_dim": 8, + "tebd_input_mode": "concat", + "set_davg_zero": True, + "type_one_side": True, + "use_three_body": False, + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 1.5, + "nsel": 10, + "nlayers": 1, + "g1_dim": 8, + "g2_dim": 5, + "axis_neuron": 4, + "update_g1_has_conv": True, + "update_g1_has_drrd": True, + "update_g1_has_grrg": True, + "update_g2_has_attn": True, + "attn1_hidden": 8, + "attn1_nhead": 2, + "attn2_hidden": 5, + "attn2_nhead": 1, + "update_style": "res_avg", + "set_davg_zero": True, + }, + "concat_output_tebd": True, + "precision": "float64", + "seed": 1, + }, + "fitting_net": {"neuron": [4, 4], "resnet_dt": True, "seed": 1}, + } + desc = get_model(copy.deepcopy(config)).atomic_model.descriptor + assert desc.has_message_passing() is True + assert desc.has_message_passing_across_ranks() is True diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py index ef4b479724..3013f5cc65 100644 --- a/source/tests/pt_expt/descriptor/test_dpa3.py +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -311,3 +311,43 @@ def test_share_params(self, shared_level) -> None: # invalid level raises with pytest.raises(NotImplementedError): dd1.share_params(dd0, shared_level=2) + + +@pytest.mark.parametrize("use_loc_mapping", [True, False]) +def test_has_message_passing_across_ranks(use_loc_mapping) -> None: + """DPA3 always reports message passing; cross-rank only when + ``use_loc_mapping=False`` (so per-layer node embeddings must flow + via MPI ghost exchange instead of a local gather). + """ + import copy + + from deepmd.dpmodel.model.model import ( + get_model, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 6, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 8, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 4, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": use_loc_mapping, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, + } + model = get_model(copy.deepcopy(config)) + desc = model.atomic_model.descriptor + assert desc.has_message_passing() is True + assert desc.has_message_passing_across_ranks() is (not use_loc_mapping) diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py index 5fa8970bf1..b45a7bea19 100644 --- a/source/tests/pt_expt/descriptor/test_hybrid.py +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -284,3 +284,73 @@ def test_share_params(self) -> None: # invalid level raises with pytest.raises(NotImplementedError): dd1.share_params(dd0, shared_level=1) + + +def _se_e2_a_child() -> dict: + return { + "type": "se_e2_a", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": [20, 20], + "neuron": [2, 4], + "axis_neuron": 2, + "type_one_side": True, + "precision": "float64", + "seed": 1, + } + + +def _dpa3_child(use_loc_mapping: bool) -> dict: + return { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 6, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 8, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 4, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": use_loc_mapping, + } + + +@pytest.mark.parametrize( + "child_factory,expected_hmp,expected_hmp_ar", + [ + (_se_e2_a_child, False, False), + (lambda: _dpa3_child(use_loc_mapping=True), True, False), + (lambda: _dpa3_child(use_loc_mapping=False), True, True), + ], + ids=["se_e2_a-only", "dpa3-ulm-true", "dpa3-ulm-false"], +) +def test_has_message_passing_across_ranks( + child_factory, expected_hmp, expected_hmp_ar +) -> None: + """Hybrid descriptor recurses into its children; cross-rank message + passing is required iff any child needs it. Closes the structural + side of catalog Tier-1 #1. + """ + import copy + + from deepmd.dpmodel.model.model import ( + get_model, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "hybrid", + "list": [child_factory()], + }, + "fitting_net": {"neuron": [4, 4], "seed": 1}, + } + desc = get_model(copy.deepcopy(config)).atomic_model.descriptor + assert desc.has_message_passing() is expected_hmp + assert desc.has_message_passing_across_ranks() is expected_hmp_ar diff --git a/source/tests/pt_expt/descriptor/test_repflow_parallel.py b/source/tests/pt_expt/descriptor/test_repflow_parallel.py new file mode 100644 index 0000000000..f5b4d40bcd --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_repflow_parallel.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Eager parity test for the pt_expt RepFlow parallel-mode override. + +Verifies that ``DescrptBlockRepflows._exchange_ghosts`` (the pt_expt +override) produces output identical to the dpmodel default +``_exchange_ghosts`` when the supplied ``comm_dict`` describes a +single-rank, self-only MPI exchange whose effect equals the per-layer +gather that the default does via ``mapping``. + +This is a Phase 2.5 gate: it exercises the override code path *eagerly* +(no torch.export, no AOTInductor) before we attempt the export round +trip in Phase 3. End-to-end multi-rank validation is deferred to the +Phase 5 LAMMPS test (``test_lammps_dpa3_pt2_mpi``). + +Implementation note: the underlying ``torch.ops.deepmd.border_op`` +treats ``sendlist_tensor`` as a packed pointer-array (``int**``). We +build that pointer array using numpy contiguous int32 arrays and pack +their addresses into an int64 tensor. In single-rank mode (no MPI +init) the C++ op enters the ``sendproc == me`` self-send branch and +performs an in-process memcpy from the sendlist-indexed rows into the +ghost slots — no MPI runtime needed. +""" + +from __future__ import ( + annotations, +) + +import ctypes + +import numpy as np +import pytest +import torch + +# Trigger registration of the deepmd_export::border_op opaque wrapper. +import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt_expt.descriptor.dpa3 import ( + DescrptDPA3, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...common.test_mixins import ( + TestCaseSingleFrameWithNlist, + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + +# --------------------------------------------------------------------------- +# Helpers for building the comm_dict tensors + + +def _addr_of(np_arr: np.ndarray) -> int: + """Return the raw int address of a numpy array's data buffer.""" + return np_arr.ctypes.data_as(ctypes.c_void_p).value + + +def _build_self_comm_dict( + *, + nloc: int, + nghost: int, + sendlist_indices: np.ndarray, + device: torch.device, + keepalive: list, +) -> dict: + """Build a comm_dict for a single-rank self-exchange. + + Parameters + ---------- + nloc, nghost + Atom counts; ``nall = nloc + nghost``. + sendlist_indices + int32 array of length ``nghost`` giving local indices to copy + into successive ghost slots [nloc, nloc+1, ...]. + device + Target torch device for the data tensors. The control tensors + (send_proc / recv_proc / send_num / recv_num / send_list / + communicator / nlocal / nghost) are forced to CPU regardless of + ``device`` because the C++ ``border_op`` host-side code derefer- + ences ``data_ptr()`` directly — production builds them on + CPU in ``commonPTExpt.h::build_comm_tensors_positional`` and a + CUDA-built kernel will segfault if it tries to read CUDA memory + from the host. + keepalive + List into which we store numpy buffers that must outlive the + forward pass (their addresses are referenced by sendlist_tensor). + """ + del device # control tensors are always CPU; see docstring + sendlist_indices = np.ascontiguousarray(sendlist_indices, dtype=np.int32) + keepalive.append(sendlist_indices) + nswap = 1 + addr = _addr_of(sendlist_indices) + # int** packed as one int64 entry per swap. + sendlist_tensor = torch.tensor([addr], dtype=torch.int64, device="cpu") + sendproc = torch.zeros(nswap, dtype=torch.int32, device="cpu") + recvproc = torch.zeros(nswap, dtype=torch.int32, device="cpu") + sendnum = torch.tensor([nghost], dtype=torch.int32, device="cpu") + recvnum = torch.tensor([nghost], dtype=torch.int32, device="cpu") + communicator = torch.zeros(1, dtype=torch.int64, device="cpu") + nlocal_ts = torch.tensor(nloc, dtype=torch.int32, device="cpu") + nghost_ts = torch.tensor(nghost, dtype=torch.int32, device="cpu") + return { + "send_list": sendlist_tensor, + "send_proc": sendproc, + "recv_proc": recvproc, + "send_num": sendnum, + "recv_num": recvnum, + "communicator": communicator, + "nlocal": nlocal_ts, + "nghost": nghost_ts, + } + + +# --------------------------------------------------------------------------- + + +class TestRepflowParallel(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + # ``mapping_at_parallel`` toggles between two scenarios: + # - "with-mapping": parallel call still receives the mapping tensor + # (matches what pt's DeepPotPT.cc does in production). + # - "none-mapping": parallel call receives ``mapping=None`` so the + # dpmodel branches that gate on ``mapping is not None`` are + # exercised (the regular code path still uses mapping for the + # reference, which proves the comm_dict path's correctness + # does not depend on mapping when override consumes comm_dict). + @pytest.mark.parametrize("mapping_at_parallel", ["with-mapping", "none-mapping"]) + @pytest.mark.parametrize( + "prec", ["float64"] + ) # precision (single is enough for parity) + def test_parallel_matches_default( + self, + prec: str, + mapping_at_parallel: str, + ) -> None: + """Override with comm_dict matching mapping must match default path.""" + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repflow = RepFlowArgs( + n_dim=8, + e_dim=6, + a_dim=4, + nlayers=2, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=False, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + + dd = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision=prec, + use_econf_tebd=False, + type_map=None, + seed=GLOBAL_SEED, + use_loc_mapping=False, # need extended-region indexing for parity + ).to(self.device) + dd.repflows.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + + # use only the first frame to keep the test simple — single rank, + # one frame, simple mapping ([0, 1, 2, 0]: ghost atom 3 mirrors local 0). + coord_ext = torch.tensor( + self.coord_ext[:1], + dtype=dtype, + device=self.device, + ) + atype_ext = torch.tensor( + self.atype_ext[:1], + dtype=torch.int64, + device=self.device, + ) + nlist = torch.tensor(self.nlist[:1], dtype=torch.int64, device=self.device) + mapping = torch.tensor( + self.mapping[:1], + dtype=torch.int64, + device=self.device, + ) + nall = self.nall + + # Default path (comm_dict=None) — uses gather via mapping. + rd_default, _, _, _, _ = dd(coord_ext, atype_ext, nlist, mapping) + + # Parallel path: build a comm_dict whose sendlist mirrors the + # extended portion of mapping. For each ghost slot ii in + # [nloc, nall), border_op writes node_ebd[sendlist[ii - nloc]], + # so sendlist must match mapping[nloc:nall]. + keepalive: list = [] + ghost_sources = self.mapping[0, nloc:].astype(np.int32) + comm_dict = _build_self_comm_dict( + nloc=nloc, + nghost=nall - nloc, + sendlist_indices=ghost_sources, + device=self.device, + keepalive=keepalive, + ) + + mapping_for_parallel = ( + mapping if mapping_at_parallel == "with-mapping" else None + ) + rd_parallel, _, _, _, _ = dd( + coord_ext, + atype_ext, + nlist, + mapping_for_parallel, + comm_dict=comm_dict, + ) + + np.testing.assert_allclose( + rd_parallel.detach().cpu().numpy(), + rd_default.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + def test_use_loc_mapping_with_comm_dict_raises(self) -> None: + """``use_loc_mapping=True`` + ``comm_dict`` is contradictory. + + The local-mapping codepath skips per-layer ghost exchange + entirely, so combining it with ``comm_dict`` would silently + drop the parallel behaviour. Verify the override raises a + clear error rather than producing wrong output. + """ + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + repflow = RepFlowArgs( + n_dim=8, + e_dim=6, + a_dim=4, + nlayers=1, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=False, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + dd = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision="float64", + use_econf_tebd=False, + type_map=None, + seed=GLOBAL_SEED, + use_loc_mapping=True, # contradictory with comm_dict + ).to(self.device) + dd.repflows.mean = torch.tensor(davg, dtype=torch.float64, device=self.device) + dd.repflows.stddev = torch.tensor(dstd, dtype=torch.float64, device=self.device) + + coord_ext = torch.tensor( + self.coord_ext[:1], + dtype=torch.float64, + device=self.device, + ) + atype_ext = torch.tensor( + self.atype_ext[:1], + dtype=torch.int64, + device=self.device, + ) + nlist = torch.tensor(self.nlist[:1], dtype=torch.int64, device=self.device) + mapping = torch.tensor( + self.mapping[:1], + dtype=torch.int64, + device=self.device, + ) + + keepalive: list = [] + ghost_sources = self.mapping[0, nloc:].astype(np.int32) + comm_dict = _build_self_comm_dict( + nloc=nloc, + nghost=self.nall - nloc, + sendlist_indices=ghost_sources, + device=self.device, + keepalive=keepalive, + ) + + with pytest.raises(RuntimeError, match="use_loc_mapping=True"): + dd(coord_ext, atype_ext, nlist, mapping, comm_dict=comm_dict) + + def test_spin_branch_runs(self) -> None: + """Structural test for the ``has_spin`` branch of _exchange_ghosts. + + Builds a synthetic input that satisfies the spin path's atom- + doubling invariant (``nloc`` and ``nall`` even), invokes the + override directly with ``comm_dict["has_spin"]`` set, and + verifies the output shape matches the input. This catches + regressions in the split-real-virtual + concat_switch_virtual + code path without requiring a full spin model. + """ + from deepmd.pt_expt.descriptor.repflows import ( + DescrptBlockRepflows, + ) + + # Build a minimally-initialised block instance via deserialize + # of a tiny dpmodel block. We just need an instance to call + # the method on; method behaviour is independent of weights. + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + repflow = RepFlowArgs( + n_dim=8, + e_dim=6, + a_dim=4, + nlayers=1, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=False, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + dd = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision="float64", + use_econf_tebd=False, + type_map=None, + seed=GLOBAL_SEED, + use_loc_mapping=False, + ).to(self.device) + dd.repflows.mean = torch.tensor(davg, dtype=torch.float64, device=self.device) + dd.repflows.stddev = torch.tensor(dstd, dtype=torch.float64, device=self.device) + block = dd.repflows + assert isinstance(block, DescrptBlockRepflows) + + # Pseudo-spin shapes: nloc and nall are even; n_dim from the + # model. The spin path splits along dim 1 into real/virtual + # halves and concats along dim 2. + n_dim = block.n_dim + nloc_spin, nghost_spin = 4, 2 + nall_spin = nloc_spin + nghost_spin + # node_ebd: (1, nloc_spin, n_dim) + node_ebd = torch.randn( + 1, + nloc_spin, + n_dim, + dtype=torch.float64, + device=self.device, + ) + + keepalive: list = [] + # sendlist mirrors local-to-ghost slot for one ghost rank. + # Real ghost slots are real_nall-real_nloc = 1 atoms -> sendlist + # has 1 entry. Self-send branch will copy local index 0. + sendlist_indices = np.array([0], dtype=np.int32) + comm_dict = _build_self_comm_dict( + nloc=nloc_spin // 2, + nghost=nghost_spin // 2, + sendlist_indices=sendlist_indices, + device=self.device, + keepalive=keepalive, + ) + comm_dict["has_spin"] = torch.tensor( + [1], + dtype=torch.int32, + device=self.device, + ) + + # Direct invocation of _exchange_ghosts on the block. + out = block._exchange_ghosts( + node_ebd, + mapping_tiled=None, + comm_dict=comm_dict, + nall=nall_spin, + nloc=nloc_spin, + ) + # concat_switch_virtual produces a tensor of shape + # (1, nall_spin, n_dim) — 4 real + 2 virtual + 2 ghost-real + + # 2 ghost-virtual interleaved per the helper's contract. + # The exact structure is: out[1] dim is doubled relative to the + # real_nall (real_nloc + real_nghost = 3); for nloc_spin=4, + # nall_spin=6, the helper outputs 2*real_nall = 6 rows. + assert out.shape[0] == 1 + assert out.shape[2] == n_dim + # Spin path returns shape (1, 2*real_nall, n_dim) = (1, nall_spin, n_dim). + assert out.shape[1] == nall_spin diff --git a/source/tests/pt_expt/descriptor/test_repformer_parallel.py b/source/tests/pt_expt/descriptor/test_repformer_parallel.py new file mode 100644 index 0000000000..1a6413d08f --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_repformer_parallel.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Eager parity test for the pt_expt Repformer parallel-mode override. + +Mirror of ``test_repflow_parallel.py`` but for DPA2 (which uses +``DescrptBlockRepformers``). Same single-rank self-exchange trick: +``sendlist`` mirrors ``mapping[nloc:]`` so the C++ ``border_op``'s +self-send branch reproduces the gather that the dpmodel default does. +""" + +from __future__ import ( + annotations, +) + +import ctypes + +import numpy as np +import pytest +import torch + +# Trigger registration of the deepmd_export::border_op opaque wrapper. +import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.pt_expt.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...common.test_mixins import ( + TestCaseSingleFrameWithNlist, + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +def _addr_of(np_arr: np.ndarray) -> int: + return np_arr.ctypes.data_as(ctypes.c_void_p).value + + +def _build_self_comm_dict( + *, + nloc: int, + nghost: int, + sendlist_indices: np.ndarray, + device: torch.device, + keepalive: list, +) -> dict: + """Control tensors must live on CPU because the C++ ``border_op`` + host code dereferences ``data_ptr()`` directly. Production + builds them on CPU in + ``commonPTExpt.h::build_comm_tensors_positional``; on a CUDA build + a CUDA-device control tensor segfaults the host read. See + ``test_repflow_parallel.py::_build_self_comm_dict`` for the full + rationale. + """ + del device # control tensors are always CPU + sendlist_indices = np.ascontiguousarray(sendlist_indices, dtype=np.int32) + keepalive.append(sendlist_indices) + nswap = 1 + addr = _addr_of(sendlist_indices) + sendlist_tensor = torch.tensor([addr], dtype=torch.int64, device="cpu") + sendproc = torch.zeros(nswap, dtype=torch.int32, device="cpu") + recvproc = torch.zeros(nswap, dtype=torch.int32, device="cpu") + sendnum = torch.tensor([nghost], dtype=torch.int32, device="cpu") + recvnum = torch.tensor([nghost], dtype=torch.int32, device="cpu") + communicator = torch.zeros(1, dtype=torch.int64, device="cpu") + nlocal_ts = torch.tensor(nloc, dtype=torch.int32, device="cpu") + nghost_ts = torch.tensor(nghost, dtype=torch.int32, device="cpu") + return { + "send_list": sendlist_tensor, + "send_proc": sendproc, + "recv_proc": recvproc, + "send_num": sendnum, + "recv_num": recvnum, + "communicator": communicator, + "nlocal": nlocal_ts, + "nghost": nghost_ts, + } + + +class TestRepformerParallel(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + # See test_repflow_parallel.py for rationale on the "none-mapping" + # variant — exercises dpa2's "skip pre-block gather" branch with + # mapping=None, which is the realistic LAMMPS multi-rank shape. + @pytest.mark.parametrize("mapping_at_parallel", ["with-mapping", "none-mapping"]) + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_parallel_matches_default( + self, + prec: str, + mapping_at_parallel: str, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = 0.1 + np.abs(dstd_2) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="concat", + set_davg_zero=True, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=2, + g1_dim=12, + g2_dim=8, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=True, + update_g1_has_grrg=True, + update_g1_has_attn=True, + update_g2_has_g1g1=True, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=12, + attn1_nhead=2, + attn2_hidden=8, + attn2_nhead=2, + attn2_has_gate=False, + update_style="res_avg", + set_davg_zero=True, + use_sqrt_nnei=False, + g1_out_conv=False, + g1_out_mlp=False, + ) + + dd = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=False, + type_map=None, + seed=GLOBAL_SEED, + ).to(self.device) + dd.repinit.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=self.device) + dd.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=self.device) + + coord_ext = torch.tensor( + self.coord_ext[:1], + dtype=dtype, + device=self.device, + ) + atype_ext = torch.tensor( + self.atype_ext[:1], + dtype=torch.int64, + device=self.device, + ) + nlist = torch.tensor(self.nlist[:1], dtype=torch.int64, device=self.device) + mapping = torch.tensor( + self.mapping[:1], + dtype=torch.int64, + device=self.device, + ) + nall = self.nall + + rd_default, _, _, _, _ = dd(coord_ext, atype_ext, nlist, mapping) + + keepalive: list = [] + ghost_sources = self.mapping[0, nloc:].astype(np.int32) + comm_dict = _build_self_comm_dict( + nloc=nloc, + nghost=nall - nloc, + sendlist_indices=ghost_sources, + device=self.device, + keepalive=keepalive, + ) + + mapping_for_parallel = ( + mapping if mapping_at_parallel == "with-mapping" else None + ) + rd_parallel, _, _, _, _ = dd( + coord_ext, + atype_ext, + nlist, + mapping_for_parallel, + comm_dict=comm_dict, + ) + + np.testing.assert_allclose( + rd_parallel.detach().cpu().numpy(), + rd_default.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index e4bd1e385e..e3a8ca5c21 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -221,3 +221,38 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + +def test_has_message_passing_across_ranks() -> None: + """se_e2_a is a single-layer local descriptor: no message passing, + no cross-rank exchange ever needed. + """ + import copy + + from deepmd.dpmodel.model.model import ( + get_model, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": [20, 20], + "neuron": [2, 4], + "axis_neuron": 2, + "type_one_side": True, + "precision": "float64", + "seed": 1, + }, + "fitting_net": { + "neuron": [4, 4], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + desc = get_model(copy.deepcopy(config)).atomic_model.descriptor + assert desc.has_message_passing() is False + assert desc.has_message_passing_across_ranks() is False diff --git a/source/tests/pt_expt/model/test_export_with_comm.py b/source/tests/pt_expt/model/test_export_with_comm.py new file mode 100644 index 0000000000..dcbc628e53 --- /dev/null +++ b/source/tests/pt_expt/model/test_export_with_comm.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Phase 3 round-trip test for the with-comm AOTInductor artifact. + +For a GNN model (DPA3 here), ``deserialize_to_file`` produces a .pt2 +archive containing TWO compiled artifacts: + * the regular forward_lower (no comm), packed at the top of the ZIP. + * a ``forward_lower_with_comm`` variant nested at + ``extra/forward_lower_with_comm.pt2``. + +This test verifies: + 1. Both artifacts are present in the archive. + 2. ``metadata.json`` carries the ``has_comm_artifact`` flag. + 3. The with-comm artifact loads via ``aoti_load_package`` and runs + when fed valid comm-dict tensors built via the ctypes pointer + trick (see ``test_repflow_parallel.py``). + 4. The with-comm artifact's output matches the regular artifact's + output for a single-rank self-exchange whose effect is identity + (sendlist mirrors the extended-region mapping, which is what the + gather in the regular path produces). +""" + +from __future__ import ( + annotations, +) + +import ctypes +import json +import os +import tempfile +import zipfile + +import numpy as np +import pytest +import torch + +# Trigger registration of the deepmd_export::border_op opaque wrapper +# (needed by the with-comm artifact at runtime). +import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] +from deepmd.pt_expt.model.get_model import ( + get_model, +) +from deepmd.pt_expt.utils.serialization import ( + _make_sample_inputs, + deserialize_to_file, +) + +_DPA3_CONFIG = { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 6, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 12, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 8, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": False, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, +} + + +def _addr_of(np_arr: np.ndarray) -> int: + return np_arr.ctypes.data_as(ctypes.c_void_p).value + + +def _build_self_comm_inputs( + nloc: int, + nghost: int, + sendlist_indices: np.ndarray, + keepalive: list, +) -> tuple[torch.Tensor, ...]: + """Build runtime comm tensors for a single-rank self-send. + + Clamps the swap count to ``max(1, nghost)`` to mirror the trace-time + helper in ``serialization.py::_make_comm_sample_inputs``; that + avoids an empty sendlist pointer when a caller happens to construct + a fixture with no ghost atoms. + """ + send_count = max(1, nghost) + sendlist_indices = np.ascontiguousarray(sendlist_indices, dtype=np.int32) + if sendlist_indices.size == 0: + sendlist_indices = np.zeros(send_count, dtype=np.int32) + keepalive.append(sendlist_indices) + nswap = 1 + addr = _addr_of(sendlist_indices) + send_list = torch.tensor([addr], dtype=torch.int64) + send_proc = torch.zeros(nswap, dtype=torch.int32) + recv_proc = torch.zeros(nswap, dtype=torch.int32) + send_num = torch.tensor([send_count], dtype=torch.int32) + recv_num = torch.tensor([send_count], dtype=torch.int32) + communicator = torch.zeros(1, dtype=torch.int64) + nlocal_ts = torch.tensor(nloc, dtype=torch.int32) + nghost_ts = torch.tensor(nghost, dtype=torch.int32) + return ( + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal_ts, + nghost_ts, + ) + + +@pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="AOTInductor compile is slow (~30s); run locally only by default.", +) +def test_pt2_dual_artifact_for_gnn(tmp_path) -> None: + """End-to-end: GNN model produces dual-artifact .pt2; both load.""" + model = get_model(_DPA3_CONFIG) + model.to("cpu") + model.eval() + + # Serialize → deserialize_to_file (compiles and packs both artifacts) + pt2_path = str(tmp_path / "test_dpa3.pt2") + data = {"model": model.serialize()} + deserialize_to_file(pt2_path, data) + assert os.path.exists(pt2_path) + + # 1. ZIP layout sanity. PyTorch 2.11 strict layout puts our sidecars + # under ``model/extra/`` (PT2_EXTRA_PREFIX); see serialization.py. + with zipfile.ZipFile(pt2_path, "r") as zf: + names = set(zf.namelist()) + meta = json.loads(zf.read("model/extra/metadata.json").decode("utf-8")) + assert "model/extra/forward_lower_with_comm.pt2" in names, ( + f"with-comm artifact missing; names={sorted(names)}" + ) + assert meta["has_comm_artifact"] is True + + # 2. Both artifacts load. + from torch._inductor import ( + aoti_load_package, + ) + + regular = aoti_load_package(pt2_path) + + with tempfile.TemporaryDirectory() as td: + wc_path = os.path.join(td, "fl_wc.pt2") + with zipfile.ZipFile(pt2_path, "r") as zf: + with open(wc_path, "wb") as f: + f.write(zf.read("model/extra/forward_lower_with_comm.pt2")) + with_comm = aoti_load_package(wc_path) + + # 3. Run both artifacts with nframes=1 (matches what the with-comm + # artifact requires; LAMMPS always passes one frame anyway). + sample = _make_sample_inputs(model, nframes=1, has_spin=False) + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = sample + nloc = nlist_t.shape[1] + nall = ext_atype.shape[1] + nghost = nall - nloc + + out_regular = regular(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + + # 4. Build runtime comm tensors mirroring the mapping (single-rank + # self-send: ghost slot ii receives node[mapping[ii]], identical to + # the gather in the regular path). + keepalive: list = [] + ghost_sources = mapping_t[0, nloc:].cpu().numpy().astype(np.int32) + comm_inputs = _build_self_comm_inputs( + nloc=nloc, + nghost=nghost, + sendlist_indices=ghost_sources, + keepalive=keepalive, + ) + + out_with_comm = with_comm( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + *comm_inputs, + ) + + # 5. Outputs must match (parity gate, eager-mode equivalent). + for key in out_regular: + np.testing.assert_allclose( + out_with_comm[key].detach().cpu().numpy(), + out_regular[key].detach().cpu().numpy(), + rtol=0, + atol=1e-10, + err_msg=f"output[{key}] differs between regular and with-comm", + ) + + +# --------------------------------------------------------------------------- +# Coverage for previously-untested branches +# --------------------------------------------------------------------------- + + +def test_make_comm_sample_inputs_clamps_zero_nghost() -> None: + """``_make_comm_sample_inputs(nghost=0)`` must produce valid tensors. + + The clamp ``send_count = max(1, nghost)`` ensures we never pass an + empty pointer-array to border_op. This test exercises the + ``nghost == 0`` branch (a model exported on a system whose entire + domain fits in one rank with no ghosts) — the trace must still + produce well-formed comm tensors of shape (1,). + """ + from deepmd.pt_expt.utils.serialization import ( + _make_comm_sample_inputs, + ) + + comm_inputs = _make_comm_sample_inputs( + nloc=4, + nghost=0, + device=torch.device("cpu"), + ) + assert len(comm_inputs) == 8 + ( + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal, + nghost_t, + ) = comm_inputs + # nswap stays at 1 (Phase 0: nswap=0 specializes during export). + assert send_list.shape == (1,) + assert send_proc.shape == (1,) + assert recv_proc.shape == (1,) + assert send_num.shape == (1,) + assert recv_num.shape == (1,) + # send_count is clamped to >=1, so send_num is also clamped. + assert send_num.item() == 1 + assert recv_num.item() == 1 + # Scalar metadata reports the original (un-clamped) values. + assert nlocal.item() == 4 + assert nghost_t.item() == 0 + + +def test_needs_with_comm_artifact_for_hybrid_with_gnn() -> None: + """``_needs_with_comm_artifact`` correctly reports True for hybrid + descriptors whose children include a GNN block needing cross-rank + message passing. + + The hybrid descriptor delegates ``has_message_passing_across_ranks()`` + to its children — if any child needs cross-rank message passing, + the hybrid does too. ``_deserialize_to_file_pt2`` uses this gate + to decide whether to compile the with-comm artifact, so the + hybrid case must route correctly. + """ + from deepmd.pt_expt.model.get_model import get_model as get_pt_expt_model + from deepmd.pt_expt.utils.serialization import ( + _needs_with_comm_artifact, + ) + + config = { + "type_map": ["O", "H"], + "descriptor": { + "type": "hybrid", + "list": [ + # Non-GNN child. + { + "type": "se_e2_a", + "sel": [12, 12], + "rcut": 4.0, + "rcut_smth": 0.5, + "neuron": [4, 8], + "axis_neuron": 4, + "seed": 1, + }, + # GNN child (DPA3). + { + "type": "dpa3", + "repflow": { + "n_dim": 4, + "e_dim": 4, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 8, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 4, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": False, + }, + ], + }, + "fitting_net": {"neuron": [8, 8], "seed": 1}, + } + model = get_pt_expt_model(config) + model.to("cpu") + model.eval() + assert _needs_with_comm_artifact(model) is True, ( + "hybrid model with a use_loc_mapping=False GNN child must " + "report has_message_passing_across_ranks=True so a with-comm " + "artifact gets compiled" + ) + + +def test_pte_with_comm_dict_traces_and_loads(tmp_path) -> None: + """``_trace_and_export(with_comm_dict=True)`` produces a valid + ExportedProgram that can be saved as .pte and loaded back. + + .pte is Python-only (the multi-rank consumer is C++/LAMMPS via + .pt2), so production has no business calling this path. But the + trace machinery is the same as the .pt2 path, so .pte serves as + a cheap (no AOTI compile) round-trip test for the with-comm + export pipeline. + """ + from deepmd.pt_expt.utils.serialization import ( + _trace_and_export, + ) + + model = get_model(_DPA3_CONFIG) + model.to("cpu") + model.eval() + data = {"model": model.serialize()} + + exported, metadata, _data_for_json, output_keys = _trace_and_export( + data, + model_json_override=None, + with_comm_dict=True, + ) + # ``_trace_and_export(with_comm_dict=True)`` is the with-comm path + # by construction; metadata at this layer no longer carries the + # has_message_passing flag (only ``has_comm_artifact``, written + # later in _deserialize_to_file_pt2). Sanity-check via output_keys + # that the trace produced energy outputs. + # output_keys mirrors what the regular trace would produce; at + # least one energy-related key must be present. + assert any(k.startswith("energy") for k in output_keys), ( + f"expected an 'energy*' output key; got {output_keys}" + ) + + # Save as .pte and reload — verifies the ExportedProgram is + # structurally valid (no broken graph or missing constants). + pte_path = str(tmp_path / "fl_with_comm.pte") + torch.export.save(exported, pte_path) + assert os.path.exists(pte_path) + loaded = torch.export.load(pte_path) + # Sanity: the loaded program has the expected number of inputs + # (6 base + 8 comm = 14). + spec = loaded.module().graph.find_nodes(op="placeholder") + assert len(spec) == 14, ( + f"with-comm exported program must accept 14 positional inputs " + f"(6 base + 8 comm); got {len(spec)}" + ) diff --git a/source/tests/pt_expt/model/test_spin_export_with_comm.py b/source/tests/pt_expt/model/test_spin_export_with_comm.py new file mode 100644 index 0000000000..0e403d2b42 --- /dev/null +++ b/source/tests/pt_expt/model/test_spin_export_with_comm.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for SpinModel + comm_dict end-to-end. + +Two coverage levels: + +1. ``test_spin_forward_common_lower_exportable_with_comm_traces``: + verifies the trace machinery (positional comm-tensor plumbing, + has_spin injection, make_fx symbolic mode) on a spin model with a + non-GNN descriptor (se_e2_a). The non-GNN case is the cheapest + smoke test since se_e2_a's `call` accepts and drops comm_dict — + exercising the wrapper/spin model layers without paying for GNN + compile cost. + +2. ``test_spin_dpa3_eager_parity``: end-to-end value-correctness for + a spin DPA3 model running through ``call_common_lower`` in eager + mode, with a comm_dict whose self-exchange mirrors the mapping. + Asserts the result matches the no-comm reference. This proves + ``SpinModel.call_common_lower`` correctly forwards comm_dict + through to the GNN repflow, AND that the spin branch of + ``_exchange_ghosts`` (real/virtual split + concat_switch_virtual) + reproduces the regular gather path on real values. +""" + +from __future__ import ( + annotations, +) + +import ctypes + +import numpy as np +import torch + +import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - opaque op registration +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.pt_expt.model.spin_ener_model import ( + SpinEnergyModel, +) + +SPIN_GNN_DATA = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140], + }, +} + + +def _addr_of(np_arr: np.ndarray) -> int: + return np_arr.ctypes.data_as(ctypes.c_void_p).value + + +def _build_self_comm_inputs(nloc: int, nghost: int): + """Build trivial-but-valid comm tensors for tracing.""" + keepalive: list[np.ndarray] = [] + indices = np.zeros(max(1, nghost), dtype=np.int32) + keepalive.append(indices) + addr = _addr_of(indices) + nswap = 1 + return ( + torch.tensor([addr], dtype=torch.int64), # send_list + torch.zeros(nswap, dtype=torch.int32), # send_proc + torch.zeros(nswap, dtype=torch.int32), # recv_proc + torch.tensor([max(1, nghost)], dtype=torch.int32), # send_num + torch.tensor([max(1, nghost)], dtype=torch.int32), # recv_num + torch.zeros(1, dtype=torch.int64), # communicator + torch.tensor(nloc, dtype=torch.int32), # nlocal + torch.tensor(nghost, dtype=torch.int32), # nghost + ), keepalive + + +def test_spin_forward_common_lower_exportable_with_comm_traces() -> None: + """The spin variant of forward_common_lower_exportable_with_comm + produces a callable traced GraphModule. + """ + dp_model = get_model_dp(SPIN_GNN_DATA) + model = SpinEnergyModel.deserialize(dp_model.serialize()).to("cpu") + model.eval() + + # Build sample inputs (nframes=1 to match the override's nb=1 + # constraint; spin doubles natoms). nlist width must match the + # model's sum(sel); the descriptor's _format_nlist asserts this. + nloc = 6 # 3 real + 3 virtual + nall = 8 # 1 ghost on each side + n_dim_coord = 3 + nnei = sum(SPIN_GNN_DATA["descriptor"]["sel"]) + ext_coord = torch.zeros(1, nall, n_dim_coord, dtype=torch.float64) + ext_atype = torch.zeros(1, nall, dtype=torch.int64) + ext_spin = torch.zeros(1, nall, n_dim_coord, dtype=torch.float64) + nlist = torch.zeros(1, nloc, nnei, dtype=torch.int64) + mapping = torch.zeros(1, nall, dtype=torch.int64) + fparam = None + aparam = None + + comm_inputs, _keepalive = _build_self_comm_inputs(nloc=nloc, nghost=nall - nloc) + + # The trace should succeed without raising. We do NOT verify + # numerical correctness here — that would require a real spin GNN + # model + live MPI (deferred to Phase 5 LAMMPS). This test only + # checks the trace-time machinery: positional arg plumbing, + # has_spin injection, and that make_fx symbolic mode produces a + # valid GraphModule. + traced = model.forward_common_lower_exportable_with_comm( + ext_coord, + ext_atype, + ext_spin, + nlist, + mapping, + fparam, + aparam, + *comm_inputs, + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + # The traced module must be a torch.nn.Module that can be invoked. + assert isinstance(traced, torch.nn.Module) + # And calling it with the same inputs returns a dict with the + # expected keys. + out = traced( + ext_coord, + ext_atype, + ext_spin, + nlist, + mapping, + fparam, + aparam, + *comm_inputs, + ) + assert isinstance(out, dict) + # forward_common_lower internal output names; specifics depend on + # the model's output def, just check at least one is present. + assert any(k.startswith("energy") for k in out), ( + f"expected an 'energy*' key in trace output; got {list(out.keys())}" + ) + + +# --------------------------------------------------------------------------- +# 2. End-to-end value parity for spin DPA3 in eager mode +# --------------------------------------------------------------------------- + + +SPIN_DPA3_DATA = { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 6, + "a_dim": 4, + "nlayers": 1, + "e_rcut": 4.0, + "e_rcut_smth": 0.5, + "e_sel": 8, + "a_rcut": 3.5, + "a_rcut_smth": 0.5, + "a_sel": 4, + "axis_neuron": 4, + "update_angle": False, + }, + "use_loc_mapping": False, + }, + "fitting_net": {"neuron": [16, 16], "seed": 1}, + "spin": {"use_spin": [True, False], "virtual_scale": [0.314]}, +} + + +def test_spin_dpa3_eager_parity() -> None: + """SpinModel.call_common_lower with comm_dict (self-exchange) must + match the no-comm reference for a spin DPA3 model. + + Setup mirrors the per-block parity tests but at the SpinModel + level so it exercises the full plumbing chain: + ``SpinModel.call_common_lower(comm_dict=...)`` + → process_spin_input_lower (atom-doubling) + → backbone EnergyModel.call_common_lower(comm_dict=...) + → atomic_model.forward_common_atomic(comm_dict=...) + → DescrptDPA3.call(comm_dict=...) + → DescrptBlockRepflows.call(comm_dict=...) + → DescrptBlockRepflows._exchange_ghosts (pt_expt override, + spin branch via has_spin in comm_dict) + + The comm_dict has has_spin=tensor([1]) and a sendlist that + mirrors the real-atom portion of the mapping. The override's + spin branch splits node_ebd into real/virtual halves, stacks + along feature dim, exchanges, then de-interleaves with + concat_switch_virtual. When the exchange produces the same + result as the gather (which it should for a self-mirror + sendlist), the spin model output must equal the no-comm output + bit-for-bit (atol 1e-12 for float64). + """ + dp_model = get_model_dp(SPIN_DPA3_DATA) + model = SpinEnergyModel.deserialize(dp_model.serialize()).to("cpu") + model.eval() + + # Build a 2-atom test system: 1 real + 1 ghost real for type 0, + # plus the same in spin (use_spin=[True, False] means type 0 is + # spin-doubled, type 1 is not). After atom-doubling the model + # processes 2 real + 2 virtual = 4 atoms locally and 4 ghost + # slots. We use minimal nloc to keep the test fast. + nframes = 1 + nloc_real = 2 # 2 real atoms (both type 0 to keep simple) + nghost_real = 2 # 2 ghost real atoms + nall_real = nloc_real + nghost_real + rng = np.random.default_rng(42) + + # Coordinates and types (real only — spin model doubles internally). + coord_real = rng.uniform(0, 4.0, size=(nframes, nall_real, 3)).astype(np.float64) + atype_real = np.zeros((nframes, nall_real), dtype=np.int64) # all type 0 + spin_real = rng.uniform(-0.1, 0.1, size=(nframes, nall_real, 3)).astype(np.float64) + # mapping: ghost atoms mirror local atoms (ghost 0 → local 0, ghost 1 → local 1) + mapping_real = np.array( + [[0, 1, 0, 1]], + dtype=np.int64, + ) # nframes=1, nall_real=4 + + # Build extended-region nlist for the real atoms. Each real atom's + # neighbour list points to the other 3 atoms (within rcut by + # construction of small box). We don't need physically meaningful + # values — just well-formed nlist so the model runs. + nnei = 8 # matches e_sel + nlist_real = np.full((nframes, nloc_real, nnei), -1, dtype=np.int64) + for ii in range(nloc_real): + # neighbours = all other atoms (real + ghost) up to nnei + others = [j for j in range(nall_real) if j != ii][:nnei] + nlist_real[0, ii, : len(others)] = others + + # ``call_common_lower`` runs through ``transform_output`` which + # calls ``torch.autograd.grad`` on coord, so coord must require + # grad in eager mode. + ext_coord = torch.tensor(coord_real, dtype=torch.float64, requires_grad=True) + ext_atype = torch.tensor(atype_real, dtype=torch.int64) + ext_spin = torch.tensor(spin_real, dtype=torch.float64) + nlist_t = torch.tensor(nlist_real, dtype=torch.int64) + mapping_t = torch.tensor(mapping_real, dtype=torch.int64) + + # 1. No-comm reference. + out_ref = model.call_common_lower( + ext_coord, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam=None, + aparam=None, + do_atomic_virial=False, + ) + + # 2. With comm_dict. The SpinModel internally doubles atoms to + # nloc=2*nloc_real=4 and nall=2*nall_real=8. The override's spin + # branch peels back to real_nloc=nloc_real and real_nall=nall_real. + # Sendlist must point to REAL local indices for each real ghost + # slot (mapping_real[nloc_real:nall_real]). + keepalive: list = [] + sendlist_indices = mapping_real[0, nloc_real:].astype(np.int32) + keepalive.append(sendlist_indices) + addr = sendlist_indices.ctypes.data_as(ctypes.c_void_p).value + nswap = 1 + nghost_real_count = nall_real - nloc_real + comm_dict = { + "send_list": torch.tensor([addr], dtype=torch.int64), + "send_proc": torch.zeros(nswap, dtype=torch.int32), + "recv_proc": torch.zeros(nswap, dtype=torch.int32), + "send_num": torch.tensor([nghost_real_count], dtype=torch.int32), + "recv_num": torch.tensor([nghost_real_count], dtype=torch.int32), + "communicator": torch.zeros(1, dtype=torch.int64), + # nlocal/nghost are the REAL counts (the override's spin branch + # halves nloc/nall internally). In production C++ side passes + # real counts here too — see DeepSpinPT.cc. + "nlocal": torch.tensor(nloc_real, dtype=torch.int32), + "nghost": torch.tensor(nghost_real_count, dtype=torch.int32), + # Triggers spin branch in the override. + "has_spin": torch.tensor([1], dtype=torch.int32), + } + + # Fresh coord tensor (the first call's backward graph would otherwise + # be reused / cause double-backward errors). + ext_coord_2 = torch.tensor(coord_real, dtype=torch.float64, requires_grad=True) + out_parallel = model.call_common_lower( + ext_coord_2, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam=None, + aparam=None, + do_atomic_virial=False, + comm_dict=comm_dict, + ) + + # 3. Compare every output key. + for key in out_ref: + ref = out_ref[key].detach().cpu().numpy() + par = out_parallel[key].detach().cpu().numpy() + np.testing.assert_allclose( + par, + ref, + atol=1e-10, + rtol=0, + err_msg=f"output[{key}] mismatch between no-comm and comm_dict path", + ) diff --git a/source/tests/pt_expt/utils/test_border_op_backward.py b/source/tests/pt_expt/utils/test_border_op_backward.py new file mode 100644 index 0000000000..b33e575f1a --- /dev/null +++ b/source/tests/pt_expt/utils/test_border_op_backward.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the new C++ symbol ``deepmd::border_op_backward`` and +the pt_expt autograd path that dispatches to it. + +Tests two distinct surfaces: + +1. **Direct op call** — invokes ``torch.ops.deepmd.border_op_backward`` + with hand-built comm tensors (single-rank self-exchange via ctypes + pointer trick). Verifies the symbol is registered, accepts the + expected positional args, and produces a correctly-shaped output + for both ``float32`` and ``float64`` (covers the ``backward_t`` + template's two specializations). + +2. **Through the opaque wrapper** — exercises + ``torch.ops.deepmd_export.border_op``'s ``register_autograd`` + pathway. Calls the wrapper inside an autograd context, asks for + ``grad`` w.r.t. the ``g1`` input, and verifies the gradient flows + through (matches the gradient produced by an equivalent + ``index_select`` + ``index_add_`` Python implementation, which is + the reference for the symmetric MPI exchange in single-rank). +""" + +from __future__ import ( + annotations, +) + +import ctypes + +import numpy as np +import pytest +import torch + +# comm self-bootstraps the underlying libdeepmd_op_pt.so when needed, so +# this single side-effect import is enough to register both the C++ +# ops (deepmd::border_op_backward) and their fake/autograd metadata. +import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - registers deepmd_export::border_op + + +def _addr_of(np_arr: np.ndarray) -> int: + return np_arr.ctypes.data_as(ctypes.c_void_p).value + + +def _build_self_swap( + nloc: int, + nghost: int, + sendlist_indices: np.ndarray, + keepalive: list, + dtype: torch.dtype, +): + """Build comm tensors for a single self-exchange swap.""" + sendlist_indices = np.ascontiguousarray(sendlist_indices, dtype=np.int32) + keepalive.append(sendlist_indices) + nswap = 1 + addr = _addr_of(sendlist_indices) + sendlist = torch.tensor([addr], dtype=torch.int64) + sendproc = torch.zeros(nswap, dtype=torch.int32) + recvproc = torch.zeros(nswap, dtype=torch.int32) + sendnum = torch.tensor([nghost], dtype=torch.int32) + recvnum = torch.tensor([nghost], dtype=torch.int32) + communicator = torch.zeros(1, dtype=torch.int64) + nlocal_ts = torch.tensor(nloc, dtype=torch.int32) + nghost_ts = torch.tensor(nghost, dtype=torch.int32) + return ( + sendlist, + sendproc, + recvproc, + sendnum, + recvnum, + communicator, + nlocal_ts, + nghost_ts, + ) + + +# --------------------------------------------------------------------------- +# 1. Direct op call: border_op_backward as a standalone op +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_border_op_backward_direct(dtype: torch.dtype) -> None: + """``torch.ops.deepmd.border_op_backward`` is callable for both + float32 and float64 inputs and returns a tensor of the expected + shape on the input's device. + """ + assert hasattr(torch.ops.deepmd, "border_op_backward"), ( + "Symbol not registered; rebuild libdeepmd_op_pt.so." + ) + nloc, nghost = 5, 3 + nall = nloc + nghost + n_dim = 4 + + keepalive: list = [] + sendlist_indices = np.array([0, 1, 2], dtype=np.int32) + comm = _build_self_swap(nloc, nghost, sendlist_indices, keepalive, dtype) + + grad_g1 = torch.ones(nall, n_dim, dtype=dtype) + + grad_in = torch.ops.deepmd.border_op_backward( + comm[0], + comm[1], + comm[2], + comm[3], + comm[4], + grad_g1, + comm[5], + comm[6], + comm[7], + ) + + # backward must preserve dtype and shape, and run on the same device. + assert grad_in.dtype == grad_g1.dtype + assert tuple(grad_in.shape) == tuple(grad_g1.shape) + assert grad_in.device == grad_g1.device + + +def test_border_op_backward_accumulation_semantics() -> None: + """Single-rank self-exchange backward: each ghost slot's grad is + accumulated into the local atom whose index sendlist points to. + + Reference: for forward ``g_ext[nloc + i] = g[sendlist[i]]``, the + reverse is ``grad_g[sendlist[i]] += grad_g_ext[nloc + i]``. + """ + nloc, nghost = 4, 4 + nall = nloc + nghost + n_dim = 3 + + # Each ghost slot maps back to a local atom: ghost 0->local 0, ghost + # 1->local 1, etc. So backward should add grad_g_ext[nloc+i] into + # grad_g[i] for i in [0, nghost). + keepalive: list = [] + sendlist_indices = np.array([0, 1, 2, 3], dtype=np.int32) + comm = _build_self_swap( + nloc, + nghost, + sendlist_indices, + keepalive, + torch.float64, + ) + + # Distinct values per ghost slot so we can identify the routing. + grad_g1 = torch.zeros(nall, n_dim, dtype=torch.float64) + grad_g1[nloc + 0, 0] = 7.0 + grad_g1[nloc + 1, 1] = 11.0 + grad_g1[nloc + 2, 2] = 13.0 + grad_g1[nloc + 3, 0] = 17.0 + # Local part has its own grad too — must pass through unchanged. + grad_g1[0, 1] = 1.0 + grad_g1[2, 2] = 2.0 + # Capture the input BEFORE the call: the C++ op writes + # ``index_add_`` into the same tensor and returns it, so once + # we've called the op the ``grad_g1`` reference points to the + # modified buffer. Snapshot first. + grad_g1_orig = grad_g1.clone() + grad_in = torch.ops.deepmd.border_op_backward( + comm[0], + comm[1], + comm[2], + comm[3], + comm[4], + grad_g1, + comm[5], + comm[6], + comm[7], + ) + + # Expected: grad_g_local += grad_g_ext[nloc:] indexed by sendlist. + # Ghost rows pass through unchanged (the C++ backward does not + # zero them; the wrapper's autograd consumer is F.pad whose + # backward drops them anyway). + expected = grad_g1_orig.clone() + for i, src_local_idx in enumerate(sendlist_indices.tolist()): + expected[src_local_idx] += grad_g1_orig[nloc + i] + np.testing.assert_allclose( + grad_in.numpy(), + expected.numpy(), + atol=1e-12, + rtol=0, + ) + + +# --------------------------------------------------------------------------- +# 2. Autograd path through the deepmd_export::border_op opaque wrapper +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_border_op_export_autograd(dtype: torch.dtype) -> None: + """End-to-end autograd through the opaque wrapper. + + Builds an inputs tensor with ``requires_grad=True``, calls the + wrapper, sums the output, and asks for ``grad`` w.r.t. the input. + The reported gradient must match a hand-computed reference based + on the same self-exchange routing. + """ + nloc, nghost = 3, 2 + nall = nloc + nghost + n_dim = 4 + + keepalive: list = [] + sendlist_indices = np.array([0, 1], dtype=np.int32) # ghosts mirror locals 0,1 + comm = _build_self_swap(nloc, nghost, sendlist_indices, keepalive, dtype) + + # g1 is full nall-shape pre-padded; ghosts initialised to zero + # (mirroring how repflows.forward feeds the wrapper). + rng = np.random.default_rng(123) + g1_np = rng.normal(size=(nall, n_dim)).astype( + np.float32 if dtype == torch.float32 else np.float64, + ) + g1_np[nloc:] = 0.0 + g1 = torch.tensor(g1_np, dtype=dtype, requires_grad=True) + + out = torch.ops.deepmd_export.border_op( + comm[0], + comm[1], + comm[2], + comm[3], + comm[4], + g1, + comm[5], + comm[6], + comm[7], + ) + # Sum so the upstream grad is all-ones at every position. + loss = out.sum() + (grad_in,) = torch.autograd.grad(loss, g1, create_graph=False) + + # Reference for LOCAL rows only: forward sets + # ``out[nloc + i] = g1[sendlist[i]]`` for each ghost slot i and + # passes local rows through. With ``loss = out.sum()`` the + # upstream grad is ones everywhere, so each local row k receives + # 1 (from ``out[k] = g1[k]``) plus 1 for every ghost slot that + # references k via ``sendlist``. + expected_local = torch.ones(nloc, n_dim, dtype=dtype) + for s in sendlist_indices: + expected_local[int(s)] += 1.0 + rtol, atol = (0.0, 1e-5) if dtype == torch.float32 else (0.0, 1e-12) + np.testing.assert_allclose( + grad_in[:nloc].numpy(), + expected_local.numpy(), + atol=atol, + rtol=rtol, + ) + # Ghost rows of grad_in are not semantically meaningful: in + # production the wrapper's input is ``F.pad(node_ebd, value=0)`` + # so the ghost-row gradient is consumed by ``F.pad``'s backward + # (which drops it). The C++ backward leaves them as the upstream + # grad (here, ones), but we don't assert on it.