diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index ed8b06a8..05856d68 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -66,15 +66,12 @@ struct ScopedRocfftPlanDescription { }; /// \brief A class that wraps rocfft_execution_info for RAII -template struct ScopedRocfftExecutionInfo { private: - using BufferViewType = - Kokkos::View *, Kokkos::HIP>; rocfft_execution_info m_execution_info; //! Internal work buffer - BufferViewType m_buffer; + void *m_workbuffer = nullptr; public: ScopedRocfftExecutionInfo() { @@ -84,6 +81,10 @@ struct ScopedRocfftExecutionInfo { "rocfft_execution_info_create failed"); } ~ScopedRocfftExecutionInfo() noexcept { + if (m_workbuffer != nullptr) { + hipError_t hip_status = hipFree(m_workbuffer); + if (hip_status != hipSuccess) Kokkos::abort("hipFree failed"); + } rocfft_status status = rocfft_execution_info_destroy(m_execution_info); if (status != rocfft_status_success) Kokkos::abort("rocfft_execution_info_destroy failed"); @@ -111,9 +112,10 @@ struct ScopedRocfftExecutionInfo { // Set work buffer if (workbuffersize > 0) { - m_buffer = BufferViewType("workbuffer", workbuffersize); - status = rocfft_execution_info_set_work_buffer( - m_execution_info, (void *)m_buffer.data(), workbuffersize); + hipError_t hip_status = hipMalloc(&m_workbuffer, workbuffersize); + KOKKOSFFT_THROW_IF(hip_status != hipSuccess, "hipMalloc failed"); + status = rocfft_execution_info_set_work_buffer( + m_execution_info, m_workbuffer, workbuffersize); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execution_info_set_work_buffer failed"); } @@ -124,14 +126,12 @@ struct ScopedRocfftExecutionInfo { template struct ScopedRocfftPlan { private: - using floating_point_type = KokkosFFT::Impl::base_floating_point_type; - using ScopedRocfftExecutionInfoType = - ScopedRocfftExecutionInfo; + using floating_point_type = KokkosFFT::Impl::base_floating_point_type; rocfft_precision m_precision = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; rocfft_plan m_plan; - std::unique_ptr m_execution_info; + std::unique_ptr m_execution_info; public: ScopedRocfftPlan(const FFTWTransformType transform_type, @@ -209,7 +209,7 @@ struct ScopedRocfftPlan { KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_plan_get_work_buffer_size failed"); - m_execution_info = std::make_unique(); + m_execution_info = std::make_unique(); m_execution_info->setup(exec_space, workbuffersize); }