From e7ed6aa6621e8b3c011cd636ff58e93645b19dd0 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 3 Dec 2024 23:33:56 +0900 Subject: [PATCH] Wrapper for rocfft handle --- fft/src/KokkosFFT_ROCM_plans.hpp | 50 ++++++++++------------- fft/src/KokkosFFT_ROCM_transform.hpp | 55 +++++++++++++------------ fft/src/KokkosFFT_ROCM_types.hpp | 60 +++++++++++++++++----------- 3 files changed, 88 insertions(+), 77 deletions(-) diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index e1b115e9..47781097 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -84,16 +84,15 @@ auto compute_strides(const std::vector& extents) // batched transform, over ND Views template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, const OutViewType& out, BufferViewType& buffer, - InfoType& execution_info, Direction direction, - axis_type axes, shape_type s, - bool is_inplace) { + Direction direction, axis_type axes, + shape_type s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -129,8 +128,13 @@ auto create_plan(const ExecutionSpace& exec_space, convert_int_type_and_reverse(fft_extents); // Create the description - rocfft_plan_description description; - rocfft_status status = rocfft_plan_description_create(&description); + std::unique_ptr> const + description(new rocfft_plan_description, + [](rocfft_plan_description* desc) { + rocfft_plan_description_destroy(*desc); + }); + rocfft_status status = rocfft_plan_description_create(&(*description)); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_description_create failed"); @@ -139,7 +143,7 @@ auto create_plan(const ExecutionSpace& exec_space, rocfft_precision precision = get_in_out_array_type(); status = rocfft_plan_description_set_data_layout( - description, // description handle + *description, // description handle in_array_type, // input array type out_array_type, // output array type nullptr, // offsets to start of input data @@ -159,56 +163,46 @@ auto create_plan(const ExecutionSpace& exec_space, // Create a plan plan = std::make_unique(); - status = rocfft_plan_create(&(*plan), place, fft_direction, precision, + status = rocfft_plan_create(&(plan->plan()), place, fft_direction, precision, reversed_fft_extents.size(), // Dimension reversed_fft_extents.data(), // Lengths - howmany, // Number of transforms - description // Description + howmany, // Number of transforms + *description // Description ); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_create failed"); + plan->m_is_plan_created = true; // Prepare workbuffer and set execution information - status = rocfft_execution_info_create(&execution_info); + status = rocfft_execution_info_create(&(plan->execution_info())); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execution_info_create failed"); + plan->m_is_info_created = true; // set stream // NOTE: The stream must be of type hipStream_t. // It is an error to pass the address of a hipStream_t object. hipStream_t stream = exec_space.hip_stream(); - status = rocfft_execution_info_set_stream(execution_info, stream); + status = rocfft_execution_info_set_stream(plan->execution_info(), stream); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execution_info_set_stream failed"); std::size_t workbuffersize = 0; - status = rocfft_plan_get_work_buffer_size(*plan, &workbuffersize); + status = rocfft_plan_get_work_buffer_size(plan->plan(), &workbuffersize); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_get_work_buffer_size failed"); if (workbuffersize > 0) { - buffer = BufferViewType("work_buffer", workbuffersize); - status = rocfft_execution_info_set_work_buffer( - execution_info, (void*)buffer.data(), workbuffersize); + plan->m_buffer = BufferViewType("work_buffer", workbuffersize); + status = rocfft_execution_info_set_work_buffer( + plan->execution_info(), (void*)plan->m_buffer.data(), workbuffersize); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execution_info_set_work_buffer failed"); } - status = rocfft_plan_description_destroy(description); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_description_destroy failed"); - return fft_size; } -template , - std::nullptr_t> = nullptr> -void destroy_plan_and_info(std::unique_ptr& plan, - InfoType& execution_info) { - rocfft_execution_info_destroy(execution_info); - rocfft_plan_destroy(*plan); -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_ROCM_transform.hpp b/fft/src/KokkosFFT_ROCM_transform.hpp index 2c6d50b8..e4d45abd 100644 --- a/fft/src/KokkosFFT_ROCM_transform.hpp +++ b/fft/src/KokkosFFT_ROCM_transform.hpp @@ -11,60 +11,65 @@ namespace KokkosFFT { namespace Impl { -inline void exec_plan(rocfft_plan& plan, float* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, float* idata, + std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for R2C failed"); } -inline void exec_plan(rocfft_plan& plan, double* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, double* idata, + std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for D2Z failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - float* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, std::complex* idata, + float* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for C2R failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - double* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, std::complex* idata, + double* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for Z2D failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, std::complex* idata, + std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for C2C failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +template +void exec_plan(ScopedPlanType& scoped_plan, std::complex* idata, + std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for Z2Z failed"); } - } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index 60af7e57..5b2e023f 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -8,6 +8,9 @@ #include #include #include "KokkosFFT_common_types.hpp" +#if defined(ENABLE_HOST_AND_DEVICE) +#include "KokkosFFT_FFTW_Types.hpp" +#endif // Check the size of complex type static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); @@ -17,27 +20,45 @@ static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); static_assert(alignof(std::complex) <= alignof(Kokkos::complex)); -#ifdef ENABLE_HOST_AND_DEVICE -#include -#include "KokkosFFT_utils.hpp" -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); -#endif - namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; constexpr FFTDirectionType ROCFFT_FORWARD = 1; constexpr FFTDirectionType ROCFFT_BACKWARD = -1; +#if !defined(ENABLE_HOST_AND_DEVICE) enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; +#endif template using TransformType = FFTWTransformType; +/// \brief A class that wraps rocfft for RAII +template +struct ScopedRocfftPlanType { + using floating_point_type = KokkosFFT::Impl::base_floating_point_type; + rocfft_plan m_plan; + rocfft_execution_info m_execution_info; + + using BufferViewType = + Kokkos::View *, ExecutionSpace>; + + bool m_is_info_created = false; + bool m_is_plan_created = false; + + //! Internal work buffer + BufferViewType m_buffer; + + ScopedRocfftPlanType() {} + ~ScopedRocfftPlanType() { + if (m_is_info_created) rocfft_execution_info_destroy(m_execution_info); + if (m_is_plan_created) rocfft_plan_destroy(m_plan); + } + + rocfft_plan &plan() { return m_plan; } + rocfft_execution_info &execution_info() { return m_execution_info; } +}; + // Define fft transform types template struct transform_type { @@ -76,7 +97,7 @@ struct transform_type, static constexpr FFTWTransformType type() { return m_type; }; }; -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) template struct FFTDataType { @@ -92,18 +113,12 @@ struct FFTDataType { template struct FFTPlanType { - using fftwHandle = std::conditional_t< - std::is_same_v, float>, - fftwf_plan, fftw_plan>; + using fftw_plan_type = ScopedFFTWPlanType; + using rocfft_plan_type = ScopedRocfftPlanType; using type = std::conditional_t, - rocfft_plan, fftwHandle>; + rocfft_plan_type, fftw_plan_type>; }; -template -using FFTInfoType = - std::conditional_t, - rocfft_execution_info, int>; - template auto direction_type(Direction direction) { static constexpr FFTDirectionType FORWARD = @@ -126,12 +141,9 @@ struct FFTDataType { template struct FFTPlanType { - using type = rocfft_plan; + using type = ScopedRocfftPlanType; }; -template -using FFTInfoType = rocfft_execution_info; - template auto direction_type(Direction direction) { return direction == Direction::forward ? ROCFFT_FORWARD : ROCFFT_BACKWARD;