Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions source/source_estate/module_pot/pot_ml_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ ML_EXX::ML_EXX()

ML_EXX::~ML_EXX(){}

void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in)
void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running)
{
torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble));
auto output = torch::get_default_dtype();
std::cout << "Default type: " << output << std::endl;
ofs_running << " Default type: " << output << std::endl;

this->set_device(inp.of_ml_device);
this->set_device(inp.of_ml_device, ofs_running);

this->nx = rho_basis_in->nrxx;
this->nx_tot = rho_basis_in->nrxx;
Expand All @@ -46,17 +46,18 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
inp.of_ml_tanh_pnl,
inp.of_ml_tanh_qnl,
inp.of_ml_tanhp_nl,
inp.of_ml_tanhq_nl);
inp.of_ml_tanhq_nl,
ofs_running);

std::cout << "ninput = " << this->ninput << std::endl;
ofs_running << "ninput = " << this->ninput << std::endl;

if (PARAM.inp.ml_exx)
{
int nnode = 100;
int nlayer = 3;
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device);
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running);
torch::load(this->nn, "net.pt", this->device_type);
std::cout << "load net done" << std::endl;
ofs_running << "load net done" << std::endl;
if (PARAM.inp.of_ml_feg != 0)
{
torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type);
Expand All @@ -74,7 +75,7 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr<double>()[0];
}

std::cout << "feg_net_F = " << this->feg_net_F << std::endl;
ofs_running << "feg_net_F = " << this->feg_net_F << std::endl;
}
}

Expand All @@ -88,8 +89,24 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
this->chi_pnl = inp.of_ml_chi_pnl;
this->chi_qnl = inp.of_ml_chi_qnl;

this->cal_tool->set_para(this->nx, inp.nelec, inp.of_tf_weight, inp.of_vw_weight, this->chi_p, this->chi_q,
this->chi_xi, this->chi_pnl, this->chi_qnl, this->nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, this->dV * rho_basis_in->nxyz, rho_basis_in);
this->cal_tool->set_para(
this->nx,
inp.nelec,
inp.of_tf_weight,
inp.of_vw_weight,
this->chi_p,
this->chi_q,
this->chi_xi,
this->chi_pnl,
this->chi_qnl,
this->nkernel,
inp.of_ml_kernel,
inp.of_ml_kernel_scaling,
inp.of_ml_yukawa_alpha,
inp.of_ml_kernel_file,
this->dV * rho_basis_in->nxyz,
rho_basis_in,
ofs_running);
}
}

Expand Down Expand Up @@ -177,7 +194,7 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba
* @param prho charge density
* @param pw_rho PW_Basis
*/
void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho)
void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running)
{
// for test =====================
std::vector<long unsigned int> cshape = {(long unsigned) this->nx};
Expand All @@ -192,7 +209,7 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw
for (int ir = 0; ir < this->nx; ++ir)
{
if (prho[0][ir] == 0.){
std::cout << "WARNING: rho = 0" << std::endl;
ofs_running << "WARNING: rho = 0" << std::endl;
}
};
// ==============================
Expand Down
11 changes: 6 additions & 5 deletions source/source_estate/module_pot/pot_ml_exx.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class ML_EXX : public ML_Base
ML_EXX();
virtual ~ML_EXX();

void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in);
void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running);

void ml_potential(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential);

// output all parameters
void generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff);
void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho);
void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running);

void init_data(
const int &nkernel,
Expand All @@ -40,7 +40,8 @@ class ML_EXX : public ML_Base
const std::vector<int> &of_ml_tanh_pnl,
const std::vector<int> &of_ml_tanh_qnl,
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl
const std::vector<int> &of_ml_tanhq_nl,
std::ostream& ofs_running
);

double ml_exx_energy = 0.0;
Expand All @@ -56,13 +57,13 @@ class PotML_EXX : public PotBase
this->dynamic_mode = true;
this->fixed_mode = false;

this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in);
this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in, GlobalV::ofs_running);
}
~PotML_EXX() {};

void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override
{
if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_);
if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_, GlobalV::ofs_running);
this->ml_exx.ml_potential(chg->rho, this->rho_basis_, v_eff);
}

Expand Down
3 changes: 2 additions & 1 deletion source/source_estate/module_pot/pot_ml_exx_label.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void ML_EXX::init_data(
const std::vector<int> &of_ml_tanh_pnl,
const std::vector<int> &of_ml_tanh_qnl,
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl
const std::vector<int> &of_ml_tanhq_nl,
std::ostream& ofs_running
)
{

Expand Down
3 changes: 2 additions & 1 deletion source/source_io/module_ctrl/ctrl_output_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
inp.of_ml_yukawa_alpha,
inp.of_ml_kernel_file,
ucell.omega,
pw_rho);
pw_rho,
GlobalV::ofs_running);

write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir,
stp.template get_psi_t<T, Device>(),
Expand Down
5 changes: 3 additions & 2 deletions source/source_io/module_ml/cal_mlkedf_descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void Cal_MLKEDF_Descriptors::set_para(
const std::vector<double> &yukawa_alpha,
const std::vector<std::string> &kernel_file,
const double &omega,
const ModulePW::PW_Basis *pw_rho
const ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running
)
{
this->nx = nx;
Expand All @@ -34,7 +35,7 @@ void Cal_MLKEDF_Descriptors::set_para(
this->kernel_scaling = kernel_scaling;
this->yukawa_alpha = yukawa_alpha;
this->kernel_file = kernel_file;
std::cout << "nkernel = " << nkernel << std::endl;
ofs_running << "nkernel = " << nkernel << std::endl;

if (PARAM.inp.of_wt_rho0 != 0)
{
Expand Down
3 changes: 2 additions & 1 deletion source/source_io/module_ml/cal_mlkedf_descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Cal_MLKEDF_Descriptors
const std::vector<double> &yukawa_alpha,
const std::vector<std::string> &kernel_file,
const double &omega,
const ModulePW::PW_Basis *pw_rho);
const ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running);
// get input parameters
void getGamma(const double * const *prho, std::vector<double> &rgamma);
void getP(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::vector<std::vector<double>> &pnablaRho, std::vector<double> &rp);
Expand Down
44 changes: 36 additions & 8 deletions source/source_pw/module_ofdft/kedf_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,40 @@ void KEDF_Manager::init(
{
if (this->ml_ == nullptr)
this->ml_ = new KEDF_ML();
this->ml_->set_para(pw_rho->nrxx, dV, nelec, inp.of_tf_weight, inp.of_vw_weight,
inp.of_ml_chi_p, inp.of_ml_chi_q, inp.of_ml_chi_xi, inp.of_ml_chi_pnl, inp.of_ml_chi_qnl,
inp.of_ml_nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling,
inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, inp.of_ml_gamma, inp.of_ml_p, inp.of_ml_q, inp.of_ml_tanhp, inp.of_ml_tanhq,
inp.of_ml_gammanl, inp.of_ml_pnl, inp.of_ml_qnl, inp.of_ml_xi, inp.of_ml_tanhxi,
inp.of_ml_tanhxi_nl, inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, inp.of_ml_tanhq_nl, inp.of_ml_device, pw_rho);
this->ml_->set_para(
pw_rho->nrxx,
dV,
nelec,
inp.of_tf_weight,
inp.of_vw_weight,
inp.of_ml_chi_p,
inp.of_ml_chi_q,
inp.of_ml_chi_xi,
inp.of_ml_chi_pnl,
inp.of_ml_chi_qnl,
inp.of_ml_nkernel,
inp.of_ml_kernel,
inp.of_ml_kernel_scaling,
inp.of_ml_yukawa_alpha,
inp.of_ml_kernel_file,
inp.of_ml_gamma,
inp.of_ml_p,
inp.of_ml_q,
inp.of_ml_tanhp,
inp.of_ml_tanhq,
inp.of_ml_gammanl,
inp.of_ml_pnl,
inp.of_ml_qnl,
inp.of_ml_xi,
inp.of_ml_tanhxi,
inp.of_ml_tanhxi_nl,
inp.of_ml_tanh_pnl,
inp.of_ml_tanh_qnl,
inp.of_ml_tanhp_nl,
inp.of_ml_tanhq_nl,
inp.of_ml_device,
pw_rho,
GlobalV::ofs_running);
}
#endif
}
Expand Down Expand Up @@ -239,8 +267,8 @@ double KEDF_Manager::get_energy() const
kinetic_energy += this->ml_->ml_energy;
if (this->ml_->ml_energy >= this->tf_->tf_energy)
{
std::cout << "WARNING: ML >= TF" << std::endl;
std::cout << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl;
GlobalV::ofs_running << "WARNING: ML >= TF" << std::endl;
GlobalV::ofs_running << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl;
}
}
#endif
Expand Down
20 changes: 11 additions & 9 deletions source/source_pw/module_ofdft/kedf_ml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ void KEDF_ML::set_para(
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl,
const std::string device_inpt,
ModulePW::PW_Basis *pw_rho
ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running
)
{
torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble));
auto output = torch::get_default_dtype();
std::cout << "Default type: " << output << std::endl;
ofs_running << " Default type: " << output << std::endl;

this->set_device(device_inpt);
this->set_device(device_inpt, ofs_running);

this->nx = nx;
this->nx_tot = nx;
Expand All @@ -68,17 +69,18 @@ void KEDF_ML::set_para(
of_ml_tanh_pnl,
of_ml_tanh_qnl,
of_ml_tanhp_nl,
of_ml_tanhq_nl);
of_ml_tanhq_nl,
ofs_running);

std::cout << "ninput = " << ninput << std::endl;
ofs_running << "ninput = " << ninput << std::endl;

if (PARAM.inp.of_kinetic == "ml")
{
int nnode = 100;
int nlayer = 3;
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device);
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running);
torch::load(this->nn, "net.pt", this->device_type);
std::cout << "load net done" << std::endl;
ofs_running << "load net done" << std::endl;
if (PARAM.inp.of_ml_feg != 0)
{
torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type);
Expand All @@ -96,7 +98,7 @@ void KEDF_ML::set_para(
this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr<double>()[0];
}

std::cout << "feg_net_F = " << this->feg_net_F << std::endl;
ofs_running << "feg_net_F = " << this->feg_net_F << std::endl;
}
}

Expand All @@ -111,7 +113,7 @@ void KEDF_ML::set_para(
this->chi_qnl = chi_qnl;

this->cal_tool->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q,
chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho);
chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho, ofs_running);
}
}

Expand Down
6 changes: 4 additions & 2 deletions source/source_pw/module_ofdft/kedf_ml.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class KEDF_ML : public ML_Base
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl,
const std::string device_inpt,
ModulePW::PW_Basis *pw_rho);
ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running);

double get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho);
// double get_energy_density(const double * const *prho, int is, int ir, ModulePW::PW_Basis *pw_rho);
Expand Down Expand Up @@ -78,7 +79,8 @@ class KEDF_ML : public ML_Base
const std::vector<int> &of_ml_tanh_pnl,
const std::vector<int> &of_ml_tanh_qnl,
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl
const std::vector<int> &of_ml_tanhq_nl,
std::ostream& ofs_running
);
};

Expand Down
7 changes: 4 additions & 3 deletions source/source_pw/module_ofdft/kedf_ml_label.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void KEDF_ML::init_data(
const std::vector<int> &of_ml_tanh_pnl,
const std::vector<int> &of_ml_tanh_qnl,
const std::vector<int> &of_ml_tanhp_nl,
const std::vector<int> &of_ml_tanhq_nl
const std::vector<int> &of_ml_tanhq_nl,
std::ostream& ofs_running
)
{

Expand Down Expand Up @@ -152,8 +153,8 @@ void KEDF_ML::init_data(
this->descriptor2kernel[descriptor_type[i]].push_back(kernel_index[i]);
this->descriptor2index[descriptor_type[i]].push_back(i);
}
std::cout << "descriptor2index " << descriptor2index << std::endl;
std::cout << "descriptor2kernel " << descriptor2kernel << std::endl;
ofs_running << "descriptor2index " << descriptor2index << std::endl;
ofs_running << "descriptor2kernel " << descriptor2kernel << std::endl;

this->ml_gamma = this->descriptor2index["gamma"].size() > 0;
this->ml_p = this->descriptor2index["p"].size() > 0;
Expand Down
10 changes: 5 additions & 5 deletions source/source_pw/module_ofdft/ml_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@ ML_Base::~ML_Base()
if (this->cal_tool) delete this->cal_tool;
}

void ML_Base::set_device(std::string device_inpt)
void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_running)
{
if (device_inpt == "cpu")
{
std::cout << "------------------- Running NN on CPU -------------------" << std::endl;
ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl;
this->device_type = torch::kCPU;
}
else if (device_inpt == "gpu")
{
if (torch::cuda::cudnn_is_available())
{
std::cout << "------------------- Running NN on GPU -------------------" << std::endl;
ofs_running << "------------------- Running Neural Network on GPU -------------------" << std::endl;
this->device_type = torch::kCUDA;
}
else
{
std::cout << "--------------- Warning: GPU is unaviable ---------------" << std::endl;
std::cout << "------------------- Running NN on CPU -------------------" << std::endl;
ofs_running << "--------------- Warning: GPU is unaviable ---------------" << std::endl;
ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl;
this->device_type = torch::kCPU;
}
}
Expand Down
2 changes: 1 addition & 1 deletion source/source_pw/module_ofdft/ml_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ML_Base
~ML_Base();

// Common Interface
void set_device(std::string device_inpt);
void set_device(const std::string& device_inpt, std::ostream& ofs_running);

// Tools
void loadVector(std::string filename, std::vector<double> &data);
Expand Down
Loading
Loading