From bec10a2b8a85e1dcaf13a9fd4168cf2e611e4bd7 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 9 Jan 2025 12:46:52 +0000 Subject: [PATCH 1/3] set npol to private --- source/module_elecstate/elecstate_pw.cpp | 2 +- source/module_hamilt_general/operator.cpp | 6 +++--- source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp | 4 ++-- .../module_deltaspin/cal_mw_from_lambda.cpp | 8 ++++---- source/module_hamilt_lcao/module_dftu/dftu_pw.cpp | 8 ++++---- source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/velocity_pw.cpp | 4 ++-- source/module_io/write_vxc_lip.hpp | 2 +- source/module_psi/psi.cpp | 4 ++-- source/module_psi/psi.h | 5 +++-- 10 files changed, 24 insertions(+), 23 deletions(-) diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index f55f2ec447..0b4fbe0368 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -271,7 +271,7 @@ void ElecStatePW::cal_becsum(const psi::Psi& psi) { const T one{1, 0}; const T zero{0, 0}; - const int npol = psi.npol; + const int npol = psi.get_npol(); const int npwx = psi.get_nbasis() / npol; const int nbands = psi.get_nbands() * npol; const int nkb = this->ppcell->nkb; diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index 008d5e30e3..3f9e43a99c 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -63,7 +63,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp delete this->hpsi; this->hpsi = new psi::Psi(hpsi_pointer, 1, - nbands / psi_input->npol, + nbands / psi_input->get_npol(), psi_input->get_nbasis(), psi_input->get_nbasis(), true); @@ -86,7 +86,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp default: op->act(nbands, psi_input->get_nbasis(), - psi_input->npol, + psi_input->get_npol(), tmpsi_in, this->hpsi->get_pointer(), psi_input->get_current_nbas(), @@ -105,7 +105,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp } ModuleBase::timer::tick("Operator", "hPsi"); - return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer); + return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer); } template diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp index 94c5c74db7..25b2e4e879 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp @@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_CPU>* psi_t = static_cast, base_device::DEVICE_CPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); @@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain>::cal_mi_pw() psi::Psi, base_device::DEVICE_GPU>* psi_t = static_cast, base_device::DEVICE_GPU>*>(this->psi); const int nbands = psi_t->get_nbands(); const int nks = psi_t->get_nk(); - const int npol = psi_t->npol; + const int npol = psi_t->get_npol(); for(int ik = 0; ik < nks; ik++) { psi_t->fix_k(ik); diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 87a2fa41cc..c4eaca5755 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -252,7 +252,7 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_CPU>* hamilt_t = static_cast, base_device::DEVICE_CPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); @@ -454,7 +454,7 @@ void spinconstrain::SpinConstrain>::update_psi_charge(const hamilt::Hamilt, base_device::DEVICE_GPU>* hamilt_t = static_cast, base_device::DEVICE_GPU>*>(this->p_hamilt); auto* onsite_p = projectors::OnsiteProjector::get_instance(); nbands = psi_t->get_nbands(); - npol = psi_t->npol; + npol = psi_t->get_npol(); nkb = onsite_p->get_tot_nproj(); nk = psi_t->get_nk(); nh_iat = &onsite_p->get_nh(0); diff --git a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp index cc0c3a6c30..4f3bd4abb8 100644 --- a/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp +++ b/source/module_hamilt_lcao/module_dftu/dftu_pw.cpp @@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands * psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { @@ -88,11 +88,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr psi_p->fix_k(ik); onsite_p->tabulate_atomic(ik); - onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer()); + onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer()); const std::complex* becp = onsite_p->get_h_becp(); // becp(nbands*npol , nkb) // mag = wg * \sum_{nh}becp * becp - int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol; + int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol(); int begin_ih = 0; for(int iat = 0; iat < cell.nat; iat++) { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index 2bb69dc131..32a4902221 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -165,7 +165,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_); RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_); //rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq); - rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol); + rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol); // For being compatible with present cal_force and cal_stress framework // uncomment the following code block if you want to use the Onsite_Proj_tools if(this->tab_atomic_ == nullptr) @@ -541,7 +541,7 @@ void projectors::OnsiteProjector::cal_occupations(const psi::Psi #include @@ -134,7 +135,17 @@ class Psi const int& get_current_ngk() const; - const int& get_npol() const {return this->npol;} + const int& get_npol() const + { + if (PARAM.inp.nspin == 4) + { + return 2; + } + else + { + return 1; + } + } // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const; @@ -143,27 +154,22 @@ class Psi T* psi = nullptr; // avoid using C++ STL Device* ctx = {}; // an context identifier for obtaining the device variable - int npol = 1; + bool allocate_inside = true; ///< whether allocate psi inside Psi class + bool k_first = true; + + const int* ngk = nullptr; // dimensions int nk = 1; // number of k points int nbands = 1; // number of bands int nbasis = 1; // number of basis + // mutable values mutable int current_k = 0; // current k point mutable int current_b = 0; // current band index mutable int current_nbasis = 1; // current number of basis of current_k - - // current pointer for getting the psi - mutable T* psi_current = nullptr; - // psi_current = psi + psi_bias; - mutable int psi_bias = 0; - - const int* ngk = nullptr; - - bool k_first = true; - - bool allocate_inside = true; ///< whether allocate psi inside Psi class + mutable T* psi_current = nullptr; // current pointer for getting the psi + mutable int psi_bias = 0; // psi_current = psi + psi_bias; #ifdef __DSP using delete_memory_op = base_device::memory::delete_memory_op_mt; From 881e532d4549d7a55ad951e6dd12c42c4cbc5ff1 Mon Sep 17 00:00:00 2001 From: haozhihan Date: Thu, 9 Jan 2025 13:31:10 +0000 Subject: [PATCH 3/3] fix bug --- source/module_psi/psi.cpp | 13 +++++++++++++ source/module_psi/psi.h | 13 +------------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 51db69d34c..ba8a178032 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -477,6 +477,19 @@ int Psi::get_current_nbas() const return this->current_nbasis; } +template +const int& Psi::get_npol() const +{ + if (PARAM.inp.nspin == 4) + { + return 2; + } + else + { + return 1; + } +} + template const int& Psi::get_ngk(const int ik_in) const { diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index d2aa69a743..d41acae153 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -3,7 +3,6 @@ #include "module_base/module_device/memory_op.h" #include "module_base/module_device/types.h" -#include "module_parameter/parameter.h" #include #include @@ -135,17 +134,7 @@ class Psi const int& get_current_ngk() const; - const int& get_npol() const - { - if (PARAM.inp.nspin == 4) - { - return 2; - } - else - { - return 1; - } - } + const int& get_npol() const; // solve Range: return(pointer of begin, number of bands or k-points) std::tuple to_range(const Range& range) const;