Skip to content

Commit

Permalink
Improve the cleanup logic for rocfft plan
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 19, 2024
1 parent a350598 commit 27501ac
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 69 deletions.
25 changes: 10 additions & 15 deletions fft/src/KokkosFFT_ROCM_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

#include <complex>
#include <rocfft/rocfft.h>
#include "KokkosFFT_asserts.hpp"
#include "KokkosFFT_ROCM_types.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, float* idata,
void exec_plan(ScopedRocfftPlan<float>& scoped_plan, float* idata,
std::complex<float>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -21,8 +20,7 @@ void exec_plan(ScopedPlanType& scoped_plan, float* idata,
"rocfft_execute for R2C failed");
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, double* idata,
void exec_plan(ScopedRocfftPlan<double>& scoped_plan, double* idata,
std::complex<double>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -31,8 +29,7 @@ void exec_plan(ScopedPlanType& scoped_plan, double* idata,
"rocfft_execute for D2Z failed");
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
void exec_plan(ScopedRocfftPlan<float>& scoped_plan, std::complex<float>* idata,
float* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -41,18 +38,16 @@ void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
"rocfft_execute for C2R failed");
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<double>* idata,
double* odata, int /*direction*/) {
void exec_plan(ScopedRocfftPlan<double>& scoped_plan,
std::complex<double>* idata, double* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
scoped_plan.execution_info());
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execute for Z2D failed");
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
void exec_plan(ScopedRocfftPlan<float>& scoped_plan, std::complex<float>* idata,
std::complex<float>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -61,9 +56,9 @@ void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
"rocfft_execute for C2C failed");
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<double>* idata,
std::complex<double>* odata, int /*direction*/) {
void exec_plan(ScopedRocfftPlan<double>& scoped_plan,
std::complex<double>* idata, std::complex<double>* odata,
int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
scoped_plan.execution_info());
Expand Down
108 changes: 54 additions & 54 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <numeric>
#include <algorithm>
#include <complex>
#include <iostream>
#include <rocfft/rocfft.h>
#include <Kokkos_Abort.hpp>
#include "KokkosFFT_common_types.hpp"
Expand Down Expand Up @@ -72,9 +73,6 @@ struct ScopedRocfftPlan {
rocfft_plan m_plan;
rocfft_execution_info m_execution_info;

bool m_is_plan_created = false;
bool m_is_info_created = false;

//! Internal work buffer
BufferViewType m_buffer;

Expand Down Expand Up @@ -117,63 +115,55 @@ struct ScopedRocfftPlan {
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_set_data_layout failed");

// inplace or Out-of-place transform
const rocfft_result_placement place =
is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace;

// Create a plan
status = rocfft_plan_create(&m_plan, place, fft_direction, m_precision,
reversed_fft_extents.size(), // Dimension
reversed_fft_extents.data(), // Lengths
howmany, // Number of transforms
scoped_description.description() // Description
);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");

m_is_plan_created = true;

// Prepare workbuffer and set execution information
status = rocfft_execution_info_create(&m_execution_info);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");

m_is_info_created = true;
try {
// inplace or Out-of-place transform
const rocfft_result_placement place =
is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace;

// Create a plan
status =
rocfft_plan_create(&m_plan, place, fft_direction, m_precision,
reversed_fft_extents.size(), // Dimension
reversed_fft_extents.data(), // Lengths
howmany, // Number of transforms
scoped_description.description() // Description
);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");

// set stream
// NOTE: The stream must be of type hipStream_t.
// It is an error to pass the address of a hipStream_t object.
hipStream_t stream = exec_space.hip_stream();
status = rocfft_execution_info_set_stream(m_execution_info, stream);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_stream failed");
// Prepare workbuffer and set execution information
status = rocfft_execution_info_create(&m_execution_info);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");

// Set work buffer
std::size_t workbuffersize = 0;
status = rocfft_plan_get_work_buffer_size(m_plan, &workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");
// set stream
// NOTE: The stream must be of type hipStream_t.
// It is an error to pass the address of a hipStream_t object.
hipStream_t stream = exec_space.hip_stream();
status = rocfft_execution_info_set_stream(m_execution_info, stream);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_stream failed");

if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
// Set work buffer
std::size_t workbuffersize = 0;
status = rocfft_plan_get_work_buffer_size(m_plan, &workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
}
~ScopedRocfftPlan() noexcept {
if (m_is_info_created) {
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
}
if (m_is_plan_created) {
rocfft_status status = rocfft_plan_destroy(m_plan);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_destroy failed");
"rocfft_plan_get_work_buffer_size failed");

if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
}
~ScopedRocfftPlan() noexcept { cleanup(); }

ScopedRocfftPlan() = delete;
ScopedRocfftPlan(const ScopedRocfftPlan &) = delete;
Expand All @@ -185,6 +175,16 @@ struct ScopedRocfftPlan {
rocfft_execution_info &execution_info() { return m_execution_info; }

private:
void cleanup() {
rocfft_status status = rocfft_plan_destroy(m_plan);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_destroy failed");

rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
}

// Helper to get input and output array type and direction from transform type
auto get_in_out_array_type(FFTWTransformType type, Direction direction) {
rocfft_array_type in_array_type, out_array_type;
Expand Down

0 comments on commit 27501ac

Please sign in to comment.