Skip to content

Commit

Permalink
Refactor: allocate david gpu memory in unified constructor&destructor (
Browse files Browse the repository at this point in the history
  • Loading branch information
Cstandardlib authored Jul 26, 2024
1 parent 14c4f42 commit 5ab4fd6
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
// lagrange_matrix(nband, nband); // for orthogonalization
resmem_complex_op()(this->ctx, this->lagrange_matrix, nband * nband);
setmem_complex_op()(this->ctx, this->lagrange_matrix, 0, nband * nband);

#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->d_precondition, dim);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, dim);
}
#endif
}

/**
Expand All @@ -130,6 +138,13 @@ DiagoDavid<T, Device>::~DiagoDavid()
delmem_complex_op()(this->ctx, this->vcc);
delmem_complex_op()(this->ctx, this->lagrange_matrix);
base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx, this->eigenvalue);
// If the device is a GPU device, free the d_precondition array.
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif
}

template <typename T, typename Device>
Expand Down Expand Up @@ -1135,14 +1150,6 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
int ntry = 0;
this->notconv = 0;

#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->d_precondition, ldPsi);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, ldPsi);
}
#endif

int sum_dav_iter = 0;
do
{
Expand All @@ -1155,14 +1162,6 @@ int DiagoDavid<T, Device>::diag(const HPsiFunc& hpsi_func,
std::cout << "\n notconv = " << this->notconv;
std::cout << "\n DiagoDavid::diag', too many bands are not converged! \n";
}
// If the device is a GPU device, free the d_precondition array.
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif

return sum_dav_iter;
}

Expand Down

0 comments on commit 5ab4fd6

Please sign in to comment.