Skip to content

Commit

Permalink
Merge branch 'develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
mohanchen authored Jun 9, 2024
2 parents 50a92df + b629044 commit cecce5c
Show file tree
Hide file tree
Showing 11 changed files with 931 additions and 931 deletions.
1,298 changes: 582 additions & 716 deletions source/module_cell/klist.cpp

Large diffs are not rendered by default.

372 changes: 316 additions & 56 deletions source/module_cell/klist.h

Large diffs are not rendered by default.

131 changes: 20 additions & 111 deletions source/module_cell/parallel_kpoints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,113 +185,17 @@ void Parallel_Kpoints::pool_collection(double &value, const double *wk, const in
}


void Parallel_Kpoints::pool_collection(double *valuea, double *valueb, const ModuleBase::realArray &a, const ModuleBase::realArray &b, const int &ik)
void Parallel_Kpoints::pool_collection(double *value_re, double *value_im, const ModuleBase::realArray &re, const ModuleBase::realArray &im, const int &ik)
{
const int dim2 = a.getBound2();
const int dim3 = a.getBound3();
const int dim4 = a.getBound4();
assert( a.getBound2() == b.getBound2() );
assert( a.getBound3() == b.getBound3() );
assert( a.getBound4() == b.getBound4() );
const int dim2 = re.getBound2();
const int dim3 = re.getBound3();
const int dim4 = re.getBound4();
assert( re.getBound2() == im.getBound2() );
assert( re.getBound3() == im.getBound3() );
assert( re.getBound4() == im.getBound4() );
const int dim = dim2 * dim3 * dim4;
#ifdef __MPI
const int ik_now = ik - this->startk_pool[this->my_pool];
const int begin = ik_now * dim2 * dim3 * dim4;
double* pa = &a.ptr[begin];
double* pb = &b.ptr[begin];

const int pool = this->whichpool[ik];

GlobalV::ofs_running << "\n ik=" << ik;

if (this->rank_in_pool==0)
{
if (this->my_pool==0)
{
if (pool==0)
{
for (int i=0; i<dim ; i++)
{
valuea[i] = *pa;
valueb[i] = *pb;
++pa;
++pb;
}
}
else
{
GlobalV::ofs_running << " receive data.";
MPI_Status ierror;
MPI_Recv(valuea, dim, MPI_DOUBLE, this->startpro_pool[pool], ik*2+0, MPI_COMM_WORLD,&ierror);
MPI_Recv(valueb, dim, MPI_DOUBLE, this->startpro_pool[pool], ik*2+1, MPI_COMM_WORLD,&ierror);
}
}
else
{
if (this->my_pool == pool)
{
GlobalV::ofs_running << " send data.";
MPI_Send(pa, dim, MPI_DOUBLE, 0, ik*2+0, MPI_COMM_WORLD);
MPI_Send(pb, dim, MPI_DOUBLE, 0, ik*2+1, MPI_COMM_WORLD);
}
}
}
else
{
GlobalV::ofs_running << "\n do nothing.";
}
MPI_Barrier(MPI_COMM_WORLD);

/*
if(this->whichpool[ik] == GlobalV::MY_POOL)
{
if(GlobalV::MY_POOL > 0 && GlobalV::RANK_IN_POOL == 0)
{
// data transfer ends.
MPI_Send(pa, dim, MPI_DOUBLE, 0, ik*2, MPI_COMM_WORLD);
MPI_Send(pb, dim, MPI_DOUBLE, 0, ik*2+1, MPI_COMM_WORLD);
}
else if(GlobalV::MY_POOL == 0 && GlobalV::MY_RANK == 0)
{
// std::cout << "\n ik = " << ik << std::endl;
// data transfer begin.
for(int i=0; i<dim; i++)
{
valuea[i] = *pa;
valueb[i] = *pb;
++pa;
++pb;
}
// data transfer ends.
}
}
else
{
if(GlobalV::MY_RANK==0)
{
MPI_Status* ierror;
const int iproc = this->startpro_pool[ this->whichpool[ik] ];
MPI_Recv(valuea, dim, MPI_DOUBLE, iproc, ik*2, MPI_COMM_WORLD,ierror);
MPI_Recv(valueb, dim, MPI_DOUBLE, iproc, ik*2+1, MPI_COMM_WORLD,ierror);
}
}
*/
#else
// data transfer ends.
const int begin = ik * dim2 * dim3 * dim4;
double* pa = &a.ptr[begin];
double* pb = &b.ptr[begin];
for (int i=0; i<dim; i++)
{
valuea[i] = *pa;
valueb[i] = *pb;
++pa;
++pb;
}
// data transfer ends.
#endif
pool_collection_aux(value_re, re, dim, ik);
pool_collection_aux(value_im, im, dim, ik);
return;
}

Expand All @@ -302,12 +206,17 @@ void Parallel_Kpoints::pool_collection(std::complex<double> *value, const Module
const int dim3 = w.getBound3();
const int dim4 = w.getBound4();
const int dim = dim2 * dim3 * dim4;
pool_collection_aux(value, w, dim, ik);
}

template <class T, class V> void Parallel_Kpoints::pool_collection_aux(T *value, const V &w, const int& dim, const int &ik)
{
#ifdef __MPI
const int ik_now = ik - this->startk_pool[this->my_pool];
const int begin = ik_now * dim2 * dim3 * dim4;
std::complex<double>* p = &w.ptr[begin];
const int begin = ik_now * dim;
T* p = &w.ptr[begin];
//temprary restrict kpar=1 for NSPIN=2 case for generating_orbitals
int pool = 0;
int pool = 0;
if(this->nspin != 2) pool = this->whichpool[ik];

GlobalV::ofs_running << "\n ik=" << ik;
Expand Down Expand Up @@ -348,13 +257,13 @@ void Parallel_Kpoints::pool_collection(std::complex<double> *value, const Module

#else
// data transfer ends.
const int begin = ik * dim2 * dim3 * dim4;
std::complex<double> * p = &w.ptr[begin];
const int begin = ik * dim;
T * p = &w.ptr[begin];
for (int i=0; i<dim; i++)
{
value[i] = *p;
++p;
}
// data transfer ends.
#endif
}
}
1 change: 1 addition & 0 deletions source/module_cell/parallel_kpoints.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Parallel_Kpoints
const ModuleBase::realArray& b,
const int& ik);
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik);
template <class T, class V> void pool_collection_aux(T *value, const V &w, const int& dim, const int &ik);
#ifdef __MPI
/**
* @brief gather kpoints from all processors
Expand Down
41 changes: 3 additions & 38 deletions source/module_cell/test/klist_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,9 @@ namespace GlobalC
* - set_kup_and_kdw()
* - SetKupKdown: set basic kpoints info: kvec_c, kvec_d, wk, isk, nks, nkstot
* according to different spin case
* - set_kup_and_kdw_after_vc()
* - SetKupKdownAfterVC: set basic kpoints info: kvec_c, kvec_d, wk, isk, nks, nkstot
* according to different spin case after variable-cell optimization
* - set_both_kvec()
* - SetBothKvec: set kvec_c (cartesian coor.) and kvec_d (direct coor.)
* - SetBothKvecFinalSCF: same as above, with GlobalV::FINAL_SCF=1
* - set_both_kvec_after_vc()
* - SetBothKvecAfterVC: set kvec_c (cartesian coor.) and kvec_d (direct coor.)
* after variable-cell relaxation
* - print_klists()
* - PrintKlists: print kpoints coordinates
* - PrintKlistsWarningQuit: for nkstot < nks error
Expand Down Expand Up @@ -549,37 +543,8 @@ TEST_F(KlistTest, SetKupKdown)
}
}

TEST_F(KlistTest, SetKupKdownAfterVC)
{
std::string k_file = "./support/KPT4";
kv->nspin = 1;
kv->read_kpoints(k_file);
kv->set_kup_and_kdw_after_vc();
for (int ik=0;ik<5;ik++)
{
EXPECT_EQ(kv->isk[ik],0);
}
kv->nspin = 4;
kv->read_kpoints(k_file);
kv->set_kup_and_kdw_after_vc();
for (int ik=0;ik<5;ik++)
{
EXPECT_EQ(kv->isk[ik],0);
EXPECT_EQ(kv->isk[ik+5],0);
EXPECT_EQ(kv->isk[ik+10],0);
EXPECT_EQ(kv->isk[ik+15],0);
}
kv->nspin = 2;
kv->read_kpoints(k_file);
kv->set_kup_and_kdw_after_vc();
for (int ik=0;ik<5;ik++)
{
EXPECT_EQ(kv->isk[ik],0);
EXPECT_EQ(kv->isk[ik+5],1);
}
}

TEST_F(KlistTest, SetBothKvecAfterVC)
TEST_F(KlistTest, SetAfterVC)
{
kv->nspin = 1;
kv->set_nkstot(1);
Expand All @@ -588,7 +553,7 @@ TEST_F(KlistTest, SetBothKvecAfterVC)
kv->kvec_c[0].x = 0;
kv->kvec_c[0].y = 0;
kv->kvec_c[0].z = 0;
kv->set_both_kvec_after_vc(GlobalC::ucell.G,GlobalC::ucell.latvec);
kv->set_after_vc(GlobalV::NSPIN, GlobalC::ucell.G,GlobalC::ucell.latvec);
EXPECT_TRUE(kv->kd_done);
EXPECT_TRUE(kv->kc_done);
EXPECT_DOUBLE_EQ(kv->kvec_d[0].x,0);
Expand All @@ -608,7 +573,7 @@ TEST_F(KlistTest, PrintKlists)
kv->kvec_c[0].x = 0;
kv->kvec_c[0].y = 0;
kv->kvec_c[0].z = 0;
kv->set_both_kvec_after_vc(GlobalC::ucell.G,GlobalC::ucell.latvec);
kv->set_after_vc(GlobalV::NSPIN, GlobalC::ucell.G,GlobalC::ucell.latvec);
EXPECT_TRUE(kv->kd_done);
kv->print_klists(GlobalV::ofs_running);
GlobalV::ofs_running.close();
Expand Down
6 changes: 3 additions & 3 deletions source/module_cell/test/klist_test_para.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ TEST_F(KlistParaTest,Set)
if(GlobalV::NPROC == 4) {GlobalV::KPAR = 2;}
Parallel_Global::init_pools();
ModuleSymmetry::Symmetry::symm_flag=1;
kv->set(symm,k_file,kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec);
kv->set(symm,k_file,kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec, GlobalV::ofs_running);
EXPECT_EQ(kv->get_nkstot(),35);
EXPECT_TRUE(kv->kc_done);
EXPECT_TRUE(kv->kd_done);
Expand Down Expand Up @@ -219,7 +219,7 @@ TEST_F(KlistParaTest,SetAfterVC)
if(GlobalV::NPROC == 4) {GlobalV::KPAR = 1;}
Parallel_Global::init_pools();
ModuleSymmetry::Symmetry::symm_flag=1;
kv->set(symm,k_file,kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec);
kv->set(symm,k_file,kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec, GlobalV::ofs_running);
EXPECT_EQ(kv->get_nkstot(),35);
EXPECT_TRUE(kv->kc_done);
EXPECT_TRUE(kv->kd_done);
Expand All @@ -232,7 +232,7 @@ TEST_F(KlistParaTest,SetAfterVC)
}
//call set_after_vc here
kv->kc_done = 0;
kv->set_after_vc(symm,k_file,kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec);
kv->set_after_vc(kv->nspin,GlobalC::ucell.G,GlobalC::ucell.latvec);
EXPECT_TRUE(kv->kc_done);
EXPECT_TRUE(kv->kd_done);
//clear
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void ESolver_FP::init_after_vc(Input& inp, UnitCell& cell)
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
}

kv.set_after_vc(cell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, cell.G, cell.latvec);
kv.set_after_vc(GlobalV::NSPIN, cell.G, cell.latvec);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

return;
Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ void ESolver_KS<T, Device>::before_all_runners(Input& inp, UnitCell& ucell)
}

//! 6) Setup the k points according to symmetry.
this->kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec);
this->kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec,GlobalV::ofs_running);

ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

//! 7) print information
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
}

// 1.3) Setup k-points according to symmetry.
this->kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec);
this->kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");

Print_Info::setup_parameters(ucell, this->kv);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void ESolver_OF::before_all_runners(Input& inp, UnitCell& ucell)
}

// Setup the k points according to symmetry.
kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec);
kv.set(ucell.symm, GlobalV::global_kpoint_card, GlobalV::NSPIN, ucell.G, ucell.latvec, GlobalV::ofs_running);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running,"INIT K-POINTS");

// print information
Expand Down
4 changes: 1 addition & 3 deletions source/module_ri/exx_lip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,7 @@ void Exx_Lip::read_q_pack(const ModuleSymmetry::Symmetry& symm,

q_pack->kv_ptr = new K_Vectors();
const std::string exx_kpoint_card = GlobalV::global_out_dir + exx_q_pack + GlobalV::global_kpoint_card;
q_pack->kv_ptr->set( symm, exx_kpoint_card, GlobalV::NSPIN, ucell_ptr->G, ucell_ptr->latvec );
// q_pack->kv_ptr->set( symm, exx_kpoint_card, GlobalV::NSPIN, ucell_ptr->G, ucell_ptr->latvec, &Pkpoints );

q_pack->kv_ptr->set( symm, exx_kpoint_card, GlobalV::NSPIN, ucell_ptr->G, ucell_ptr->latvec, GlobalV::ofs_running );

q_pack->wf_ptr = new wavefunc();
q_pack->wf_ptr->allocate(q_pack->kv_ptr->get_nkstot(),
Expand Down

0 comments on commit cecce5c

Please sign in to comment.