Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

not ready #5840

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
{
const T one{1, 0};
const T zero{0, 0};
const int npol = psi.npol;
const int npol = psi.get_npol();
const int npwx = psi.get_nbasis() / npol;
const int nbands = psi.get_nbands() * npol;
const int nkb = this->ppcell->nkb;
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
delete this->hpsi;
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
1,
nbands / psi_input->npol,
nbands / psi_input->get_npol(),
psi_input->get_nbasis(),
psi_input->get_nbasis(),
true);
Expand All @@ -86,7 +86,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
default:
op->act(nbands,
psi_input->get_nbasis(),
psi_input->npol,
psi_input->get_npol(),
tmpsi_in,
this->hpsi->get_pointer(),
psi_input->get_current_nbas(),
Expand All @@ -105,7 +105,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
}
ModuleBase::timer::tick("Operator", "hPsi");

return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer);
}

template <typename T, typename Device>
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_CPU>*>(this->psi);
const int nbands = psi_t->get_nbands();
const int nks = psi_t->get_nk();
const int npol = psi_t->npol;
const int npol = psi_t->get_npol();
for(int ik = 0; ik < nks; ik++)
{
psi_t->fix_k(ik);
Expand Down Expand Up @@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_GPU>*>(this->psi);
const int nbands = psi_t->get_nbands();
const int nks = psi_t->get_nk();
const int npol = psi_t->npol;
const int npol = psi_t->get_npol();
for(int ik = 0; ik < nks; ik++)
{
psi_t->fix_k(ik);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mw_from_lambda(int
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>*>(this->p_hamilt);
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_CPU>::get_instance();
nbands = psi_t->get_nbands();
npol = psi_t->npol;
npol = psi_t->get_npol();
nkb = onsite_p->get_tot_nproj();
nk = psi_t->get_nk();
nh_iat = &onsite_p->get_nh(0);
Expand Down Expand Up @@ -252,7 +252,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mw_from_lambda(int
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>*>(this->p_hamilt);
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_GPU>::get_instance();
nbands = psi_t->get_nbands();
npol = psi_t->npol;
npol = psi_t->get_npol();
nkb = onsite_p->get_tot_nproj();
nk = psi_t->get_nk();
nh_iat = &onsite_p->get_nh(0);
Expand Down Expand Up @@ -382,7 +382,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>*>(this->p_hamilt);
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_CPU>::get_instance();
nbands = psi_t->get_nbands();
npol = psi_t->npol;
npol = psi_t->get_npol();
nkb = onsite_p->get_tot_nproj();
nk = psi_t->get_nk();
nh_iat = &onsite_p->get_nh(0);
Expand Down Expand Up @@ -454,7 +454,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* hamilt_t = static_cast<hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>*>(this->p_hamilt);
auto* onsite_p = projectors::OnsiteProjector<double, base_device::DEVICE_GPU>::get_instance();
nbands = psi_t->get_nbands();
npol = psi_t->npol;
npol = psi_t->get_npol();
nkb = onsite_p->get_tot_nproj();
nk = psi_t->get_nk();
nh_iat = &onsite_p->get_nh(0);
Expand Down
8 changes: 4 additions & 4 deletions source/module_hamilt_lcao/module_dftu/dftu_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr
psi_p->fix_k(ik);
onsite_p->tabulate_atomic(ik);

onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer());
onsite_p->overlap_proj_psi(nbands * psi_p->get_npol(), psi_p->get_pointer());
const std::complex<double>* becp = onsite_p->get_h_becp();
// becp(nbands*npol , nkb)
// mag = wg * \sum_{nh}becp * becp
int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol;
int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol();
int begin_ih = 0;
for(int iat = 0; iat < cell.nat; iat++)
{
Expand Down Expand Up @@ -88,11 +88,11 @@ void DFTU::cal_occ_pw(const int iter, const void* psi_in, const ModuleBase::matr
psi_p->fix_k(ik);
onsite_p->tabulate_atomic(ik);

onsite_p->overlap_proj_psi(nbands*psi_p->npol, psi_p->get_pointer());
onsite_p->overlap_proj_psi(nbands*psi_p->get_npol(), psi_p->get_pointer());
const std::complex<double>* becp = onsite_p->get_h_becp();
// becp(nbands*npol , nkb)
// mag = wg * \sum_{nh}becp * becp
int nkb = onsite_p->get_size_becp() / nbands / psi_p->npol;
int nkb = onsite_p->get_size_becp() / nbands / psi_p->get_npol();
int begin_ih = 0;
for(int iat = 0; iat < cell.nat; iat++)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void projectors::OnsiteProjector<T, Device>::init(const std::string& orbital_dir
RadialProjection::RadialProjector::_build_backward_map(it2iproj, lproj, irow2it_, irow2iproj_, irow2m_);
RadialProjection::RadialProjector::_build_forward_map(it2ia, it2iproj, lproj, itiaiprojm2irow_);
//rp_._build_sbt_tab(rgrid, projs, lproj, nq, dq);
rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.npol, tab, nhtol);
rp_._build_sbt_tab(nproj, rgrid, projs, lproj, nq, dq, ucell_in->omega, psi.get_npol(), tab, nhtol);
// For being compatible with present cal_force and cal_stress framework
// uncomment the following code block if you want to use the Onsite_Proj_tools
if(this->tab_atomic_ == nullptr)
Expand Down Expand Up @@ -541,7 +541,7 @@ void projectors::OnsiteProjector<T, Device>::cal_occupations(const psi::Psi<std:
}
// std::cout << __FILE__ << ":" << __LINE__ << " nbands = " << nbands << std::endl;
this->overlap_proj_psi(
nbands * psi_in->npol,
nbands * psi_in->get_npol(),
psi_in->get_pointer());
const std::complex<double>* becp_p = this->get_h_becp();
// becp(nbands*npol , nkb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void Velocity::act

const int npw = psi_in->get_current_nbas();

const int max_npw = psi_in->get_nbasis() / psi_in->npol;
const int npol = psi_in->npol;
const int max_npw = psi_in->get_nbasis() / psi_in->get_npol();
const int npol = psi_in->get_npol();
const std::complex<double>* tmpsi_in = psi0;
std::complex<double>* tmhpsi = vpsi;
// -------------
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/write_vxc_lip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ namespace ModuleIO
// psi::Psi<T> hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas());
// vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik));
// }
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.get_npol(), &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik));
delete vxcs_op_pw;
std::vector<T> vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands());
Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands);
Expand Down
29 changes: 17 additions & 12 deletions source/module_psi/psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ Range::Range(const bool k_first_in, const size_t index_1_in, const size_t range_
// Constructor 0: basic
template <typename T, typename Device>
Psi<T, Device>::Psi()
{
this->npol = PARAM.globalv.npol;
}
{}

template <typename T, typename Device>
Psi<T, Device>::~Psi()
Expand All @@ -53,7 +51,6 @@ Psi<T, Device>::Psi(const int nk_in, const int nbd_in, const int nbs_in, const i
assert(nbs_in > 0);

this->k_first = k_first_in;
this->npol = PARAM.globalv.npol;
this->allocate_inside = true;

this->ngk = ngk_in; // modify later
Expand Down Expand Up @@ -91,7 +88,6 @@ Psi<T, Device>::Psi(const int nk_in,
assert(nbs_in > 0);

this->k_first = k_first_in;
this->npol = PARAM.globalv.npol;
this->allocate_inside = true;

this->ngk = ngk_in.data(); // modify later
Expand Down Expand Up @@ -129,7 +125,6 @@ Psi<T, Device>::Psi(T* psi_pointer,
// assert(nk_in == 1); // NOTE because lr/utils/lr_uril.hpp func & get_psi_spin func

this->k_first = k_first_in;
this->npol = PARAM.globalv.npol;
this->allocate_inside = false;

this->ngk = nullptr;
Expand Down Expand Up @@ -161,7 +156,6 @@ Psi<T, Device>::Psi(const int nk_in,
assert(nk_in == 1);

this->k_first = k_first_in;
this->npol = PARAM.globalv.npol;
this->allocate_inside = true;

this->ngk = nullptr;
Expand Down Expand Up @@ -191,7 +185,6 @@ template <typename T, typename Device>
Psi<T, Device>::Psi(const Psi& psi_in)
{
this->ngk = psi_in.ngk;
this->npol = psi_in.npol;
this->nk = psi_in.get_nk();
this->nbands = psi_in.get_nbands();
this->nbasis = psi_in.get_nbasis();
Expand All @@ -218,7 +211,6 @@ template <typename T_in, typename Device_in>
Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
{
this->ngk = psi_in.get_ngk_pointer();
this->npol = psi_in.npol;
this->nk = psi_in.get_nk();
this->nbands = psi_in.get_nbands();
this->nbasis = psi_in.get_nbasis();
Expand Down Expand Up @@ -331,7 +323,7 @@ const int& Psi<T, Device>::get_psi_bias() const
template <typename T, typename Device>
const int& Psi<T, Device>::get_current_ngk() const
{
if (this->npol == 1)
if (PARAM.inp.nspin != 4)
{
return this->current_nbasis;
}
Expand Down Expand Up @@ -485,6 +477,19 @@ int Psi<T, Device>::get_current_nbas() const
return this->current_nbasis;
}

template <typename T, typename Device>
const int& Psi<T, Device>::get_npol() const
{
if (PARAM.inp.nspin == 4)
{
return 2;
}
else
{
return 1;
}
}

template <typename T, typename Device>
const int& Psi<T, Device>::get_ngk(const int ik_in) const
{
Expand Down Expand Up @@ -519,13 +524,13 @@ std::tuple<const T*, int> Psi<T, Device>::to_range(const Range& range) const
else if (i1 < 0) // [r1, r2] is the range of index1 with length m
{
const T* p = &this->psi[r1 * (k_first ? this->nbands : this->nk) * this->nbasis];
int m = (r2 - r1 + 1) * this->npol;
int m = (r2 - r1 + 1) * this->get_npol();
return std::tuple<const T*, int>(p, m);
}
else // [r1, r2] is the range of index2 with length m
{
const T* p = &this->psi[(i1 * (k_first ? this->nbands : this->nk) + r1) * this->nbasis];
int m = (r2 - r1 + 1) * this->npol;
int m = (r2 - r1 + 1) * this->get_npol();
return std::tuple<const T*, int>(p, m);
}
}
Expand Down
22 changes: 9 additions & 13 deletions source/module_psi/psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,35 +134,31 @@ class Psi

const int& get_current_ngk() const;

const int& get_npol() const;

// solve Range: return(pointer of begin, number of bands or k-points)
std::tuple<const T*, int> to_range(const Range& range) const;

int npol = 1;

private:
T* psi = nullptr; // avoid using C++ STL

Device* ctx = {}; // an context identifier for obtaining the device variable
bool allocate_inside = true; ///< whether allocate psi inside Psi class
bool k_first = true;

const int* ngk = nullptr;

// dimensions
int nk = 1; // number of k points
int nbands = 1; // number of bands
int nbasis = 1; // number of basis

// mutable values
mutable int current_k = 0; // current k point
mutable int current_b = 0; // current band index
mutable int current_nbasis = 1; // current number of basis of current_k

// current pointer for getting the psi
mutable T* psi_current = nullptr;
// psi_current = psi + psi_bias;
mutable int psi_bias = 0;

const int* ngk = nullptr;

bool k_first = true;

bool allocate_inside = true; ///< whether allocate psi inside Psi class
mutable T* psi_current = nullptr; // current pointer for getting the psi
mutable int psi_bias = 0; // psi_current = psi + psi_bias;

#ifdef __DSP
using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
Expand Down
Loading