From 0c6b33bb953f43cc1a781cd4ea35f21494b7847d Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Sat, 21 Dec 2024 00:08:17 +0900 Subject: [PATCH] Add commit method to scoped rocfft plan --- fft/src/KokkosFFT_ROCM_plans.hpp | 6 +-- fft/src/KokkosFFT_ROCM_transform.hpp | 1 + fft/src/KokkosFFT_ROCM_types.hpp | 58 ++++++++++++++-------------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index e622a3c7..56c1d6c3 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -46,9 +46,9 @@ auto create_plan(const ExecutionSpace& exec_space, KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace); // Create a plan - plan = - std::make_unique(exec_space, type, in_extents, out_extents, - fft_extents, howmany, direction, is_inplace); + plan = std::make_unique(type, in_extents, out_extents, fft_extents, + howmany, direction, is_inplace); + plan->commit(exec_space); // Calculate the total size of the FFT int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, diff --git a/fft/src/KokkosFFT_ROCM_transform.hpp b/fft/src/KokkosFFT_ROCM_transform.hpp index fc4832a4..04d1ead8 100644 --- a/fft/src/KokkosFFT_ROCM_transform.hpp +++ b/fft/src/KokkosFFT_ROCM_transform.hpp @@ -7,6 +7,7 @@ #include #include +#include "KokkosFFT_asserts.hpp" #include "KokkosFFT_ROCM_types.hpp" namespace KokkosFFT { diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index e3907e86..41772933 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -77,8 +77,7 @@ struct ScopedRocfftPlan { BufferViewType m_buffer; public: - ScopedRocfftPlan(const Kokkos::HIP &exec_space, - const FFTWTransformType transform_type, + ScopedRocfftPlan(const FFTWTransformType transform_type, const std::vector &in_extents, const std::vector &out_extents, const std::vector &fft_extents, int howmany, @@ -115,22 +114,35 @@ struct ScopedRocfftPlan { KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_description_set_data_layout failed"); - try { - // inplace or Out-of-place transform - const rocfft_result_placement place = - is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace; - - // Create a plan - status = - rocfft_plan_create(&m_plan, place, fft_direction, m_precision, - reversed_fft_extents.size(), // Dimension - reversed_fft_extents.data(), // Lengths - howmany, // Number of transforms - scoped_description.description() // Description - ); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_create failed"); + // inplace or Out-of-place transform + const rocfft_result_placement place = + is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace; + + // Create a plan + status = rocfft_plan_create(&m_plan, place, fft_direction, m_precision, + reversed_fft_extents.size(), // Dimension + reversed_fft_extents.data(), // Lengths + howmany, // Number of transforms + scoped_description.description() // Description + ); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_create failed"); + } + ~ScopedRocfftPlan() noexcept { cleanup(); } + ScopedRocfftPlan() = delete; + ScopedRocfftPlan(const ScopedRocfftPlan &) = delete; + ScopedRocfftPlan &operator=(const ScopedRocfftPlan &) = delete; + ScopedRocfftPlan &operator=(ScopedRocfftPlan &&) = delete; + ScopedRocfftPlan(ScopedRocfftPlan &&) = delete; + + rocfft_plan plan() const noexcept { return m_plan; } + rocfft_execution_info execution_info() const noexcept { + return m_execution_info; + } + + void commit(const Kokkos::HIP &exec_space) { + try { // Prepare workbuffer and set execution information status = rocfft_execution_info_create(&m_execution_info); KOKKOSFFT_THROW_IF(status != rocfft_status_success, @@ -163,18 +175,6 @@ struct ScopedRocfftPlan { throw; } } - ~ScopedRocfftPlan() noexcept { cleanup(); } - - ScopedRocfftPlan() = delete; - ScopedRocfftPlan(const ScopedRocfftPlan &) = delete; - ScopedRocfftPlan &operator=(const ScopedRocfftPlan &) = delete; - ScopedRocfftPlan &operator=(ScopedRocfftPlan &&) = delete; - ScopedRocfftPlan(ScopedRocfftPlan &&) = delete; - - rocfft_plan plan() const noexcept { return m_plan; } - rocfft_execution_info execution_info() const noexcept { - return m_execution_info; - } private: void cleanup() {