From 1d33d7a5b88263d5ddf362267e3e9cc5e5762cbf Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Wed, 18 Dec 2024 08:59:28 +0900 Subject: [PATCH] fix: rocm types --- fft/src/KokkosFFT_ROCM_types.hpp | 58 +++++++++++++++++++------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index 8e842ab3..ccb23a7f 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -38,6 +38,26 @@ enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; template using TransformType = FFTWTransformType; +/// \brief A class that wraps rocfft_plan_description for RAII +struct ScopedRocfftPlanDescription { + private: + rocfft_plan_description m_description; + + public: + ScopedRocfftPlanDescription() { + rocfft_status status = rocfft_plan_description_create(&m_description); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_description_create failed"); + } + ~ScopedRocfftPlanDescription() noexcept { + rocfft_status status = rocfft_plan_description_destroy(m_description); + if (status != rocfft_status_success) + Kokkos::abort("rocfft_plan_description_destroy failed"); + } + + rocfft_plan_description &description() { return m_description; } +} + /// \brief A class that wraps rocfft for RAII template struct ScopedRocfftPlan { @@ -65,17 +85,7 @@ struct ScopedRocfftPlan { const std::vector &out_extents, const std::vector &fft_extents, int howmany, Direction direction, bool is_inplace) { - std::unique_ptr> const - description(new rocfft_plan_description, - [](rocfft_plan_description *desc) { - if (desc) rocfft_plan_description_destroy(*desc); - desc = nullptr; - }); - - rocfft_status status = rocfft_plan_description_create(&(*description)); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_description_create failed"); + ScopedRocfftPlanDescription scoped_description; auto [in_array_type, out_array_type, fft_direction] = get_in_out_array_type(transform_type, direction); @@ -92,17 +102,17 @@ struct ScopedRocfftPlan { convert_int_type_and_reverse(fft_extents); status = rocfft_plan_description_set_data_layout( - *description, // description handle - in_array_type, // input array type - out_array_type, // output array type - nullptr, // offsets to start of input data - nullptr, // offsets to start of output data - in_strides.size(), // input stride length - in_strides.data(), // input stride data - idist, // input batch distance - out_strides.size(), // output stride length - out_strides.data(), // output stride data - odist); // output batch distance + scoped_description.description(), // description handle + in_array_type, // input array type + out_array_type, // output array type + nullptr, // offsets to start of input data + nullptr, // offsets to start of output data + in_strides.size(), // input stride length + in_strides.data(), // input stride data + idist, // input batch distance + out_strides.size(), // output stride length + out_strides.data(), // output stride data + odist); // output batch distance KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_description_set_data_layout failed"); @@ -115,8 +125,8 @@ struct ScopedRocfftPlan { status = rocfft_plan_create(&m_plan, place, fft_direction, m_precision, reversed_fft_extents.size(), // Dimension reversed_fft_extents.data(), // Lengths - howmany, // Number of transforms - *description // Description + howmany, // Number of transforms + scoped_description.description() // Description ); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_create failed");