Skip to content

Commit

Permalink
Refactor: add XC_Functional::get_has_kedf() (#5879)
Browse files Browse the repository at this point in the history
* Refactor: remove elecstate_getters

* Refactor: add XC_Functional::get_has_kedf()

* Refactor: update unit tests

* [pre-commit.ci lite] apply automatic fixes

* Refactor: rename has_kedf to ked_flag

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
YuLiu98 and pre-commit-ci-lite[bot] authored Jan 24, 2025
1 parent 602ebe5 commit 46797ad
Show file tree
Hide file tree
Showing 42 changed files with 234 additions and 287 deletions.
1 change: 0 additions & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
vnl_op.o\
global.o\
magnetism.o\
elecstate_getters.o\
occupy.o\
structure_factor.o\
structure_factor_k.o\
Expand Down
1 change: 0 additions & 1 deletion source/module_elecstate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
list(APPEND objects
elecstate.cpp
elecstate_getters.cpp
elecstate_energy_terms.cpp
elecstate_energy.cpp
elecstate_exx.cpp
Expand Down
12 changes: 6 additions & 6 deletions source/module_elecstate/elecstate_energy.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "elecstate.h"
#include "elecstate_getters.h"
#include "module_base/global_variable.h"
#include "module_base/parallel_reduce.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_parameter/parameter.h"

#include <cmath>
Expand Down Expand Up @@ -103,7 +103,7 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
const double* v_eff = this->pot->get_effective_v(0);
const double* v_fixed = this->pot->get_fixed_v();
const double* v_ofk = nullptr;
const bool v_ofk_flag = (get_xc_func_type() == 3 || get_xc_func_type() == 5);
const bool v_ofk_flag = (XC_Functional::get_ked_flag());
#ifdef USE_PAW
if (PARAM.inp.use_paw)
{
Expand Down Expand Up @@ -208,14 +208,14 @@ double ElecState::cal_delta_escf() const
const double* v_fixed = this->pot->get_fixed_v();
const double* v_ofk = nullptr;

if (get_xc_func_type() == 3 || get_xc_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
v_ofk = this->pot->get_effective_vofk(0);
}
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
{
descf -= (this->charge->rho[0][ir] - this->charge->rho_save[0][ir]) * (v_eff[ir] - v_fixed[ir]);
if (get_xc_func_type() == 3 || get_xc_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
// cause in the get_effective_vofk, the func will return nullptr
assert(v_ofk != nullptr);
Expand All @@ -226,14 +226,14 @@ double ElecState::cal_delta_escf() const
if (PARAM.inp.nspin == 2)
{
v_eff = this->pot->get_effective_v(1);
if (get_xc_func_type() == 3 || get_xc_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
v_ofk = this->pot->get_effective_vofk(1);
}
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
{
descf -= (this->charge->rho[1][ir] - this->charge->rho_save[1][ir]) * (v_eff[ir] - v_fixed[ir]);
if (get_xc_func_type() == 3 || get_xc_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
descf -= (this->charge->kin_r[1][ir] - this->charge->kin_r_save[1][ir]) * v_ofk[ir];
}
Expand Down
17 changes: 0 additions & 17 deletions source/module_elecstate/elecstate_getters.cpp

This file was deleted.

16 changes: 0 additions & 16 deletions source/module_elecstate/elecstate_getters.h

This file was deleted.

6 changes: 3 additions & 3 deletions source/module_elecstate/elecstate_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<d
Gint_inout inout(this->charge->rho, Gint_Tools::job_type::rho, PARAM.inp.nspin);
this->gint_k->cal_gint(&inout);

if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
elecstate::lcao_cal_tau_k(gint_k, this->charge);
}
Expand Down Expand Up @@ -98,7 +98,7 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)

this->gint_gamma->cal_gint(&inout);

if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
elecstate::lcao_cal_tau_gamma(gint_gamma, this->charge);
}
Expand Down Expand Up @@ -161,7 +161,7 @@ void ElecStateLCAO<double>::dmToRho(std::vector<double*> pexsi_DM, std::vector<d
this->gint_gamma->transfer_DM2DtoGrid(this->DM->get_DMR_vector()); // transfer DM2D to DM_grid in gint
Gint_inout inout(this->charge->rho, Gint_Tools::job_type::rho, PARAM.inp.nspin);
this->gint_gamma->cal_gint(&inout);
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
for (int is = 0; is < PARAM.inp.nspin; is++)
{
Expand Down
3 changes: 1 addition & 2 deletions source/module_elecstate/elecstate_print.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "elecstate.h"
#include "elecstate_getters.h"
#include "module_base/formatter.h"
#include "module_base/global_variable.h"
#include "module_base/parallel_common.h"
Expand Down Expand Up @@ -461,7 +460,7 @@ void ElecState::print_etot(const Magnetism& magnet,
break;
}
std::vector<double> drho = {scf_thr};
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
drho.push_back(scf_thr_kin);
}
Expand Down
19 changes: 10 additions & 9 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "elecstate_pw.h"
#include "module_parameter/parameter.h"
#include "elecstate_getters.h"

#include "module_base/constants.h"
#include "module_base/libm/libm.h"
#include "module_base/math_ylmreal.h"
#include "module_base/module_device/device.h"
#include "module_base/parallel_reduce.h"
#include "module_base/timer.h"
#include "module_base/module_device/device.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_parameter/parameter.h"

namespace elecstate {

Expand Down Expand Up @@ -41,7 +42,7 @@ ElecStatePW<T, Device>::~ElecStatePW()
delmem_complex_op()(this->rhog_data);
delete[] this->rhog;
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
delmem_var_op()(this->kin_r_data);
delete[] this->kin_r;
Expand Down Expand Up @@ -80,7 +81,7 @@ void ElecStatePW<T, Device>::init_rho_data()
this->rhog[ii] = this->rhog_data + ii * this->charge->rhopw->npw;
}
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
this->kin_r = new Real*[this->charge->nspin];
resmem_var_op()(this->kin_r_data, this->charge->nspin * this->charge->nrxx);
Expand All @@ -96,7 +97,7 @@ void ElecStatePW<T, Device>::init_rho_data()
{
this->rhog = reinterpret_cast<T**>(this->charge->rhog);
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
this->kin_r = reinterpret_cast<Real **>(this->charge->kin_r);
}
Expand All @@ -119,7 +120,7 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
// denghui replaced at 20221110
// ModuleBase::GlobalFunc::ZEROS(this->rho[is], this->charge->nrxx);
setmem_var_op()(this->rho[is], 0, this->charge->nrxx);
if (get_xc_func_type() == 3)
if (XC_Functional::get_func_type() == 3)
{
// ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
setmem_var_op()(this->kin_r[is], 0, this->charge->nrxx);
Expand All @@ -143,7 +144,7 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
for (int ii = 0; ii < PARAM.inp.nspin; ii++)
{
castmem_var_d2h_op()(this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
if (get_xc_func_type() == 3)
if (XC_Functional::get_func_type() == 3)
{
castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
}
Expand Down Expand Up @@ -240,7 +241,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
}

// kinetic energy density
if (get_xc_func_type() == 3)
if (XC_Functional::get_func_type() == 3)
{
for (int j = 0; j < 3; j++)
{
Expand Down
1 change: 0 additions & 1 deletion source/module_elecstate/elecstate_pw_cal_tau.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "elecstate_pw.h"
#include "elecstate_getters.h"

namespace elecstate {

Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/magnetism.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "magnetism.h"
#include "elecstate_getters.h"
#include "module_parameter/parameter.h"

#include "module_base/parallel_reduce.h"
#include "module_parameter/parameter.h"

Magnetism::Magnetism()
{
Expand Down
24 changes: 13 additions & 11 deletions source/module_elecstate/module_charge/charge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@
// even in a LSDA calculation.
//----------------------------------------------------------
#include "charge.h"
#include <vector>
#include "module_parameter/parameter.h"

#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_parameter/parameter.h"
#include "module_base/libm/libm.h"
#include "module_base/math_integral.h"
#include "module_base/memory.h"
#include "module_base/parallel_reduce.h"
#include "module_base/timer.h"
#include "module_base/tool_threading.h"
#include "module_cell/unitcell.h"
#include "module_elecstate/elecstate_getters.h"
#include "module_elecstate/magnetism.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_parameter/parameter.h"

#include <vector>

#ifdef USE_PAW
#include "module_cell/module_paw/paw_cell.h"
Expand Down Expand Up @@ -80,7 +81,7 @@ void Charge::destroy()
delete[] _space_rhog_save;
delete[] _space_kin_r;
delete[] _space_kin_r_save;
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
delete[] kin_r;
delete[] kin_r_save;
Expand Down Expand Up @@ -121,7 +122,7 @@ void Charge::allocate(const int& nspin_in)
_space_rho_save = new double[nspin * nrxx];
_space_rhog = new std::complex<double>[nspin * ngmc];
_space_rhog_save = new std::complex<double>[nspin * ngmc];
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
_space_kin_r = new double[nspin * nrxx];
_space_kin_r_save = new double[nspin * nrxx];
Expand All @@ -130,7 +131,7 @@ void Charge::allocate(const int& nspin_in)
rhog = new std::complex<double>*[nspin];
rho_save = new double*[nspin];
rhog_save = new std::complex<double>*[nspin];
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
kin_r = new double*[nspin];
kin_r_save = new double*[nspin];
Expand All @@ -151,7 +152,7 @@ void Charge::allocate(const int& nspin_in)
ModuleBase::GlobalFunc::ZEROS(rhog[is], ngmc);
ModuleBase::GlobalFunc::ZEROS(rho_save[is], nrxx);
ModuleBase::GlobalFunc::ZEROS(rhog_save[is], ngmc);
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
kin_r[is] = _space_kin_r + is * nrxx;
ModuleBase::GlobalFunc::ZEROS(kin_r[is], nrxx);
Expand All @@ -171,7 +172,7 @@ void Charge::allocate(const int& nspin_in)
ModuleBase::Memory::record("Chg::rho_save", sizeof(double) * nspin * nrxx);
ModuleBase::Memory::record("Chg::rhog", sizeof(double) * nspin * ngmc);
ModuleBase::Memory::record("Chg::rhog_save", sizeof(double) * nspin * ngmc);
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5 || PARAM.inp.out_elf[0] > 0)
if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
ModuleBase::Memory::record("Chg::kin_r", sizeof(double) * nspin * ngmc);
ModuleBase::Memory::record("Chg::kin_r_save", sizeof(double) * nspin * ngmc);
Expand Down Expand Up @@ -701,9 +702,10 @@ void Charge::save_rho_before_sum_band()
for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::DCOPY(rho[is], rho_save[is], this->rhopw->nrxx);
if (elecstate::get_xc_func_type() == 3 || elecstate::get_xc_func_type() == 5) {
if (XC_Functional::get_ked_flag())
{
ModuleBase::GlobalFunc::DCOPY(kin_r[is], kin_r_save[is], this->rhopw->nrxx);
}
}
#ifdef USE_PAW
if(PARAM.inp.use_paw) {
ModuleBase::GlobalFunc::DCOPY(nhat[is], nhat_save[is], this->rhopw->nrxx);
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/module_charge/charge_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout,
}
}

if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
// If the charge density is not read in, then the kinetic energy density is not read in either
if (!read_error)
Expand Down Expand Up @@ -188,7 +188,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout,
}

// wenfei 2021-7-29 : initial tau = 3/5 rho^2/3, Thomas-Fermi
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
if (XC_Functional::get_ked_flag())
{
if (PARAM.inp.init_chg == "atomic" || read_kin_error)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_elecstate/module_charge/charge_mixing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void Charge_Mixing::init_mixing()
}

// initailize tau_mdata
if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
if (PARAM.inp.scf_thr_type == 1)
{
Expand Down Expand Up @@ -180,7 +180,7 @@ void Charge_Mixing::mix_reset()
this->mixing->reset();
this->rho_mdata.reset();
// initailize tau_mdata
if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
this->tau_mdata.reset();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ double Charge_Mixing::get_drho(Charge* chr, const double nelec)

double Charge_Mixing::get_dkin(Charge* chr, const double nelec)
{
if (!(XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5))
if (!(XC_Functional::get_ked_flag()))
{
return 0.0;
};
Expand Down
8 changes: 4 additions & 4 deletions source/module_elecstate/module_charge/charge_mixing_rho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr)
}

// For kinetic energy density
if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
std::vector<std::complex<double>> kin_g(PARAM.inp.nspin * rhodpw->npw);
std::vector<std::complex<double>> kin_g_save(PARAM.inp.nspin * rhodpw->npw);
Expand Down Expand Up @@ -485,7 +485,7 @@ void Charge_Mixing::mix_rho_real(Charge* chr)
}

double *taur_out, *taur_in;
if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
taur_in = chr->kin_r_save[0];
taur_out = chr->kin_r[0];
Expand Down Expand Up @@ -521,7 +521,7 @@ void Charge_Mixing::mix_rho(Charge* chr)
}
}
std::vector<double> kin_r123;
if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
kin_r123.resize(PARAM.inp.nspin * nrxx);
for (int is = 0; is < PARAM.inp.nspin; ++is)
Expand Down Expand Up @@ -581,7 +581,7 @@ void Charge_Mixing::mix_rho(Charge* chr)
}
}

if ((XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) && mixing_tau)
if ((XC_Functional::get_ked_flag()) && mixing_tau)
{
for (int is = 0; is < PARAM.inp.nspin; ++is)
{
Expand Down
Loading

0 comments on commit 46797ad

Please sign in to comment.