From 5ab4fd67d47e8beb699c1a7f82db0cac14f407fe Mon Sep 17 00:00:00 2001 From: Cstandardlib <49788094+Cstandardlib@users.noreply.github.com> Date: Fri, 26 Jul 2024 21:33:24 +0800 Subject: [PATCH] Refactor: allocate david gpu memory in unified constructor&destructor (#4806) --- source/module_hsolver/diago_david.cpp | 31 +++++++++++++-------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/source/module_hsolver/diago_david.cpp b/source/module_hsolver/diago_david.cpp index 20f87f2c76..e74e7f787d 100644 --- a/source/module_hsolver/diago_david.cpp +++ b/source/module_hsolver/diago_david.cpp @@ -110,6 +110,14 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, // lagrange_matrix(nband, nband); // for orthogonalization resmem_complex_op()(this->ctx, this->lagrange_matrix, nband * nband); setmem_complex_op()(this->ctx, this->lagrange_matrix, 0, nband * nband); + +#if defined(__CUDA) || defined(__ROCM) + if (this->device == base_device::GpuDevice) + { + resmem_var_op()(this->ctx, this->d_precondition, dim); + syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, dim); + } +#endif } /** @@ -130,6 +138,13 @@ DiagoDavid::~DiagoDavid() delmem_complex_op()(this->ctx, this->vcc); delmem_complex_op()(this->ctx, this->lagrange_matrix); base_device::memory::delete_memory_op()(this->cpu_ctx, this->eigenvalue); + // If the device is a GPU device, free the d_precondition array. +#if defined(__CUDA) || defined(__ROCM) + if (this->device == base_device::GpuDevice) + { + delmem_var_op()(this->ctx, this->d_precondition); + } +#endif } template @@ -1135,14 +1150,6 @@ int DiagoDavid::diag(const HPsiFunc& hpsi_func, int ntry = 0; this->notconv = 0; -#if defined(__CUDA) || defined(__ROCM) - if (this->device == base_device::GpuDevice) - { - resmem_var_op()(this->ctx, this->d_precondition, ldPsi); - syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, ldPsi); - } -#endif - int sum_dav_iter = 0; do { @@ -1155,14 +1162,6 @@ int DiagoDavid::diag(const HPsiFunc& hpsi_func, std::cout << "\n notconv = " << this->notconv; std::cout << "\n DiagoDavid::diag', too many bands are not converged! \n"; } - // If the device is a GPU device, free the d_precondition array. -#if defined(__CUDA) || defined(__ROCM) - if (this->device == base_device::GpuDevice) - { - delmem_var_op()(this->ctx, this->d_precondition); - } -#endif - return sum_dav_iter; }