Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@ frozen_model.*

# Test system directories
system/
*.expected
17 changes: 9 additions & 8 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ def __init__(
# only use_timestep when skip connection is established.
use_timestep = use_timestep and (num_out == num_in or num_out == num_in * 2)
rng = np.random.default_rng(seed)
scale_factor = 1.0 / np.sqrt(num_out + num_in)
self.w = rng.normal(size=(num_in, num_out), scale=scale_factor).astype(prec)
self.b = (
rng.normal(size=(num_out,), scale=scale_factor).astype(prec)
if bias
else None
)
# Match deepmd/pt MLPLayer._default_normal_init: w uses Glorot scaling,
# but b is N(0, 1) (unscaled) and idt is N(0.1, 0.001). Using Glorot
# scaling for b/idt empirically slows convergence on ResNet-style
# fitting nets because residual layers start near identity (idt~0).
Comment thread
wanghan-iapcm marked this conversation as resolved.
self.w = rng.normal(
size=(num_in, num_out), scale=1.0 / np.sqrt(num_out + num_in)
).astype(prec)
self.b = rng.normal(size=(num_out,), scale=1.0).astype(prec) if bias else None
self.idt = (
rng.normal(size=(num_out,), scale=scale_factor).astype(prec)
(rng.normal(size=(num_out,), scale=0.001) + 0.1).astype(prec)
if use_timestep
else None
)
Expand Down
93 changes: 93 additions & 0 deletions source/api_cc/tests/expected_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <fstream>
#include <map>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

namespace deepmd_test {

// Loader for sidecar reference files written by
// `gen_common.write_expected_ref`.
//
// File format:
// # auto-generated -- do not edit
// [case_name_1]
// array_name_1 N
// v0
// v1
// ...
// array_name_2 M
// ...
//
// [case_name_2]
// ...
//
// Lines beginning with '#' or empty lines are ignored.
class ExpectedRef {
public:
// Parse `path`. Throws std::runtime_error on malformed input.
void load(const std::string& path) {
std::ifstream in(path);
if (!in) {
throw std::runtime_error("ExpectedRef: cannot open " + path);
}
sections_.clear();
std::string line;
std::string current_section;
while (std::getline(in, line)) {
if (line.empty() || line[0] == '#') {
continue;
}
if (line.front() == '[' && line.back() == ']') {
current_section = line.substr(1, line.size() - 2);
continue;
}
// "<key> <count>" header — followed by `count` numeric lines.
if (current_section.empty()) {
throw std::runtime_error("ExpectedRef: array '" + line +
"' before any [section]");
}
std::istringstream iss(line);
std::string key;
std::size_t n = 0;
if (!(iss >> key >> n)) {
throw std::runtime_error("ExpectedRef: bad header line: " + line);
}
std::vector<double> values;
values.reserve(n);
for (std::size_t i = 0; i < n; ++i) {
if (!std::getline(in, line)) {
throw std::runtime_error("ExpectedRef: unexpected EOF in '" + key +
"'");
}
values.push_back(std::stod(line));
}
sections_[current_section][key] = std::move(values);
}
}

// Get array of `key` from `case_name`. Throws if missing.
template <typename T = double>
std::vector<T> get(const std::string& case_name,
const std::string& key) const {
auto sit = sections_.find(case_name);
if (sit == sections_.end()) {
throw std::runtime_error("ExpectedRef: missing case '" + case_name + "'");
}
auto kit = sit->second.find(key);
if (kit == sit->second.end()) {
throw std::runtime_error("ExpectedRef: missing array '" + key +
"' in case '" + case_name + "'");
}
return std::vector<T>(kit->second.begin(), kit->second.end());
}

private:
std::map<std::string, std::map<std::string, std::vector<double>>> sections_;
};

} // namespace deepmd_test
113 changes: 26 additions & 87 deletions source/api_cc/tests/test_deeppot_a_fparam_aparam_nframes_ptexpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@

#include "DeepPot.h"
#include "DeepPotPTExpt.h"
#include "expected_ref.h"
#include "test_utils.h"

namespace {
constexpr const char* kRefPath = "../../tests/infer/fparam_aparam.expected";
constexpr const char* kModelPath = "../../tests/infer/fparam_aparam.pt2";
} // namespace

template <class VALUETYPE>
class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
protected:
Expand All @@ -27,88 +33,9 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
std::vector<VALUETYPE> aparam = {
0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028,
0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028};
// Same reference values as single-frame, duplicated for 2 frames
std::vector<VALUETYPE> expected_e = {
-1.038271223729636539e-01, -7.285433579124989123e-02,
-9.467600492266425860e-02, -1.467050207422957442e-01,
-7.660561676973243195e-02, -7.277296000253175023e-02,
-1.038271223729636539e-01, -7.285433579124989123e-02,
-9.467600492266425860e-02, -1.467050207422957442e-01,
-7.660561676973243195e-02, -7.277296000253175023e-02};
std::vector<VALUETYPE> expected_f = {
6.622266941151369601e-02, 5.278739714221529489e-02,
2.265728009692277028e-02, -2.606048291367509331e-02,
-4.538812303131847109e-02, 1.058247419681241676e-02,
1.679392617013223121e-01, -2.257826240741929533e-03,
-4.490146347357203138e-02, -1.148364179422036724e-01,
-1.169790528013799069e-02, 6.140403441496700837e-02,
-8.078778123309421355e-02, -5.838879041789352825e-02,
6.773641084621376263e-02, -1.247724902386305318e-02,
6.494524782787665373e-02, -1.174787360813439457e-01,
6.622266941151369601e-02, 5.278739714221529489e-02,
2.265728009692277028e-02, -2.606048291367509331e-02,
-4.538812303131847109e-02, 1.058247419681241676e-02,
1.679392617013223121e-01, -2.257826240741929533e-03,
-4.490146347357203138e-02, -1.148364179422036724e-01,
-1.169790528013799069e-02, 6.140403441496700837e-02,
-8.078778123309421355e-02, -5.838879041789352825e-02,
6.773641084621376263e-02, -1.247724902386305318e-02,
6.494524782787665373e-02, -1.174787360813439457e-01};
std::vector<VALUETYPE> expected_v = {
-1.589185601903579381e-01, 2.586167090689088510e-03,
-1.575150812458056548e-04, -1.855360549216640564e-02,
1.949822308966445150e-02, -1.006552178977542650e-02,
3.177030388421490936e-02, 1.714350280402104215e-03,
-1.290389705296313833e-03, -8.553511587973079699e-02,
-5.654638208496251539e-03, -1.286955066237439882e-02,
2.464156699303176462e-02, -2.398203243424212178e-02,
-1.957110698882909630e-02, 2.233493653505165544e-02,
6.107843889444162372e-03, 1.707076397717688723e-03,
-1.653994136896924094e-01, 3.894358809712639147e-02,
-2.169596032233910010e-02, 6.819702786556020371e-03,
-5.018240707559744503e-03, 2.640663592968431426e-03,
-1.985295554050418160e-03, -3.638422207618969423e-02,
2.342932709960221863e-02, -8.501331666888653493e-02,
-2.181253119706856591e-03, 4.311299629418858387e-03,
-1.910329576491436726e-03, -1.808810428459609043e-03,
-1.540075460017477360e-03, -1.173703527688202929e-02,
-2.596307050960845741e-03, 6.705026635782097323e-03,
-9.038454847872562370e-02, 3.011717694088476838e-02,
-5.083053967307901710e-02, -2.951212926932282599e-03,
2.342446057919112673e-02, -4.091208178777860222e-02,
-1.648470670751139844e-02, -2.872262362355524484e-02,
4.763925761561256522e-02, -8.300037376164930147e-02,
1.020429200603871836e-03, -1.026734257188876599e-03,
5.678534821710372327e-02, 1.273635858276599142e-02,
-1.530143401888291177e-02, -1.061672032476311256e-01,
-2.486859787145567074e-02, 2.875323543588798395e-02,
-1.589185601903579381e-01, 2.586167090689088510e-03,
-1.575150812458056548e-04, -1.855360549216640564e-02,
1.949822308966445150e-02, -1.006552178977542650e-02,
3.177030388421490936e-02, 1.714350280402104215e-03,
-1.290389705296313833e-03, -8.553511587973079699e-02,
-5.654638208496251539e-03, -1.286955066237439882e-02,
2.464156699303176462e-02, -2.398203243424212178e-02,
-1.957110698882909630e-02, 2.233493653505165544e-02,
6.107843889444162372e-03, 1.707076397717688723e-03,
-1.653994136896924094e-01, 3.894358809712639147e-02,
-2.169596032233910010e-02, 6.819702786556020371e-03,
-5.018240707559744503e-03, 2.640663592968431426e-03,
-1.985295554050418160e-03, -3.638422207618969423e-02,
2.342932709960221863e-02, -8.501331666888653493e-02,
-2.181253119706856591e-03, 4.311299629418858387e-03,
-1.910329576491436726e-03, -1.808810428459609043e-03,
-1.540075460017477360e-03, -1.173703527688202929e-02,
-2.596307050960845741e-03, 6.705026635782097323e-03,
-9.038454847872562370e-02, 3.011717694088476838e-02,
-5.083053967307901710e-02, -2.951212926932282599e-03,
2.342446057919112673e-02, -4.091208178777860222e-02,
-1.648470670751139844e-02, -2.872262362355524484e-02,
4.763925761561256522e-02, -8.300037376164930147e-02,
1.020429200603871836e-03, -1.026734257188876599e-03,
5.678534821710372327e-02, 1.273635858276599142e-02,
-1.530143401888291177e-02, -1.061672032476311256e-01,
-2.486859787145567074e-02, 2.875323543588798395e-02};
std::vector<VALUETYPE> expected_e;
std::vector<VALUETYPE> expected_f;
std::vector<VALUETYPE> expected_v;
int natoms;
int nframes = 2;
std::vector<double> expected_tot_e;
Expand All @@ -118,22 +45,34 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {

static void SetUpTestSuite() {
#if defined(BUILD_PYTORCH) && BUILD_PT_EXPT
dp.init("../../tests/infer/fparam_aparam.pt2");
dp.init(kModelPath);
#endif
}

void SetUp() override {
#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
#endif
deepmd_test::ExpectedRef ref;
ref.load(kRefPath);
auto e_single = ref.get<VALUETYPE>("default", "expected_e");
auto f_single = ref.get<VALUETYPE>("default", "expected_f");
auto v_single = ref.get<VALUETYPE>("default", "expected_v");
// Replicate single-frame reference for nframes batched inference.
expected_e.reserve(nframes * e_single.size());
expected_f.reserve(nframes * f_single.size());
expected_v.reserve(nframes * v_single.size());
for (int kk = 0; kk < nframes; ++kk) {
expected_e.insert(expected_e.end(), e_single.begin(), e_single.end());
expected_f.insert(expected_f.end(), f_single.begin(), f_single.end());
expected_v.insert(expected_v.end(), v_single.begin(), v_single.end());
}

natoms = expected_e.size() / nframes;
EXPECT_EQ(nframes * natoms * 3, expected_f.size());
EXPECT_EQ(nframes * natoms * 9, expected_v.size());
expected_tot_e.resize(nframes);
expected_tot_v.resize(static_cast<size_t>(nframes) * 9);
std::fill(expected_tot_e.begin(), expected_tot_e.end(), 0.);
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
expected_tot_e.assign(nframes, 0.);
expected_tot_v.assign(static_cast<size_t>(nframes) * 9, 0.);
for (int kk = 0; kk < nframes; ++kk) {
for (int ii = 0; ii < natoms; ++ii) {
expected_tot_e[kk] += expected_e[kk * natoms + ii];
Expand Down
64 changes: 17 additions & 47 deletions source/api_cc/tests/test_deeppot_a_fparam_aparam_pt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
#include <vector>

#include "DeepPot.h"
#include "expected_ref.h"
#include "neighbor_list.h"
#include "test_utils.h"

// 1e-10 cannot pass; unclear bug or not
#undef EPSILON
#define EPSILON (std::is_same<VALUETYPE, double>::value ? 1e-7 : 1e-4)

namespace {
constexpr const char* kRefPath = "../../tests/infer/fparam_aparam.expected";
constexpr const char* kModelPath = "../../tests/infer/fparam_aparam.pth";
} // namespace

template <class VALUETYPE>
class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
protected:
Expand All @@ -28,50 +34,9 @@ class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
std::vector<VALUETYPE> fparam = {0.25852028};
std::vector<VALUETYPE> aparam = {0.25852028, 0.25852028, 0.25852028,
0.25852028, 0.25852028, 0.25852028};
// Generated by source/tests/infer/gen_fparam_aparam.py
// (from pre-committed fparam_aparam_default.pth, type_one_side=True)
std::vector<VALUETYPE> expected_e = {
-1.038271223729636539e-01, -7.285433579124989123e-02,
-9.467600492266425860e-02, -1.467050207422957442e-01,
-7.660561676973243195e-02, -7.277296000253175023e-02};
std::vector<VALUETYPE> expected_f = {
6.622266941151369601e-02, 5.278739714221529489e-02,
2.265728009692277028e-02, -2.606048291367509331e-02,
-4.538812303131847109e-02, 1.058247419681241676e-02,
1.679392617013223121e-01, -2.257826240741929533e-03,
-4.490146347357203138e-02, -1.148364179422036724e-01,
-1.169790528013799069e-02, 6.140403441496700837e-02,
-8.078778123309421355e-02, -5.838879041789352825e-02,
6.773641084621376263e-02, -1.247724902386305318e-02,
6.494524782787665373e-02, -1.174787360813439457e-01};
std::vector<VALUETYPE> expected_v = {
-1.589185601903579381e-01, 2.586167090689088510e-03,
-1.575150812458056548e-04, -1.855360549216640564e-02,
1.949822308966445150e-02, -1.006552178977542650e-02,
3.177030388421490936e-02, 1.714350280402104215e-03,
-1.290389705296313833e-03, -8.553511587973079699e-02,
-5.654638208496251539e-03, -1.286955066237439882e-02,
2.464156699303176462e-02, -2.398203243424212178e-02,
-1.957110698882909630e-02, 2.233493653505165544e-02,
6.107843889444162372e-03, 1.707076397717688723e-03,
-1.653994136896924094e-01, 3.894358809712639147e-02,
-2.169596032233910010e-02, 6.819702786556020371e-03,
-5.018240707559744503e-03, 2.640663592968431426e-03,
-1.985295554050418160e-03, -3.638422207618969423e-02,
2.342932709960221863e-02, -8.501331666888653493e-02,
-2.181253119706856591e-03, 4.311299629418858387e-03,
-1.910329576491436726e-03, -1.808810428459609043e-03,
-1.540075460017477360e-03, -1.173703527688202929e-02,
-2.596307050960845741e-03, 6.705026635782097323e-03,
-9.038454847872562370e-02, 3.011717694088476838e-02,
-5.083053967307901710e-02, -2.951212926932282599e-03,
2.342446057919112673e-02, -4.091208178777860222e-02,
-1.648470670751139844e-02, -2.872262362355524484e-02,
4.763925761561256522e-02, -8.300037376164930147e-02,
1.020429200603871836e-03, -1.026734257188876599e-03,
5.678534821710372327e-02, 1.273635858276599142e-02,
-1.530143401888291177e-02, -1.061672032476311256e-01,
-2.486859787145567074e-02, 2.875323543588798395e-02};
std::vector<VALUETYPE> expected_e;
std::vector<VALUETYPE> expected_f;
std::vector<VALUETYPE> expected_v;
int natoms;
double expected_tot_e;
std::vector<VALUETYPE> expected_tot_v;
Expand All @@ -82,14 +47,19 @@ class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
#ifndef BUILD_PYTORCH
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
#endif
dp.init("../../tests/infer/fparam_aparam.pth");
deepmd_test::ExpectedRef ref;
ref.load(kRefPath);
expected_e = ref.get<VALUETYPE>("default", "expected_e");
expected_f = ref.get<VALUETYPE>("default", "expected_f");
expected_v = ref.get<VALUETYPE>("default", "expected_v");

dp.init(kModelPath);

natoms = expected_e.size();
EXPECT_EQ(natoms * 3, expected_f.size());
EXPECT_EQ(natoms * 9, expected_v.size());
expected_tot_e = 0.;
expected_tot_v.resize(9);
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
expected_tot_v.assign(9, 0.);
for (int ii = 0; ii < natoms; ++ii) {
expected_tot_e += expected_e[ii];
}
Expand Down
Loading
Loading