Skip to content

Commit

Permalink
Refactor: remove GlobalC::Pkpoint (#5846)
Browse files Browse the repository at this point in the history
* Refactor: remove GlobalC::Pkpoint
  • Loading branch information
Qianruipku authored Jan 10, 2025
1 parent 16714c6 commit 4484536
Show file tree
Hide file tree
Showing 22 changed files with 61 additions and 101 deletions.
16 changes: 8 additions & 8 deletions source/module_cell/klist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ void K_Vectors::set(const UnitCell& ucell,
// It's very important in parallel case,
// firstly do the mpi_k() and then
// do set_kup_and_kdw()
GlobalC::Pkpoints.kinfo(nkstot,
GlobalV::KPAR,
GlobalV::MY_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC,
nspin_in); // assign k points to several process pools
this->para_k.kinfo(nkstot,
GlobalV::KPAR,
GlobalV::MY_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC,
nspin_in); // assign k points to several process pools
#ifdef __MPI
// distribute K point data to the corresponding process
this->mpi_k(); // 2008-4-29
Expand Down Expand Up @@ -1163,7 +1163,7 @@ void K_Vectors::mpi_k()

Parallel_Common::bcast_double(koffset, 3);

this->nks = GlobalC::Pkpoints.nks_pool[GlobalV::MY_POOL];
this->nks = this->para_k.nks_pool[GlobalV::MY_POOL];

GlobalV::ofs_running << std::endl;
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "k-point number in this process", nks);
Expand Down Expand Up @@ -1217,7 +1217,7 @@ void K_Vectors::mpi_k()
for (int i = 0; i < nks; i++)
{
// 3 is because each k point has three value:kx, ky, kz
k_index = i + GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL];
k_index = i + this->para_k.startk_pool[GlobalV::MY_POOL];
kvec_c[i].x = kvec_c_aux[k_index * 3];
kvec_c[i].y = kvec_c_aux[k_index * 3 + 1];
kvec_c[i].z = kvec_c_aux[k_index * 3 + 2];
Expand Down
5 changes: 4 additions & 1 deletion source/module_cell/klist.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "module_base/global_variable.h"
#include "module_base/matrix3.h"
#include "module_cell/unitcell.h"

#include "parallel_kpoints.h"
#include <vector>

class K_Vectors
Expand All @@ -31,6 +31,9 @@ class K_Vectors
K_Vectors& operator=(const K_Vectors&) = default;
K_Vectors& operator=(K_Vectors&& rhs) = default;

Parallel_Kpoints para_k; ///< parallel for kpoints


/**
* @brief Set up the k-points for the system.
*
Expand Down
12 changes: 2 additions & 10 deletions source/module_cell/parallel_kpoints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
#include "module_base/parallel_common.h"
#include "module_base/parallel_global.h"

Parallel_Kpoints::Parallel_Kpoints()
{
}

Parallel_Kpoints::~Parallel_Kpoints()
{
}

// the kpoints here are reduced after symmetry applied.
void Parallel_Kpoints::kinfo(int& nkstot_in,
const int& kpar_in,
Expand Down Expand Up @@ -227,7 +219,7 @@ void Parallel_Kpoints::pool_collection(double* value_re,
return;
}

void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik)
void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const
{
const int dim2 = w.getBound2();
const int dim3 = w.getBound3();
Expand All @@ -237,7 +229,7 @@ void Parallel_Kpoints::pool_collection(std::complex<double>* value, const Module
}

template <class T, class V>
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik)
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const
{
#ifdef __MPI
const int ik_now = ik - this->startk_pool[this->my_pool];
Expand Down
12 changes: 6 additions & 6 deletions source/module_cell/parallel_kpoints.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
class Parallel_Kpoints
{
public:
Parallel_Kpoints();
~Parallel_Kpoints();
Parallel_Kpoints(){};
~Parallel_Kpoints(){};

void kinfo(int& nkstot_in,
const int& kpar_in,
Expand All @@ -28,9 +28,9 @@ class Parallel_Kpoints
const ModuleBase::realArray& a,
const ModuleBase::realArray& b,
const int& ik);
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik);
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const;
template <class T, class V>
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik);
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const;
#ifdef __MPI
/**
* @brief gather kpoints from all processors
Expand All @@ -46,8 +46,8 @@ class Parallel_Kpoints
// int* nproc_pool = nullptr; it is not used

// inforamation about kpoints, dim: KPAR
std::vector<int> nks_pool; // number of k-points in each pool
std::vector<int> startk_pool; // the first k-point in each pool
std::vector<int> nks_pool; // number of k-points in each pool, here use k-points without spin
std::vector<int> startk_pool; // the first k-point in each pool, here use k-points without spin

// information about which pool each k-point belongs to,
std::vector<int> whichpool; // whichpool[k] : the pool which k belongs to, dim: nkstot_np
Expand Down
6 changes: 2 additions & 4 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
// qianrui modify 2020-10-18
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "md" || PARAM.inp.calculation == "relax")
{
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
}

const int nspin0 = (PARAM.inp.nspin == 2) ? 2 : 1;
Expand All @@ -432,8 +432,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
0.0,
PARAM.inp.out_band[1],
this->pelec->ekb,
this->kv,
&(GlobalC::Pkpoints));
this->kv);
}
} // out_band

Expand All @@ -452,7 +451,6 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
PARAM.inp.dos_scale,
PARAM.inp.dos_sigma,
*(this->pelec->klist),
GlobalC::Pkpoints,
ucell,
this->pelec->eferm,
PARAM.inp.nbands,
Expand Down
7 changes: 3 additions & 4 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
GlobalV::MY_RANK,
ucell,
this->sf,
GlobalC::Pkpoints,
this->kv.para_k,
this->ppcell,
*this->pw_wfc);
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
Expand Down Expand Up @@ -844,7 +844,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
}

//! 2) Print occupation numbers into istate.info
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);

//! 3) Compute density of states (DOS)
if (PARAM.inp.out_dos)
Expand Down Expand Up @@ -883,8 +883,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
0.0,
PARAM.inp.out_band[1],
this->pelec->ekb,
this->kv,
&(GlobalC::Pkpoints));
this->kv);
}
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
GlobalV::ofs_running << std::setprecision(16);
GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
}

template <>
Expand All @@ -277,7 +277,7 @@ void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::after_all_r
GlobalV::ofs_running << std::setprecision(16);
GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);

if (this->method_sto == 2)
{
Expand Down
1 change: 0 additions & 1 deletion source/module_hamilt_pw/hamilt_pwdft/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ namespace GlobalC
#include "module_cell/unitcell.h"
namespace GlobalC
{
extern Parallel_Kpoints Pkpoints;
extern Restart restart; // Peize Lin add 2020.04.04
} // namespace GlobalC

Expand Down
6 changes: 1 addition & 5 deletions source/module_io/dos_nao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace ModuleIO
/// @param[in] dos_scale
/// @param[in] dos_sigma
/// @param[in] kv
/// @param[in] Pkpoints
/// @param[in] ucell
/// @param[in] eferm
/// @param[in] nbands
Expand All @@ -28,7 +27,6 @@ namespace ModuleIO
const double& dos_scale,
const double& dos_sigma,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const elecstate::efermi& eferm,
int nbands,
Expand All @@ -45,7 +43,7 @@ namespace ModuleIO
{
std::stringstream ss3;
ss3 << PARAM.globalv.global_out_dir << "Fermi_Surface_" << i << ".bxsf";
nscf_fermi_surface(ss3.str(), nbands, eferm.ef, kv, Pkpoints, ucell, ekb);
nscf_fermi_surface(ss3.str(), nbands, eferm.ef, kv, ucell, ekb);
}
}

Expand All @@ -69,7 +67,6 @@ template void out_dos_nao(
const double& dos_scale,
const double& dos_sigma,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const elecstate::efermi& eferm,
int nbands,
Expand All @@ -84,7 +81,6 @@ template void out_dos_nao(
const double& dos_scale,
const double& dos_sigma,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const elecstate::efermi& eferm,
int nbands,
Expand Down
1 change: 0 additions & 1 deletion source/module_io/dos_nao.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace ModuleIO
const double& dos_scale,
const double& dos_sigma,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const elecstate::efermi& eferm,
int nbands,
Expand Down
13 changes: 6 additions & 7 deletions source/module_io/nscf_band.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ void ModuleIO::nscf_band(
const double &fermie,
const int &precision,
const ModuleBase::matrix& ekb,
const K_Vectors& kv,
const Parallel_Kpoints* Pkpoints)
const K_Vectors& kv)
{
ModuleBase::TITLE("ModuleIO","nscf_band");
ModuleBase::timer::tick("ModuleIO", "nscf_band");
// number of k points without spin; nspin = 1,2, nkstot = nkstot_np * nspin;
// nspin = 4, nkstot = nkstot_np
const int nkstot_np = Pkpoints->nkstot_np;
const int nks_np = Pkpoints->nks_np;
const int nkstot_np = kv.para_k.nkstot_np;
const int nks_np = kv.para_k.nks_np;

#ifdef __MPI
if(GlobalV::MY_RANK==0)
Expand All @@ -33,7 +32,7 @@ void ModuleIO::nscf_band(
klength.resize(nkstot_np);
klength[0] = 0.0;
std::vector<ModuleBase::Vector3<double>> kvec_c_global;
Pkpoints->gatherkvec(kv.kvec_c, kvec_c_global);
kv.para_k.gatherkvec(kv.kvec_c, kvec_c_global);
for(int ik=0; ik<nkstot_np; ik++)
{
if (ik>0)
Expand All @@ -43,10 +42,10 @@ void ModuleIO::nscf_band(
klength[ik] += (kv.kl_segids[ik] == kv.kl_segids[ik-1]) ? delta.norm() : 0.0;
}
/* first find if present kpoint in present pool */
if ( GlobalV::MY_POOL == Pkpoints->whichpool[ik] )
if ( GlobalV::MY_POOL == kv.para_k.whichpool[ik] )
{
/* then get the local kpoint index, which starts definitly from 0 */
const int ik_now = ik - Pkpoints->startk_pool[GlobalV::MY_POOL];
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
/* if present kpoint corresponds the spin of the present one */
assert( kv.isk[ik_now+is*nks_np] == is );
if ( GlobalV::RANK_IN_POOL == 0)
Expand Down
4 changes: 1 addition & 3 deletions source/module_io/nscf_band.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ namespace ModuleIO
* @param precision precision of the output
* @param ekb eigenvalues of k points and bands
* @param kv klist
* @param Pkpoints parallel kpoints
*/
void nscf_band(const int& is,
const std::string& out_band_dir,
const int& nband,
const double& fermie,
const int& precision,
const ModuleBase::matrix& ekb,
const K_Vectors& kv,
const Parallel_Kpoints* Pkpoints);
const K_Vectors& kv);
}

#endif
5 changes: 2 additions & 3 deletions source/module_io/nscf_fermi_surf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,
const int &nband,
const double &ef,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const ModuleBase::matrix &ekb)
{
Expand All @@ -29,7 +28,7 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,

for(int ik=0; ik<kv.get_nkstot(); ik++)
{
if ( GlobalV::MY_POOL == Pkpoints.whichpool[ik] )
if ( GlobalV::MY_POOL == kv.para_k.whichpool[ik] )
{
if( GlobalV::RANK_IN_POOL == 0)
{
Expand Down Expand Up @@ -58,7 +57,7 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,
ofs << " " << ucell.G.e31 << " " << ucell.G.e32 << " " << ucell.G.e33 << std::endl;
}

const int ik_now = ik - Pkpoints.startk_pool[GlobalV::MY_POOL];
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
ofs << "ik= " << ik << std::endl;
ofs << kv.kvec_c[ik_now].x << " " << kv.kvec_c[ik_now].y << " " << kv.kvec_c[ik_now].z << std::endl;

Expand Down
1 change: 0 additions & 1 deletion source/module_io/nscf_fermi_surf.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ void nscf_fermi_surface(const std::string& out_band_dir,
const int& nband,
const double& ef,
const K_Vectors& kv,
const Parallel_Kpoints& Pkpoints,
const UnitCell& ucell,
const ModuleBase::matrix& ekb);
}
Expand Down
12 changes: 6 additions & 6 deletions source/module_io/numerical_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,9 @@ void Numerical_Basis::output_k(std::ofstream& ofs, const K_Vectors& kv)
// temprary restrict kpar=1 for NSPIN=2 case for generating_orbitals
int pool = 0;
if (PARAM.inp.nspin != 2) {
pool = GlobalC::Pkpoints.whichpool[ik];
pool = kv.para_k.whichpool[ik];
}
const int iknow = ik - GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL];
const int iknow = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
if (GlobalV::RANK_IN_POOL == 0)
{
if (GlobalV::MY_POOL == 0)
Expand All @@ -671,7 +671,7 @@ void Numerical_Basis::output_k(std::ofstream& ofs, const K_Vectors& kv)
else
{

int startpro_pool = GlobalC::Pkpoints.get_startpro_pool(pool);
int startpro_pool = kv.para_k.get_startpro_pool(pool);
MPI_Status ierror;
MPI_Recv(&kx, 1, MPI_DOUBLE, startpro_pool, ik * 4, MPI_COMM_WORLD, &ierror);
MPI_Recv(&ky, 1, MPI_DOUBLE, startpro_pool, ik * 4 + 1, MPI_COMM_WORLD, &ierror);
Expand Down Expand Up @@ -755,7 +755,7 @@ void Numerical_Basis::output_overlap_Q(std::ofstream& ofs, const std::vector<Mod
{
ModuleBase::ComplexArray Qtmp(overlap_Q[ik].getBound1(), overlap_Q[ik].getBound2(), overlap_Q[ik].getBound3());
Qtmp.zero_out();
GlobalC::Pkpoints.pool_collection(Qtmp.ptr, overlap_Q_k, ik);
kv.para_k.pool_collection(Qtmp.ptr, overlap_Q_k, ik);
if (GlobalV::MY_RANK == 0)
{
// ofs << "\n ik=" << ik;
Expand Down Expand Up @@ -803,12 +803,12 @@ void Numerical_Basis::output_overlap_Sq(const std::string& name, std::ofstream&
{
for (int ik = 0; ik < nkstot; ik++)
{
if (GlobalV::MY_POOL == GlobalC::Pkpoints.whichpool[ik])
if (GlobalV::MY_POOL == kv.para_k.whichpool[ik])
{
if (GlobalV::RANK_IN_POOL == 0)
{
ofs.open(name.c_str(), std::ios::app);
const int ik_now = ik - GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL] + is * nkstot;
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL] + is * nkstot;

const int size = overlap_Sq[ik_now].getSize();
for (int i = 0; i < size; i++)
Expand Down
Loading

0 comments on commit 4484536

Please sign in to comment.