Skip to content
32 changes: 27 additions & 5 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,43 @@ def get_backends_by_feature(
if backend.features & feature
}

@classmethod
def match_filename(cls, filename: str) -> int:
"""Specificity score of this backend's claim on ``filename``.

Returns a positive integer if this backend can handle the file
(higher = stronger / more specific claim), or 0 otherwise.

The default implementation returns 1 when ``filename`` ends with
one of ``cls.suffixes``. Backends with overlapping suffixes can
override this to disambiguate (e.g. by inspecting file content)
and return a higher score so they win the tie.
"""
fname = str(filename).lower()
return 1 if any(fname.endswith(s) for s in cls.suffixes) else 0

@staticmethod
def detect_backend_by_model(filename: str) -> type["Backend"]:
"""Detect the backend of the given model file.

Calls ``match_filename`` on every registered backend and returns
the one with the highest specificity score (>0).

Parameters
----------
filename : str
The model file name
"""
filename = str(filename).lower()
best: type[Backend] | None = None
best_score = 0
for backend in Backend.get_backends().values():
for suffix in backend.suffixes:
if filename.endswith(suffix):
return backend
raise ValueError(f"Cannot detect the backend of the model file {filename}.")
score = backend.match_filename(filename)
if score > best_score:
best_score = score
best = backend
if best is None:
raise ValueError(f"Cannot detect the backend of the model file {filename}.")
return best

class Feature(Flag):
"""Feature flag to indicate whether the backend supports certain features."""
Expand Down
41 changes: 41 additions & 0 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,47 @@ class PyTorchExportableBackend(Backend):
suffixes: ClassVar[list[str]] = [".pte", ".pt2"]
"""The suffixes of the backend."""

@classmethod
def match_filename(cls, filename: str) -> int:
"""Recognise pt_expt-trained `.pt` checkpoints in addition to `.pt2`/`.pte`.

Returns
-------
- 1 for the regular `.pte` / `.pt2` suffixes (default behaviour).
- 2 for `.pt` files whose state-dict uses pt_expt's dpmodel
parameter naming (`.w`/`.b`); this outranks the legacy pt
backend's default suffix score (1) so pt_expt-trained `.pt`
checkpoints route here, while genuine pt-trained `.pt` files
(which use `.matrix`/`.bias`) keep going to the pt backend.
- 0 otherwise.
"""
score = super().match_filename(filename)
if score:
return score
fname = str(filename).lower()
if not fname.endswith(".pt"):
return 0
try:
import torch

# weights_only=True avoids unpickling arbitrary code from an
# untrusted .pt — sniffing only needs the dict keys.
sd = torch.load(filename, map_location="cpu", weights_only=True)
except Exception:
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
# Not a valid torch archive (corrupt file, wrong format, or a
# weights_only=True restriction trip). Surrender the claim so
# the dispatcher falls back to the default suffix match — pt's
# default score (1) will pick up the file under `dp --pt`.
return 0
if isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
keys = list(sd.keys()) if hasattr(sd, "keys") else []
has_pt_expt = any(k.endswith(".w") or k.endswith(".b") for k in keys)
has_pt = any(k.endswith(".matrix") or k.endswith(".bias") for k in keys)
if has_pt_expt and not has_pt:
return 2
return 0

def is_available(self) -> bool:
"""Check if the backend is available.

Expand Down
193 changes: 192 additions & 1 deletion deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,16 @@ def __init__(

if self._is_pt2:
self._load_pt2(model_file)
else:
elif model_file.endswith(".pte"):
self._load_pte(model_file)
elif model_file.endswith(".pt"):
self._load_pt(model_file, head=kwargs.get("head"))
else:
raise ValueError(
f"Unsupported model file '{model_file}' for the pt_expt "
"backend: expected `.pt2` / `.pte` (deployable archives) or "
"`.pt` (training checkpoint)."
)

if isinstance(auto_batch_size, bool):
if auto_batch_size:
Expand Down Expand Up @@ -206,6 +214,189 @@ def _load_pt2(self, model_file: str) -> None:
self._pt2_runner = aoti_load_package(model_file)
self.exported_module = None

def _load_pt(self, model_file: str, head: str | None = None) -> None:
"""Load a `.pt` training checkpoint (eager mode, no torch.export)."""
from copy import (
deepcopy,
)

from deepmd.pt_expt.model import (
get_model,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)

# Match the training resume path (training.py:712) — weights_only=True
# avoids unpickling arbitrary code from untrusted checkpoints.
state_dict = torch.load(model_file, map_location=DEVICE, weights_only=True)
if isinstance(state_dict, dict) and "model" in state_dict:
state_dict = state_dict["model"]
extra = state_dict.get("_extra_state") if isinstance(state_dict, dict) else None
if not (isinstance(extra, dict) and "model_params" in extra):
raise ValueError(
f"Invalid .pt file '{model_file}': expected a pt_expt training "
"checkpoint containing '_extra_state' with nested "
"'model_params'. If this is a legacy pt-trained checkpoint, "
"load it with `dp --pt` instead. If this is an exported model, "
"use a `.pte` or `.pt2` artifact."
)
model_params = deepcopy(extra["model_params"])

if "model_dict" in model_params:
# Multi-task: pick the requested head (defaults to "Default" if present).
heads = list(model_params["model_dict"].keys())
if head is None:
if "Default" in heads:
head = "Default"
else:
raise ValueError(
f"Multi-task checkpoint '{model_file}' has heads "
f"{heads}; pass --head to select one."
)
if head not in heads:
raise ValueError(
f"Head '{head}' not found in checkpoint '{model_file}'. "
f"Available heads: {heads}."
)
head_params = model_params["model_dict"][head]
# Restrict state_dict to the chosen head and rename to "Default".
# No tensor cloning needed: load_state_dict copies into the
# destination parameters and does not mutate the input dict.
head_state = {"_extra_state": state_dict["_extra_state"]}
prefix = f"model.{head}."
for key, value in state_dict.items():
if key.startswith(prefix):
head_state[key.replace(prefix, "model.Default.")] = value
state_dict = head_state
model_params = head_params

model = get_model(deepcopy(model_params)).to(DEVICE)

# Strip the `_CompiledModel` wrapper that pt_expt training applies
# after compilation (training.py:996). The saved state_dict has
# `model.Default.original_model.X` keys (the real weights) plus
# `model.Default.compiled_forward_lower._orig_mod._param_constant*`
# / `_tensor_constant*` keys (graph constants baked into the
# compiled forward — duplicates of the real weights, useless for
# eager inference). Drop the latter and unwrap the former.
cleaned: dict[str, Any] = {}
compiled_marker = ".compiled_forward_lower."
wrapper_infix = ".original_model."
for key, value in state_dict.items():
if compiled_marker in key:
continue
if wrapper_infix in key:
key = key.replace(wrapper_infix, ".", 1)
cleaned[key] = value
state_dict = cleaned

# Load weights into a {"Default": model} wrapper to match the
# `model.Default.*` key prefix used in the saved state_dict.
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)

wrapper = ModelWrapper(model)
wrapper.load_state_dict(state_dict)
model = wrapper.model["Default"].eval()

self._dpmodel = model
self._is_spin = (
model_params.get("type") == "spin_ener" or "spin" in model_params
)
self.rcut = model.get_rcut()
self.type_map = model.get_type_map()
if self._is_spin:
self._model_output_def = ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
"energy",
shape=[1],
reducible=True,
r_differentiable=True,
c_differentiable=True,
atomic=True,
magnetic=True,
)
]
)
)
else:
self._model_output_def = ModelOutputDef(model.atomic_output_def())
self._model_def_script = model_params
# Populate metadata so eval helpers (e.g. default_fparam fallback)
# behave the same as the .pt2/.pte path. Mirrors the fields that
# `_collect_metadata` writes into metadata.json.
self.metadata = {
"type_map": model.get_type_map(),
"rcut": model.get_rcut(),
"sel": model.get_sel(),
"dim_fparam": model.get_dim_fparam(),
"dim_aparam": model.get_dim_aparam(),
"mixed_types": model.mixed_types(),
"has_default_fparam": model.has_default_fparam(),
"default_fparam": model.get_default_fparam(),
"is_spin": self._is_spin,
}
if self._is_spin:
self.metadata["ntypes_spin"] = model.spin.get_ntypes_spin()
self.metadata["use_spin"] = [bool(v) for v in model.spin.use_spin]

# Eager runner with the same signature as the .pt2/.pte exported module.
# Use forward_common_lower (not forward_lower) to match the export-time
# output keys ("energy", "energy_redu", "energy_derv_r", ...) that
# communicate_extended_output downstream consumes.
# Non-spin: (ext_coord, ext_atype, nlist, mapping, fparam, aparam)
# Spin: (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam)
if self._is_spin:

def _eager_runner_spin(
ext_coord: torch.Tensor,
ext_atype: torch.Tensor,
ext_spin: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
ext_coord = ext_coord.detach().requires_grad_(True)
return model.forward_common_lower(
ext_coord,
ext_atype,
ext_spin,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=True,
)

self.exported_module = _eager_runner_spin
else:

def _eager_runner(
ext_coord: torch.Tensor,
ext_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
ext_coord = ext_coord.detach().requires_grad_(True)
return model.forward_common_lower(
ext_coord,
ext_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=True,
)

self.exported_module = _eager_runner

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.rcut
Expand Down
38 changes: 38 additions & 0 deletions deepmd/pt_expt/model/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
from deepmd.pt_expt.model.property_model import (
PropertyModel,
)
from deepmd.pt_expt.model.spin_ener_model import (
SpinEnergyModel,
)
from deepmd.utils.spin import (
Spin,
)


def _get_standard_model_components(
Expand Down Expand Up @@ -162,6 +168,36 @@ def get_linear_model(model_params: dict) -> BaseModel:
)


def get_spin_model(data: dict) -> SpinEnergyModel:
"""Build a pt_expt spin energy model from a config dictionary.

Mirrors :func:`deepmd.dpmodel.model.model.get_spin_model`: expands the
type map and descriptor sel for virtual spin atoms, then wraps the
backbone EnergyModel as a :class:`SpinEnergyModel`.
"""
data = copy.deepcopy(data)
data["type_map"] += [item + "_spin" for item in data["type_map"]]
spin = Spin(
use_spin=data["spin"]["use_spin"],
virtual_scale=data["spin"]["virtual_scale"],
)
pair_exclude_types = spin.get_pair_exclude_types(
exclude_types=data.get("pair_exclude_types", None)
)
data["pair_exclude_types"] = pair_exclude_types
data["descriptor"]["exclude_types"] = pair_exclude_types
atom_exclude_types = spin.get_atom_exclude_types(
exclude_types=data.get("atom_exclude_types", None)
)
data["atom_exclude_types"] = atom_exclude_types
if "env_protection" not in data["descriptor"]:
data["descriptor"]["env_protection"] = 1e-6
if data["descriptor"]["type"] in ["se_e2_a"]:
data["descriptor"]["sel"] += data["descriptor"]["sel"]
backbone_model = get_standard_model(data)
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_model(data: dict) -> BaseModel:
"""Get a model from a config dictionary.

Expand All @@ -172,6 +208,8 @@ def get_model(data: dict) -> BaseModel:
"""
model_type = data.get("type", "standard")
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
return get_standard_model(data)
elif model_type == "linear_ener":
return get_linear_model(data)
Expand Down
Loading
Loading