Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ elseif(USE_ROCM)
)
endif()

# base library uses symbols from device library (memory_op, math_ylm_op)
target_link_libraries(base PUBLIC device)

if(ENABLE_COVERAGE)
add_coverage(driver)
endif()
3 changes: 3 additions & 0 deletions source/source_base/kernels/cuda/math_kernel_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ch
{
cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
if (cublas_handle == nullptr) {
CHECK_CUBLAS(cublasCreate(&cublas_handle));
}
CHECK_CUBLAS(cublasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (double2*)alpha, (double2*)a , lda, (double2*)b, ldb, (double2*)beta, (double2*)c, ldc));
}

Expand Down
2 changes: 2 additions & 0 deletions source/source_base/module_container/base/macros/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ struct GetTypeCuda<double>
{
static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64F;
};
#if CUDA_VERSION >= 11000
template <>
struct GetTypeCuda<int64_t>
{
static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64I;
};
#endif
template <>
struct GetTypeCuda<std::complex<float>>
{
Expand Down
59 changes: 56 additions & 3 deletions source/source_base/module_container/base/third_party/cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
namespace container {
namespace cuSolverConnector {

#if CUDA_VERSION >= 11000
// Generic API (CUDA 11.0+)
template <typename T>
static inline
void trtri (cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, T* A, const int& lda)
Expand All @@ -37,7 +39,7 @@ void trtri (cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& d
int h_info = 0;
int* d_info = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_info, sizeof(int)));
// Perform Cholesky decomposition
// Perform triangular matrix inversion
CHECK_CUSOLVER(cusolverDnXtrtri(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, GetTypeCuda<T>::cuda_data_type, reinterpret_cast<Type*>(A), n, d_work, d_lwork, h_work, h_lwork, d_info));
CHECK_CUDA(cudaMemcpy(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost));
if (h_info != 0) {
Expand All @@ -47,6 +49,57 @@ void trtri (cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& d
CHECK_CUDA(cudaFree(d_work));
CHECK_CUDA(cudaFree(d_info));
}
#else
// Legacy API fallback (CUDA < 11.0)
static inline void trtri(cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, float* A, const int& lda)
{
int lwork = 0;
CHECK_CUSOLVER(cusolverDnStrtri_bufferSize(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, A, lda, &lwork));
float* d_work = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_work, lwork * sizeof(float)));
int* d_info = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_info, sizeof(int)));
CHECK_CUSOLVER(cusolverDnStrtri(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, A, lda, d_work, lwork, d_info));
CHECK_CUDA(cudaFree(d_work));
CHECK_CUDA(cudaFree(d_info));
}
static inline void trtri(cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, double* A, const int& lda)
{
int lwork = 0;
CHECK_CUSOLVER(cusolverDnDtrtri_bufferSize(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, A, lda, &lwork));
double* d_work = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_work, lwork * sizeof(double)));
int* d_info = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_info, sizeof(int)));
CHECK_CUSOLVER(cusolverDnDtrtri(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, A, lda, d_work, lwork, d_info));
CHECK_CUDA(cudaFree(d_work));
CHECK_CUDA(cudaFree(d_info));
}
static inline void trtri(cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, std::complex<float>* A, const int& lda)
{
int lwork = 0;
CHECK_CUSOLVER(cusolverDnCtrtri_bufferSize(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, reinterpret_cast<cuComplex*>(A), lda, &lwork));
cuComplex* d_work = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_work, lwork * sizeof(cuComplex)));
int* d_info = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_info, sizeof(int)));
CHECK_CUSOLVER(cusolverDnCtrtri(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, reinterpret_cast<cuComplex*>(A), lda, d_work, lwork, d_info));
CHECK_CUDA(cudaFree(d_work));
CHECK_CUDA(cudaFree(d_info));
}
static inline void trtri(cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, std::complex<double>* A, const int& lda)
{
int lwork = 0;
CHECK_CUSOLVER(cusolverDnZtrtri_bufferSize(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, reinterpret_cast<cuDoubleComplex*>(A), lda, &lwork));
cuDoubleComplex* d_work = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_work, lwork * sizeof(cuDoubleComplex)));
int* d_info = nullptr;
CHECK_CUDA(cudaMalloc((void**)&d_info, sizeof(int)));
CHECK_CUSOLVER(cusolverDnZtrtri(cusolver_handle, cublas_fill_mode(uplo), cublas_diag_type(diag), n, reinterpret_cast<cuDoubleComplex*>(A), lda, d_work, lwork, d_info));
CHECK_CUDA(cudaFree(d_work));
CHECK_CUDA(cudaFree(d_info));
}
#endif

static inline
void potri (cusolverDnHandle_t& cusolver_handle, const char& uplo, const char& diag, const int& n, float * A, const int& lda)
Expand Down Expand Up @@ -1327,7 +1380,7 @@ static inline void geqrf(
cusolver_handle, m, n,
reinterpret_cast<cuComplex*>(d_A),
lda,
&lwork // ← 这里才是 lwork 的地址!
&lwork // ← correct: pass address of lwork
));

cuComplex* d_work = nullptr;
Expand All @@ -1342,7 +1395,7 @@ static inline void geqrf(
cusolver_handle, m, n,
reinterpret_cast<cuComplex*>(d_A),
lda,
reinterpret_cast<cuComplex*>(d_tau), // ← 这里才是 d_tau
reinterpret_cast<cuComplex*>(d_tau), // ← correct: d_tau
d_work, lwork, d_info));

int h_info = 0;
Expand Down
2 changes: 2 additions & 0 deletions source/source_base/module_device/device_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ static const char* _cusolverGetErrorString(cusolverStatus_t error)
return "CUSOLVER_STATUS_ZERO_PIVOT";
case CUSOLVER_STATUS_INVALID_LICENSE:
return "CUSOLVER_STATUS_INVALID_LICENSE";
#if CUDA_VERSION >= 11000
case CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED:
return "CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED";
case CUSOLVER_STATUS_IRS_PARAMS_INVALID:
Expand All @@ -93,6 +94,7 @@ static const char* _cusolverGetErrorString(cusolverStatus_t error)
return "CUSOLVER_STATUS_IRS_MATRIX_SINGULAR";
case CUSOLVER_STATUS_INVALID_WORKSPACE:
return "CUSOLVER_STATUS_INVALID_WORKSPACE";
#endif
default:
return "<unknown>";
}
Expand Down
27 changes: 21 additions & 6 deletions source/source_base/parallel_global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,30 @@ void Parallel_Global::read_pal_param(int argc,
#ifdef __MPI
void Parallel_Global::finalize_mpi()
{
MPI_Comm_free(&POOL_WORLD);
if (KP_WORLD != MPI_COMM_NULL)
if (POOL_WORLD != MPI_COMM_NULL && POOL_WORLD != MPI_COMM_WORLD)
{
MPI_Comm_free(&POOL_WORLD);
}
if (KP_WORLD != MPI_COMM_NULL && KP_WORLD != MPI_COMM_WORLD)
{
MPI_Comm_free(&KP_WORLD);
}
MPI_Comm_free(&INT_BGROUP);
MPI_Comm_free(&BP_WORLD);
MPI_Comm_free(&GRID_WORLD);
MPI_Comm_free(&DIAG_WORLD);
if (INT_BGROUP != MPI_COMM_NULL && INT_BGROUP != MPI_COMM_WORLD)
{
MPI_Comm_free(&INT_BGROUP);
}
if (BP_WORLD != MPI_COMM_NULL && BP_WORLD != MPI_COMM_WORLD)
{
MPI_Comm_free(&BP_WORLD);
}
if (GRID_WORLD != MPI_COMM_NULL && GRID_WORLD != MPI_COMM_WORLD)
{
MPI_Comm_free(&GRID_WORLD);
}
if (DIAG_WORLD != MPI_COMM_NULL && DIAG_WORLD != MPI_COMM_WORLD)
{
MPI_Comm_free(&DIAG_WORLD);
}
MPI_Finalize();
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion source/source_base/timer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void timer::finish(std::ofstream &ofs, const bool print_flag, const bool check_e
//----------------------------------------------------------
void timer::start()
{
// first init ,then we can use tick
// first init, then we can use start/end
if(timer_pool[""]["total"].start_flag)
{ timer::start("","total"); }
}
Expand Down
3 changes: 3 additions & 0 deletions source/source_cell/read_atoms_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ bool parse_atom_properties(std::ifstream& ifpos,
atom.lambda[ia].x /= ModuleBase::Ry_to_eV;
atom.lambda[ia].y /= ModuleBase::Ry_to_eV;
atom.lambda[ia].z /= ModuleBase::Ry_to_eV;
std::cout << "[DS-DIAG] STRU parse: lambda[" << ia << "]=("
<< atom.lambda[ia].x << ", " << atom.lambda[ia].y << ", "
<< atom.lambda[ia].z << ") Ry/uB (converted from eV/uB)" << std::endl;
}
else if ( tmpid == "sc")
{
Expand Down
6 changes: 3 additions & 3 deletions source/source_cell/test/read_sep_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST_F(ReadSepTest, PrintSep)
if (GlobalV::MY_RANK == 0)
{
#endif
// 设置测试数据
// Set up test data
read_sep->label = "F";
read_sep->xc_type = "pbe";
read_sep->orbital = "p";
Expand All @@ -78,13 +78,13 @@ TEST_F(ReadSepTest, PrintSep)
read_sep->r = new double[2]{0.1, 0.2};
read_sep->rv = new double[2]{1.0, 2.0};

// 测试打印功能
// Test print functionality
std::ofstream ofs("test_sep.out");
read_sep->print_sep_info(ofs);
read_sep->print_sep_vsep(ofs);
ofs.close();

// 验证输出文件
// Verify output file
std::ifstream ifs("test_sep.out");
std::string line;
std::vector<std::string> lines;
Expand Down
28 changes: 27 additions & 1 deletion source/source_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,27 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

// 2) run the inner lambda loop to contrain atomic moments with the DeltaSpin method
bool skip_solve = run_deltaspin_lambda_loop_lcao<TK>(iter - 1, this->drho, PARAM.inp);
bool skip_solve = false;
if (PARAM.inp.sc_mag_switch)
{
spinconstrain::SpinConstrain<TK>& sc = spinconstrain::SpinConstrain<TK>::getScInstance();
if (PARAM.inp.sc_lambda_strategy == "linear_scan")
{
sc.run_lambda_linear_scan(iter - 1);
skip_solve = true;
}
else if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr)
{
sc.run_lambda_loop(iter - 1);
sc.set_mag_converged(true);
skip_solve = true;
}
else if (sc.mag_converged())
{
sc.run_lambda_loop(iter - 1);
skip_solve = true;
}
}

// 3) run Hsolver
if (!skip_solve)
Expand All @@ -407,6 +427,12 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int
hsolver_lcao_obj.solve(static_cast<hamilt::Hamilt<TK>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm,
this->chr, PARAM.inp.nspin, skip_charge);
}
else
{
// Lambda loop updated the density matrix (DM) but not the real-space charge density.
// HSolver was skipped, so we need to sync rho from DM manually.
LCAO_domain::dm2rho(this->dmat.dm->get_DMR_vector(), PARAM.inp.nspin, &this->chr);
}

// 4) EXX
#ifdef __EXX
Expand Down
2 changes: 1 addition & 1 deletion source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const

// update local occupations for DFT+U
// should before lambda loop in DeltaSpin
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.template get_psi_t<T, Device>(), this->pelec->wg, ucell, PARAM.inp);
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.template get_psi_t<T, Device>(), this->pelec->wg, ucell, this->p_chgmix, this->kv.isk.data());
}

// Temporary, it should be replaced by hsolver later.
Expand Down
4 changes: 2 additions & 2 deletions source/source_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
this->p_hamilt_sto,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.ks_solver,
PARAM.globalv.use_uspp,
PARAM.inp.ks_solver,
PARAM.globalv.use_uspp,
PARAM.inp.nspin,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
Expand Down
1 change: 1 addition & 0 deletions source/source_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
PARAM.inp.sccut,
PARAM.inp.sc_drop_thr,
ucell,
PARAM.inp.sc_direction_only,
&(this->pv),
PARAM.inp.nspin,
this->kv,
Expand Down
7 changes: 7 additions & 0 deletions source/source_esolver/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
remove_definitions(-D__MPI)
remove_definitions(-D__LCAO)
remove_definitions(-D__CUDA)

install(DIRECTORY support DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

Expand All @@ -8,3 +9,9 @@ AddTest(
LIBS parameter ${math_libs} base device
SOURCES esolver_dp_test.cpp ../esolver_dp.cpp ../../source_io/module_output/cif_io.cpp ../../source_io/module_output/output_log.cpp
)

AddTest(
TARGET MODULE_ESOLVER_nscf_utils_test
LIBS parameter ${math_libs} base device
SOURCES nscf_utils_test.cpp
)
Loading
Loading