Skip to content

Commit

Permalink
Merge branch 'develop' into esolver
Browse files Browse the repository at this point in the history
  • Loading branch information
mohanchen committed Jul 27, 2024
2 parents 417c865 + fdd8837 commit 2e42278
Show file tree
Hide file tree
Showing 16 changed files with 421 additions and 64 deletions.
12 changes: 12 additions & 0 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,12 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep,
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,

false);


Expand Down Expand Up @@ -1080,6 +1086,12 @@ void ESolver_KS_PW<T, Device>::hamilt2estates(const double ethr) {
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,

true);
} else {
ModuleBase::WARNING_QUIT("ESolver_KS_PW",
Expand Down
16 changes: 15 additions & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,21 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)

hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;

this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec, pw_wfc, this->stowf, istep, iter, GlobalV::KS_SOLVER);
this->phsol->solve(this->p_hamilt,
this->psi[0],
this->pelec,
pw_wfc,
this->stowf,
istep,
iter,
GlobalV::KS_SOLVER,

hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,

false);

if (GlobalV::MY_STOGROUP == 0)
{
Expand Down
31 changes: 15 additions & 16 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
// lagrange_matrix(nband, nband); // for orthogonalization
resmem_complex_op()(this->ctx, this->lagrange_matrix, nband * nband);
setmem_complex_op()(this->ctx, this->lagrange_matrix, 0, nband * nband);

#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->d_precondition, dim);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, dim);
}
#endif
}

/**
Expand All @@ -130,6 +138,13 @@ DiagoDavid<T, Device>::~DiagoDavid()
delmem_complex_op()(this->ctx, this->vcc);
delmem_complex_op()(this->ctx, this->lagrange_matrix);
base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx, this->eigenvalue);
// If the device is a GPU device, free the d_precondition array.
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif
}

template <typename T, typename Device>
Expand Down Expand Up @@ -1135,14 +1150,6 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
int ntry = 0;
this->notconv = 0;

#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->d_precondition, ldPsi);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, ldPsi);
}
#endif

int sum_dav_iter = 0;
do
{
Expand All @@ -1155,14 +1162,6 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
std::cout << "\n notconv = " << this->notconv;
std::cout << "\n DiagoDavid::diag', too many bands are not converged! \n";
}
// If the device is a GPU device, free the d_precondition array.
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif

return sum_dav_iter;
}

Expand Down
15 changes: 13 additions & 2 deletions source/module_hsolver/hsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ class HSolver
const int rank_in_pool_in,
const int nproc_in_pool_in,

const bool skip_charge = false)
const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge)
{
return;
}
Expand All @@ -85,7 +90,13 @@ class HSolver
const int istep,
const int iter,
const std::string method,
const bool skip_charge = false)

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge)
{
return;
}
Expand Down
32 changes: 21 additions & 11 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
const int rank_in_pool_in,
const int nproc_in_pool_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge)
{
ModuleBase::TITLE("HSolverPW", "solve");
Expand All @@ -271,6 +276,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
this->rank_in_pool = rank_in_pool_in;
this->nproc_in_pool = nproc_in_pool_in;

this->scf_iter = scf_iter_in;
this->need_subspace = need_subspace_in;
this->diag_iter_max = diag_iter_max_in;
this->pw_diag_thr = pw_diag_thr_in;

// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
Expand All @@ -286,7 +296,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
{
this->set_isOccupied(is_occupied,
pes,
DiagoIterAssist<T, Device>::SCF_ITER,
this->scf_iter,
psi.get_nk(),
psi.get_nbands(),
this->diago_full_acc);
Expand Down Expand Up @@ -318,7 +328,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR
<< " ; where current threshold is: " << this->pw_diag_thr
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
Expand Down Expand Up @@ -409,10 +419,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
};
DiagoCG<T, Device> cg(this->basis_type,
this->calculation_type,
DiagoIterAssist<T, Device>::need_subspace,
this->need_subspace,
subspace_func,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
this->pw_diag_thr,
this->diag_iter_max,
this->nproc_in_pool);

// warp the hpsi_func and spsi_func into a lambda function
Expand Down Expand Up @@ -521,9 +531,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
this->pw_diag_thr,
this->diag_iter_max,
this->need_subspace,
comm_info);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
Expand All @@ -539,9 +549,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
/// allow 5 eigenvecs to be NOT converged.
const int notconv_max = ("nscf" == this->calculation_type) ? 0 : 5;
/// convergence threshold
const Real david_diag_thr = DiagoIterAssist<T, Device>::PW_DIAG_THR;
const Real david_diag_thr = this->pw_diag_thr;
/// maximum iterations
const int david_maxiter = DiagoIterAssist<T, Device>::PW_DIAG_NMAX;
const int david_maxiter = this->diag_iter_max;

// dimensions of matrix to be solved
const int dim = psi.get_current_nbas(); /// dimension of matrix
Expand Down Expand Up @@ -660,7 +670,7 @@ void HSolverPW<T, Device>::output_iterInfo()
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR << " . "
<< " ; where current threshold is: " << this->pw_diag_thr << " . "
<< std::endl;
// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
Expand Down
10 changes: 10 additions & 0 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class HSolverPW : public HSolver<T, Device>
const int rank_in_pool_in,
const int nproc_in_pool_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge) override;

virtual Real cal_hsolerror(const Real diag_ethr_in) override;
Expand Down Expand Up @@ -78,6 +83,11 @@ class HSolverPW : public HSolver<T, Device>

wavefunc* pwf = nullptr;

int scf_iter = 1; // Start from 1
bool need_subspace = false;
int diag_iter_max = 50;
double pw_diag_thr = 1.0e-2;

private:
Device* ctx = {};

Expand Down
15 changes: 14 additions & 1 deletion source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
const int istep,
const int iter,
const std::string method_in,
const bool skip_charge) {

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge)
{
ModuleBase::TITLE(this->classname, "solve");
ModuleBase::timer::tick(this->classname, "solve");

const int npwx = psi.get_nbasis();
const int nbands = psi.get_nbands();
const int nks = psi.get_nk();

this->scf_iter = scf_iter_in;
this->need_subspace = need_subspace_in;
this->diag_iter_max = diag_iter_max_in;
this->pw_diag_thr = pw_diag_thr_in;

// prepare for the precondition of diagonalization
std::vector<double> precondition(psi.get_nbasis(), 0.0);

Expand Down
6 changes: 6 additions & 0 deletions source/module_hsolver/hsolver_pw_sdft.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ namespace hsolver
const int istep,
const int iter,
const std::string method_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge) override;

virtual double set_diagethr(double diag_ethr_in, const int istep, const int iter, const double drho) override;
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/test/test_hsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ TEST_F(TestHSolver, solve)
hs_f.solve(&hamilt_test_f, psi_test_f, &elecstate_test, method_test, true);
hs_cd.solve(&hamilt_test_cd, psi_test_cd, &elecstate_test, method_test, true);
hs_d.solve(&hamilt_test_d, psi_test_d, &elecstate_test, method_test, true);
hs_cf.solve(&hamilt_test_cf, psi_test_cf, &elecstate_test, wfcpw, stowf_test, 0, 0, method_test, true);
hs_cd.solve(&hamilt_test_cd, psi_test_cd, &elecstate_test, wfcpw, stowf_test, 0, 0, method_test, true);
// hs_cf.solve(&hamilt_test_cf, psi_test_cf, &elecstate_test, wfcpw, stowf_test, 0, 0, method_test, true);
// hs_cd.solve(&hamilt_test_cd, psi_test_cd, &elecstate_test, wfcpw, stowf_test, 0, 0, method_test, true);
EXPECT_EQ(hs_f.classname, "none");
EXPECT_EQ(hs_d.classname, "none");
EXPECT_EQ(hs_f.method, "none");
Expand Down
12 changes: 12 additions & 0 deletions source/module_hsolver/test/test_hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ TEST_F(TestHSolverPW, solve) {
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,

true);
EXPECT_EQ(this->hs_f.initialed_psi, true);
for (int i = 0; i < psi_test_cf.size(); i++) {
Expand All @@ -107,6 +113,12 @@ TEST_F(TestHSolverPW, solve) {
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,

true);
EXPECT_EQ(this->hs_d.initialed_psi, true);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter,
Expand Down
Loading

0 comments on commit 2e42278

Please sign in to comment.