From 457c77e33e212d9b63eb0da0ff304fc1a30e86fe Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Wed, 4 Dec 2024 03:02:18 +0900 Subject: [PATCH] fix: work buffer allocation --- fft/src/KokkosFFT_Cuda_types.hpp | 5 ++--- fft/src/KokkosFFT_HIP_types.hpp | 5 ++--- fft/src/KokkosFFT_ROCM_plans.hpp | 4 ++-- fft/src/KokkosFFT_ROCM_types.hpp | 11 +++++++---- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/fft/src/KokkosFFT_Cuda_types.hpp b/fft/src/KokkosFFT_Cuda_types.hpp index ff0d7f48..6869301f 100644 --- a/fft/src/KokkosFFT_Cuda_types.hpp +++ b/fft/src/KokkosFFT_Cuda_types.hpp @@ -24,7 +24,6 @@ namespace Impl { using FFTDirectionType = int; /// \brief A class that wraps cufft for RAII -template struct ScopedCufftPlanType { cufftHandle m_plan; @@ -129,7 +128,7 @@ struct transform_type, template struct FFTPlanType { using fftw_plan_type = ScopedFFTWPlanType; - using cufft_plan_type = ScopedCufftPlanType; + using cufft_plan_type = ScopedCufftPlanType; using type = std::conditional_t, cufft_plan_type, fftw_plan_type>; }; @@ -192,7 +191,7 @@ struct transform_type, template struct FFTPlanType { - using type = ScopedCufftPlanType; + using type = ScopedCufftPlanType; }; template diff --git a/fft/src/KokkosFFT_HIP_types.hpp b/fft/src/KokkosFFT_HIP_types.hpp index 0005d385..2e0d1c01 100644 --- a/fft/src/KokkosFFT_HIP_types.hpp +++ b/fft/src/KokkosFFT_HIP_types.hpp @@ -24,7 +24,6 @@ namespace Impl { using FFTDirectionType = int; /// \brief A class that wraps hipfft for RAII -template struct ScopedHIPfftPlanType { hipfftHandle m_plan; @@ -129,7 +128,7 @@ struct transform_type, template struct FFTPlanType { using fftw_plan_type = ScopedFFTWPlanType; - using hipfft_plan_type = ScopedHIPfftPlanType; + using hipfft_plan_type = ScopedHIPfftPlanType; using type = std::conditional_t, hipfft_plan_type, fftw_plan_type>; }; @@ -192,7 +191,7 @@ struct transform_type, template struct FFTPlanType { - using type = ScopedHIPfftPlanType; + using type = ScopedHIPfftPlanType; }; template diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index 1cb67090..40459511 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -192,8 +192,8 @@ auto create_plan(const ExecutionSpace& exec_space, "rocfft_plan_get_work_buffer_size failed"); if (workbuffersize > 0) { - plan->m_buffer = BufferViewType("work_buffer", workbuffersize); - status = rocfft_execution_info_set_work_buffer( + plan->allocate_work_buffer(workbuffersize); + status = rocfft_execution_info_set_work_buffer( plan->execution_info(), (void*)plan->m_buffer.data(), workbuffersize); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execution_info_set_work_buffer failed"); diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index 5b2e023f..07d8b214 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -34,9 +34,9 @@ template using TransformType = FFTWTransformType; /// \brief A class that wraps rocfft for RAII -template +template struct ScopedRocfftPlanType { - using floating_point_type = KokkosFFT::Impl::base_floating_point_type; + using floating_point_type = KokkosFFT::Impl::base_floating_point_type; rocfft_plan m_plan; rocfft_execution_info m_execution_info; @@ -55,6 +55,9 @@ struct ScopedRocfftPlanType { if (m_is_plan_created) rocfft_plan_destroy(m_plan); } + void allocate_work_buffer(std::size_t workbuffersize) { + m_buffer = BufferViewType("work buffer", workbuffersize); + } rocfft_plan &plan() { return m_plan; } rocfft_execution_info &execution_info() { return m_execution_info; } }; @@ -114,7 +117,7 @@ struct FFTDataType { template struct FFTPlanType { using fftw_plan_type = ScopedFFTWPlanType; - using rocfft_plan_type = ScopedRocfftPlanType; + using rocfft_plan_type = ScopedRocfftPlanType; using type = std::conditional_t, rocfft_plan_type, fftw_plan_type>; }; @@ -141,7 +144,7 @@ struct FFTDataType { template struct FFTPlanType { - using type = ScopedRocfftPlanType; + using type = ScopedRocfftPlanType; }; template