diff --git a/source/source_estate/module_pot/pot_ml_exx.cpp b/source/source_estate/module_pot/pot_ml_exx.cpp index 53393b1e33f..b5c1fbcd598 100644 --- a/source/source_estate/module_pot/pot_ml_exx.cpp +++ b/source/source_estate/module_pot/pot_ml_exx.cpp @@ -17,13 +17,13 @@ ML_EXX::ML_EXX() ML_EXX::~ML_EXX(){} -void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in) +void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running) { torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble)); auto output = torch::get_default_dtype(); - std::cout << "Default type: " << output << std::endl; + ofs_running << " Default type: " << output << std::endl; - this->set_device(inp.of_ml_device); + this->set_device(inp.of_ml_device, ofs_running); this->nx = rho_basis_in->nrxx; this->nx_tot = rho_basis_in->nrxx; @@ -46,17 +46,18 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, - inp.of_ml_tanhq_nl); + inp.of_ml_tanhq_nl, + ofs_running); - std::cout << "ninput = " << this->ninput << std::endl; + ofs_running << "ninput = " << this->ninput << std::endl; if (PARAM.inp.ml_exx) { int nnode = 100; int nlayer = 3; - this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device); + this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); torch::load(this->nn, "net.pt", this->device_type); - std::cout << "load net done" << std::endl; + ofs_running << "load net done" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -74,7 +75,7 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - std::cout << "feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << "feg_net_F = " << this->feg_net_F << std::endl; } } @@ -88,8 +89,24 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod this->chi_pnl = inp.of_ml_chi_pnl; this->chi_qnl = inp.of_ml_chi_qnl; - this->cal_tool->set_para(this->nx, inp.nelec, inp.of_tf_weight, inp.of_vw_weight, this->chi_p, this->chi_q, - this->chi_xi, this->chi_pnl, this->chi_qnl, this->nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, this->dV * rho_basis_in->nxyz, rho_basis_in); + this->cal_tool->set_para( + this->nx, + inp.nelec, + inp.of_tf_weight, + inp.of_vw_weight, + this->chi_p, + this->chi_q, + this->chi_xi, + this->chi_pnl, + this->chi_qnl, + this->nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, + this->dV * rho_basis_in->nxyz, + rho_basis_in, + ofs_running); } } @@ -177,7 +194,7 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba * @param prho charge density * @param pw_rho PW_Basis */ -void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho) +void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running) { // for test ===================== std::vector cshape = {(long unsigned) this->nx}; @@ -192,7 +209,7 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw for (int ir = 0; ir < this->nx; ++ir) { if (prho[0][ir] == 0.){ - std::cout << "WARNING: rho = 0" << std::endl; + ofs_running << "WARNING: rho = 0" << std::endl; } }; // ============================== diff --git a/source/source_estate/module_pot/pot_ml_exx.h b/source/source_estate/module_pot/pot_ml_exx.h index 5936add9050..a2a866801c2 100644 --- a/source/source_estate/module_pot/pot_ml_exx.h +++ b/source/source_estate/module_pot/pot_ml_exx.h @@ -16,13 +16,13 @@ class ML_EXX : public ML_Base ML_EXX(); virtual ~ML_EXX(); - void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in); + void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running); void ml_potential(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); // output all parameters void generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff); - void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho); + void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running); void init_data( const int &nkernel, @@ -40,7 +40,8 @@ class ML_EXX : public ML_Base const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ); double ml_exx_energy = 0.0; @@ -56,13 +57,13 @@ class PotML_EXX : public PotBase this->dynamic_mode = true; this->fixed_mode = false; - this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in); + this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in, GlobalV::ofs_running); } ~PotML_EXX() {}; void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override { - if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_); + if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_, GlobalV::ofs_running); this->ml_exx.ml_potential(chg->rho, this->rho_basis_, v_eff); } diff --git a/source/source_estate/module_pot/pot_ml_exx_label.cpp b/source/source_estate/module_pot/pot_ml_exx_label.cpp index 3908b7c5ef9..a073ce58832 100644 --- a/source/source_estate/module_pot/pot_ml_exx_label.cpp +++ b/source/source_estate/module_pot/pot_ml_exx_label.cpp @@ -40,7 +40,8 @@ void ML_EXX::init_data( const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ) { diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index f084854238b..580bc314efb 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -361,7 +361,8 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, ucell.omega, - pw_rho); + pw_rho, + GlobalV::ofs_running); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, stp.template get_psi_t(), diff --git a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp index b0b627fec23..4a1e3c86da6 100644 --- a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp +++ b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp @@ -19,7 +19,8 @@ void Cal_MLKEDF_Descriptors::set_para( const std::vector &yukawa_alpha, const std::vector &kernel_file, const double &omega, - const ModulePW::PW_Basis *pw_rho + const ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running ) { this->nx = nx; @@ -34,7 +35,7 @@ void Cal_MLKEDF_Descriptors::set_para( this->kernel_scaling = kernel_scaling; this->yukawa_alpha = yukawa_alpha; this->kernel_file = kernel_file; - std::cout << "nkernel = " << nkernel << std::endl; + ofs_running << "nkernel = " << nkernel << std::endl; if (PARAM.inp.of_wt_rho0 != 0) { diff --git a/source/source_io/module_ml/cal_mlkedf_descriptors.h b/source/source_io/module_ml/cal_mlkedf_descriptors.h index 2569d091c8a..7a0b70b69f3 100644 --- a/source/source_io/module_ml/cal_mlkedf_descriptors.h +++ b/source/source_io/module_ml/cal_mlkedf_descriptors.h @@ -38,7 +38,8 @@ class Cal_MLKEDF_Descriptors const std::vector &yukawa_alpha, const std::vector &kernel_file, const double &omega, - const ModulePW::PW_Basis *pw_rho); + const ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running); // get input parameters void getGamma(const double * const *prho, std::vector &rgamma); void getP(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::vector> &pnablaRho, std::vector &rp); diff --git a/source/source_pw/module_ofdft/kedf_manager.cpp b/source/source_pw/module_ofdft/kedf_manager.cpp index 8aaec2a8a4f..a2f270da50a 100644 --- a/source/source_pw/module_ofdft/kedf_manager.cpp +++ b/source/source_pw/module_ofdft/kedf_manager.cpp @@ -109,12 +109,40 @@ void KEDF_Manager::init( { if (this->ml_ == nullptr) this->ml_ = new KEDF_ML(); - this->ml_->set_para(pw_rho->nrxx, dV, nelec, inp.of_tf_weight, inp.of_vw_weight, - inp.of_ml_chi_p, inp.of_ml_chi_q, inp.of_ml_chi_xi, inp.of_ml_chi_pnl, inp.of_ml_chi_qnl, - inp.of_ml_nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, - inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, inp.of_ml_gamma, inp.of_ml_p, inp.of_ml_q, inp.of_ml_tanhp, inp.of_ml_tanhq, - inp.of_ml_gammanl, inp.of_ml_pnl, inp.of_ml_qnl, inp.of_ml_xi, inp.of_ml_tanhxi, - inp.of_ml_tanhxi_nl, inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, inp.of_ml_tanhq_nl, inp.of_ml_device, pw_rho); + this->ml_->set_para( + pw_rho->nrxx, + dV, + nelec, + inp.of_tf_weight, + inp.of_vw_weight, + inp.of_ml_chi_p, + inp.of_ml_chi_q, + inp.of_ml_chi_xi, + inp.of_ml_chi_pnl, + inp.of_ml_chi_qnl, + inp.of_ml_nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, + inp.of_ml_gamma, + inp.of_ml_p, + inp.of_ml_q, + inp.of_ml_tanhp, + inp.of_ml_tanhq, + inp.of_ml_gammanl, + inp.of_ml_pnl, + inp.of_ml_qnl, + inp.of_ml_xi, + inp.of_ml_tanhxi, + inp.of_ml_tanhxi_nl, + inp.of_ml_tanh_pnl, + inp.of_ml_tanh_qnl, + inp.of_ml_tanhp_nl, + inp.of_ml_tanhq_nl, + inp.of_ml_device, + pw_rho, + GlobalV::ofs_running); } #endif } @@ -239,8 +267,8 @@ double KEDF_Manager::get_energy() const kinetic_energy += this->ml_->ml_energy; if (this->ml_->ml_energy >= this->tf_->tf_energy) { - std::cout << "WARNING: ML >= TF" << std::endl; - std::cout << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; + GlobalV::ofs_running << "WARNING: ML >= TF" << std::endl; + GlobalV::ofs_running << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; } } #endif diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index 6b2643a7b58..ec3f9c1ac1d 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -38,14 +38,15 @@ void KEDF_ML::set_para( const std::vector &of_ml_tanhp_nl, const std::vector &of_ml_tanhq_nl, const std::string device_inpt, - ModulePW::PW_Basis *pw_rho + ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running ) { torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble)); auto output = torch::get_default_dtype(); - std::cout << "Default type: " << output << std::endl; + ofs_running << " Default type: " << output << std::endl; - this->set_device(device_inpt); + this->set_device(device_inpt, ofs_running); this->nx = nx; this->nx_tot = nx; @@ -68,17 +69,18 @@ void KEDF_ML::set_para( of_ml_tanh_pnl, of_ml_tanh_qnl, of_ml_tanhp_nl, - of_ml_tanhq_nl); + of_ml_tanhq_nl, + ofs_running); - std::cout << "ninput = " << ninput << std::endl; + ofs_running << "ninput = " << ninput << std::endl; if (PARAM.inp.of_kinetic == "ml") { int nnode = 100; int nlayer = 3; - this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device); + this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); torch::load(this->nn, "net.pt", this->device_type); - std::cout << "load net done" << std::endl; + ofs_running << "load net done" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -96,7 +98,7 @@ void KEDF_ML::set_para( this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - std::cout << "feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << "feg_net_F = " << this->feg_net_F << std::endl; } } @@ -111,7 +113,7 @@ void KEDF_ML::set_para( this->chi_qnl = chi_qnl; this->cal_tool->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q, - chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho); + chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho, ofs_running); } } diff --git a/source/source_pw/module_ofdft/kedf_ml.h b/source/source_pw/module_ofdft/kedf_ml.h index 202c6958f0c..d7acfb734ba 100644 --- a/source/source_pw/module_ofdft/kedf_ml.h +++ b/source/source_pw/module_ofdft/kedf_ml.h @@ -47,7 +47,8 @@ class KEDF_ML : public ML_Base const std::vector &of_ml_tanhp_nl, const std::vector &of_ml_tanhq_nl, const std::string device_inpt, - ModulePW::PW_Basis *pw_rho); + ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running); double get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho); // double get_energy_density(const double * const *prho, int is, int ir, ModulePW::PW_Basis *pw_rho); @@ -78,7 +79,8 @@ class KEDF_ML : public ML_Base const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ); }; diff --git a/source/source_pw/module_ofdft/kedf_ml_label.cpp b/source/source_pw/module_ofdft/kedf_ml_label.cpp index 100ad4c387e..0cce3d0cec3 100644 --- a/source/source_pw/module_ofdft/kedf_ml_label.cpp +++ b/source/source_pw/module_ofdft/kedf_ml_label.cpp @@ -38,7 +38,8 @@ void KEDF_ML::init_data( const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ) { @@ -152,8 +153,8 @@ void KEDF_ML::init_data( this->descriptor2kernel[descriptor_type[i]].push_back(kernel_index[i]); this->descriptor2index[descriptor_type[i]].push_back(i); } - std::cout << "descriptor2index " << descriptor2index << std::endl; - std::cout << "descriptor2kernel " << descriptor2kernel << std::endl; + ofs_running << "descriptor2index " << descriptor2index << std::endl; + ofs_running << "descriptor2kernel " << descriptor2kernel << std::endl; this->ml_gamma = this->descriptor2index["gamma"].size() > 0; this->ml_p = this->descriptor2index["p"].size() > 0; diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index 0bb09b441ac..54feecb261a 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -10,24 +10,24 @@ ML_Base::~ML_Base() if (this->cal_tool) delete this->cal_tool; } -void ML_Base::set_device(std::string device_inpt) +void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_running) { if (device_inpt == "cpu") { - std::cout << "------------------- Running NN on CPU -------------------" << std::endl; + ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU; } else if (device_inpt == "gpu") { if (torch::cuda::cudnn_is_available()) { - std::cout << "------------------- Running NN on GPU -------------------" << std::endl; + ofs_running << "------------------- Running Neural Network on GPU -------------------" << std::endl; this->device_type = torch::kCUDA; } else { - std::cout << "--------------- Warning: GPU is unaviable ---------------" << std::endl; - std::cout << "------------------- Running NN on CPU -------------------" << std::endl; + ofs_running << "--------------- Warning: GPU is unaviable ---------------" << std::endl; + ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU; } } diff --git a/source/source_pw/module_ofdft/ml_base.h b/source/source_pw/module_ofdft/ml_base.h index 41c08d22327..3f7b04d582c 100644 --- a/source/source_pw/module_ofdft/ml_base.h +++ b/source/source_pw/module_ofdft/ml_base.h @@ -20,7 +20,7 @@ class ML_Base ~ML_Base(); // Common Interface - void set_device(std::string device_inpt); + void set_device(const std::string& device_inpt, std::ostream& ofs_running); // Tools void loadVector(std::string filename, std::vector &data); diff --git a/source/source_pw/module_ofdft/nn_of.cpp b/source/source_pw/module_ofdft/nn_of.cpp index 5aa81069587..cfe48bf1d86 100644 --- a/source/source_pw/module_ofdft/nn_of.cpp +++ b/source/source_pw/module_ofdft/nn_of.cpp @@ -1,14 +1,14 @@ #include "nn_of.h" -NN_OFImpl::NN_OFImpl(int nrxx, int nrxx_vali, int ninpt, int nnode, int nlayer, torch::Device device) +NN_OFImpl::NN_OFImpl(int nrxx, int nrxx_vali, int ninpt, int nnode, int nlayer, torch::Device device, std::ostream& ofs_running) { this->nrxx = nrxx; this->nrxx_vali = nrxx_vali; this->ninpt = ninpt; this->nnode = nnode; - std::cout << "nnode = " << this->nnode << std::endl; + ofs_running << "nnode = " << this->nnode << std::endl; this->nlayer = nlayer; - std::cout << "nlayer = " << this->nlayer << std::endl; + ofs_running << "nlayer = " << this->nlayer << std::endl; this->nfc = nlayer + 1; this->inputs = torch::zeros({this->nrxx, this->ninpt}).to(device); diff --git a/source/source_pw/module_ofdft/nn_of.h b/source/source_pw/module_ofdft/nn_of.h index 6623ddcfc80..6566dcae191 100644 --- a/source/source_pw/module_ofdft/nn_of.h +++ b/source/source_pw/module_ofdft/nn_of.h @@ -11,7 +11,8 @@ struct NN_OFImpl:torch::nn::Module{ int ninpt, int nnode, int nlayer, - torch::Device device + torch::Device device, + std::ostream& ofs_running ); ~NN_OFImpl() { diff --git a/tests/integrate/Autotest.sh b/tests/integrate/Autotest.sh index dc466a096fc..b7e260afcd9 100755 --- a/tests/integrate/Autotest.sh +++ b/tests/integrate/Autotest.sh @@ -1,7 +1,7 @@ #!/bin/bash # ABACUS executable path -abacus=abacus +abacus=/home/510Group/2_abacus/abacus-mc/build_ml_para/abacus_ml_para # number of MPI processes np=4 nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS