Skip to content

Commit

Permalink
Update before_all_runners in ESolver (deepmodeling#4334)
Browse files Browse the repository at this point in the history
* add explanations in esolver_ks_lcao

* add comments in esolver_ks_lcao

* add notes in esolver

* organize the functions in esolver_ks

* add notes in esolver_ks_lcao

* update some function names in ESolver, change init() to before_runner(), change run() to runner(), and change post_process() to after_all_runners()

* update esolver

* fix tests in md

* delete some GlobalC in esolver_ks_pw.cpp

* delete some GlobalC::ppcell

* update esolver

* update wavefunc, this class is messy.

* update small places

* get clear of the bfore_all_runners subroutines

* fix fixed_weights in elecstate tests

* fix hsolver with fixed_weights

* fix a bug left by previous commit...
  • Loading branch information
mohanchen authored Jun 9, 2024
1 parent ed1eaf9 commit 332c875
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 138 deletions.
25 changes: 18 additions & 7 deletions source/module_elecstate/elecstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,49 @@ const double* ElecState::getRho(int spin) const
return &(this->charge->rho[spin][0]);
}

void ElecState::fixed_weights(const std::vector<double>& ocp_kb)
void ElecState::fixed_weights(
const std::vector<double>& ocp_kb,
const int &nbands,
const double &nelec)
{

int num = 0;
num = this->klist->get_nks() * GlobalV::NBANDS;
assert(nbands>0);
assert(nelec>0.0);

const double ne_thr = 1.0e-5;

const int num = this->klist->get_nks() * nbands;
if (num != ocp_kb.size())
{
ModuleBase::WARNING_QUIT("ElecState::fixed_weights",
"size of occupation array is wrong , please check ocp_set");
}

double num_elec = 0.0;
for (int i = 0; i < ocp_kb.size(); i++)
for (int i = 0; i < ocp_kb.size(); ++i)
{
num_elec += ocp_kb[i];
}
if (std::abs(num_elec - GlobalV::nelec) > 1.0e-5)

if (std::abs(num_elec - nelec) > ne_thr)
{
ModuleBase::WARNING_QUIT("ElecState::fixed_weights",
"total number of occupations is wrong , please check ocp_set");
}

for (int ik = 0; ik < this->wg.nr; ik++)
for (int ik = 0; ik < this->wg.nr; ++ik)
{
for (int ib = 0; ib < this->wg.nc; ib++)
for (int ib = 0; ib < this->wg.nc; ++ib)
{
this->wg(ik, ib) = ocp_kb[ik * this->wg.nc + ib];
}
}
this->skip_weights = true;

return;
}


void ElecState::init_nelec_spin()
{
this->nelec_spin.resize(GlobalV::NSPIN);
Expand Down
8 changes: 7 additions & 1 deletion source/module_elecstate/elecstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ class ElecState

// calculate wg from ekb
virtual void calculate_weights();

// use occupied weights from INPUT and skip calculate_weights
void fixed_weights(const std::vector<double>& ocp_kb);
// mohan updated on 2024-06-08
void fixed_weights(
const std::vector<double>& ocp_kb,
const int &nbands,
const double &nelec);

// if nupdown is not 0(TWO_EFERMI case),
// nelec_spin will be fixed and weights will be constrained
void init_nelec_spin();
Expand Down
8 changes: 4 additions & 4 deletions source/module_elecstate/test/elecstate_base_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ TEST_F(ElecStateTest,FixedWeights)
{
ocp_kb[i] = 1.0;
}
elecstate->fixed_weights(ocp_kb);
elecstate->fixed_weights(ocp_kb, GlobalV::NBANDS, GlobalV::nelec);
EXPECT_EQ(elecstate->wg(0, 0), 1.0);
EXPECT_EQ(elecstate->wg(klist->get_nks()-1, GlobalV::NBANDS-1), 1.0);
EXPECT_TRUE(elecstate->skip_weights);
Expand All @@ -428,7 +428,7 @@ TEST_F(ElecStateDeathTest,FixedWeightsWarning1)
ocp_kb[i] = 1.0;
}
testing::internal::CaptureStdout();
EXPECT_EXIT(elecstate->fixed_weights(ocp_kb), ::testing::ExitedWithCode(0), "");
EXPECT_EXIT(elecstate->fixed_weights(ocp_kb, GlobalV::NBANDS, GlobalV::nelec), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("size of occupation array is wrong , please check ocp_set"));
}
Expand All @@ -448,7 +448,7 @@ TEST_F(ElecStateDeathTest,FixedWeightsWarning2)
ocp_kb[i] = 1.0;
}
testing::internal::CaptureStdout();
EXPECT_EXIT(elecstate->fixed_weights(ocp_kb), ::testing::ExitedWithCode(0), "");
EXPECT_EXIT(elecstate->fixed_weights(ocp_kb, GlobalV::NBANDS, GlobalV::nelec), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("total number of occupations is wrong , please check ocp_set"));
}
Expand Down Expand Up @@ -703,4 +703,4 @@ TEST_F(ElecStateTest, CalculateWeightsGWeightsTwoFermi)


#undef protected
#undef private
#undef private
6 changes: 5 additions & 1 deletion source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
{
ModuleBase::TITLE("ESolver_FP", "before_all_runners");

//! 1) read pseudopotentials
if(!GlobalV::use_paw)
{
cell.read_pseudo(GlobalV::ofs_running);
}

//! 2) initialie the plane wave basis for rho
#ifdef __MPI
this->pw_rho->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
#endif
Expand All @@ -57,7 +59,6 @@ void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
this->pw_rho->setfullpw(inp.of_full_pw, inp.of_full_pw_dim);
}

//! Initalize the plane wave basis set
if (inp.nx * inp.ny * inp.nz == 0)
{
this->pw_rho->initgrids(inp.ref_cell_factor * cell.lat0, cell.latvec, 4.0 * inp.ecutwfc);
Expand All @@ -66,12 +67,14 @@ void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
{
this->pw_rho->initgrids(inp.ref_cell_factor * cell.lat0, cell.latvec, inp.nx, inp.ny, inp.nz);
}

this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc);
this->pw_rho->ft.fft_mode = inp.fft_mode;
this->pw_rho->setuptransform();
this->pw_rho->collect_local_pw();
this->pw_rho->collect_uniqgg();

//! 3) initialize the double grid (for uspp) if necessary
if (GlobalV::double_grid)
{
ModulePW::PW_Basis_Sup* pw_rhod_sup = static_cast<ModulePW::PW_Basis_Sup*>(pw_rhod);
Expand All @@ -97,6 +100,7 @@ void ESolver_FP::before_all_runners(Input& inp, UnitCell& cell)
this->pw_rhod->collect_uniqgg();
}

//! 4) print some information
this->print_rhofft(inp, GlobalV::ofs_running);

return;
Expand Down
30 changes: 16 additions & 14 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
{
ModuleBase::TITLE("ESolver_KS", "before_all_runners");

//! 1) initialize "before_all_runniers" in ESolver_FP
ESolver_FP::before_all_runners(inp,ucell);

//------------------Charge Mixing------------------
//! 2) setup the charge mixing parameters
p_chgmix->set_mixing(GlobalV::MIXING_MODE,
GlobalV::MIXING_BETA,
GlobalV::MIXING_NDIM,
Expand All @@ -108,7 +109,6 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
GlobalV::MIXING_DMR);

/// PAW Section

#ifdef USE_PAW
if(GlobalV::use_paw)
{
Expand Down Expand Up @@ -180,11 +180,12 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
#endif
/// End PAW

//! 3) calculate the electron number
ucell.cal_nelec(GlobalV::nelec);

/* it has been established that
xc_func is same for all elements, therefore
only the first one if used*/
//! 4) it has been established that
// xc_func is same for all elements, therefore
// only the first one if used
if(GlobalV::use_paw)
{
XC_Functional::set_xc_type(GlobalV::DFT_FUNCTIONAL);
Expand All @@ -195,23 +196,22 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
}
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL");

// ESolver depends on the Symmetry module
//! 5) ESolver depends on the Symmetry module
// symmetry analysis should be performed every time the cell is changed
if (ModuleSymmetry::Symmetry::symm_flag == 1)
{
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
}

// Setup the k points according to symmetry.
//! 6) Setup the k points according to symmetry.
this->kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

// print information
// mohan add 2021-01-30
//! 7) print information
Print_Info::setup_parameters(ucell, this->kv);

//new plane wave basis
//! 8) new plane wave basis, fft grids, etc.
#ifdef __MPI
this->pw_wfc->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
#endif
Expand All @@ -221,6 +221,7 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
this->pw_rho->nx,
this->pw_rho->ny,
this->pw_rho->nz);

this->pw_wfc->initparameters(false, inp.ecutwfc, this->kv.get_nks(), this->kv.kvec_d.data());

// the MPI allreduce should not be here, mohan 2024-05-12
Expand All @@ -236,6 +237,7 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)

this->pw_wfc->setuptransform();

//! 9) initialize the number of plane waves for each k point
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
{
this->kv.ngk[ik] = this->pw_wfc->npwk[ik];
Expand All @@ -245,20 +247,20 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)

this->print_wfcfft(inp, GlobalV::ofs_running);

//! initialize the real-space uniform grid for FFT and parallel
//! 10) initialize the real-space uniform grid for FFT and parallel
//! distribution of plane waves
GlobalC::Pgrid.init(this->pw_rhod->nx,
this->pw_rhod->ny,
this->pw_rhod->nz,
this->pw_rhod->nplane,
this->pw_rhod->nrxx,
pw_big->nbz,
pw_big->bz); // mohan add 2010-07-22, update 2011-05-04
pw_big->bz);

// Calculate Structure factor
//! 11) calculate the structure factor
this->sf.setup_structure_factor(&ucell, this->pw_rhod);

// Initialize charge extrapolation
//! 12) initialize the charge extrapolation method if necessary
CE.Init_CE(ucell.nat);

#ifdef USE_PAW
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ESolver_KS : public ESolver_FP

protected:
//! Something to do before SCF iterations.
virtual void before_scf(int istep) {};
virtual void before_scf(const int istep) {};

//! Something to do before hamilt2density function in each iter loop.
virtual void iter_init(const int istep, const int iter) {};
Expand Down
23 changes: 9 additions & 14 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
this->LOWF.ParaV = &(this->orb_con.ParaV);
this->LM.ParaV = &(this->orb_con.ParaV);

// 5) initialize Density Matrix
// 5) initialize density matrix
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->orb_con.ParaV), GlobalV::NSPIN);


// this function should be removed outside of the function
if (GlobalV::CALCULATION == "get_S")
{
ModuleBase::timer::tick("ESolver_KS_LCAO", "init");
Expand Down Expand Up @@ -235,30 +236,27 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
}
#endif

// 8) Quxin added for DFT+U
// 8) initialize DFT+U
if (GlobalV::dft_plus_u)
{
GlobalC::dftu.init(ucell, this->LM, this->kv.get_nks());
}

// 9) ppcell
// output is GlobalC::ppcell.vloc 3D local pseudopotentials
// without structure factors
// this function belongs to cell LOOP
// 9) initialize ppcell
GlobalC::ppcell.init_vloc(GlobalC::ppcell.vloc, this->pw_rho);

// 10) init HSolver
// 10) initialize the HSolver
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolverLCAO<TK>(&(this->orb_con.ParaV));
this->phsol->method = GlobalV::KS_SOLVER;
}

// 11) inititlize the charge density.
// 11) inititlize the charge density
this->pelec->charge->allocate(GlobalV::NSPIN);
this->pelec->omega = GlobalC::ucell.omega;

// 12) initialize the potential.
// 12) initialize the potential
if (this->pelec->pot == nullptr)
{
this->pelec->pot = new elecstate::Potential(this->pw_rhod,
Expand All @@ -272,20 +270,17 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)

#ifdef __DEEPKS
// 13) initialize deepks
// wenfei 2021-12-19
// if we are performing DeePKS calculations, we need to load a model
if (GlobalV::deepks_scf)
{
// load the DeePKS model from deep neural network
GlobalC::ld.load_model(INPUT.deepks_model);
}
#endif

// 14) set occupations?
// Fix this->pelec->wg by ocp_kb
// 14) set occupations
if (GlobalV::ocp)
{
this->pelec->fixed_weights(GlobalV::ocp_kb);
this->pelec->fixed_weights(GlobalV::ocp_kb, GlobalV::NBANDS, GlobalV::nelec);
}

ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners");
Expand Down
Loading

0 comments on commit 332c875

Please sign in to comment.