Skip to content

Commit

Permalink
fix: buffer size for rocfft execution info
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jan 8, 2025
1 parent b4b272a commit a23a3db
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ struct ScopedRocfftPlanDescription {
};

/// \brief A class that wraps rocfft_execution_info for RAII
template <typename FloatingPointType>
struct ScopedRocfftExecutionInfo {
private:
using BufferViewType =
Kokkos::View<Kokkos::complex<FloatingPointType> *, Kokkos::HIP>;
rocfft_execution_info m_execution_info;

//! Internal work buffer
BufferViewType m_buffer;
void *m_workbuffer = nullptr;

public:
ScopedRocfftExecutionInfo() {
Expand All @@ -84,6 +81,10 @@ struct ScopedRocfftExecutionInfo {
"rocfft_execution_info_create failed");
}
~ScopedRocfftExecutionInfo() noexcept {
if (m_workbuffer != nullptr) {
hipError_t hip_status = hipFree(m_workbuffer);
if (hip_status != hipSuccess) Kokkos::abort("hipFree failed");
}
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
Expand Down Expand Up @@ -111,9 +112,10 @@ struct ScopedRocfftExecutionInfo {

// Set work buffer
if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
hipError_t hip_status = hipMalloc(&m_workbuffer, workbuffersize);
KOKKOSFFT_THROW_IF(hip_status != hipSuccess, "hipMalloc failed");
status = rocfft_execution_info_set_work_buffer(
m_execution_info, m_workbuffer, workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
Expand All @@ -124,14 +126,12 @@ struct ScopedRocfftExecutionInfo {
template <typename T>
struct ScopedRocfftPlan {
private:
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
using ScopedRocfftExecutionInfoType =
ScopedRocfftExecutionInfo<floating_point_type>;
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
rocfft_precision m_precision = std::is_same_v<floating_point_type, float>
? rocfft_precision_single
: rocfft_precision_double;
rocfft_plan m_plan;
std::unique_ptr<ScopedRocfftExecutionInfoType> m_execution_info;
std::unique_ptr<ScopedRocfftExecutionInfo> m_execution_info;

public:
ScopedRocfftPlan(const FFTWTransformType transform_type,
Expand Down Expand Up @@ -209,7 +209,7 @@ struct ScopedRocfftPlan {
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

m_execution_info = std::make_unique<ScopedRocfftExecutionInfoType>();
m_execution_info = std::make_unique<ScopedRocfftExecutionInfo>();
m_execution_info->setup(exec_space, workbuffersize);
}

Expand Down

0 comments on commit a23a3db

Please sign in to comment.