Skip to content

Commit

Permalink
Add ScopedExecutionInfo for rocm backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jan 2, 2025
1 parent b405320 commit 28d9891
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 53 deletions.
4 changes: 2 additions & 2 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct ScopedFFTWPlan {
plan_type plan() const noexcept { return m_plan; }

private:
void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
static void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
Expand All @@ -138,7 +138,7 @@ struct ScopedFFTWPlan {
#endif
}

void cleanup_threads() {
static void cleanup_threads() {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
Expand Down
115 changes: 64 additions & 51 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,67 @@ struct ScopedRocfftPlanDescription {
rocfft_plan_description description() const noexcept { return m_description; }
};

/// \brief A class that wraps rocfft_execution_info for RAII
template <typename FloatingPointType>
struct ScopedRocfftExecutionInfo {
private:
using BufferViewType =
Kokkos::View<Kokkos::complex<FloatingPointType> *, Kokkos::HIP>;
rocfft_execution_info m_execution_info;

//! Internal work buffer
BufferViewType m_buffer;

public:
ScopedRocfftExecutionInfo() {
// Prepare workbuffer and set execution information
rocfft_status status = rocfft_execution_info_create(&m_execution_info);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");
}
~ScopedRocfftExecutionInfo() noexcept {
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
}

rocfft_execution_info execution_info() const noexcept {
return m_execution_info;
}

void setup(const Kokkos::HIP &exec_space, std::size_t workbuffersize) {
// set stream
// NOTE: The stream must be of type hipStream_t.
// It is an error to pass the address of a hipStream_t object.
hipStream_t stream = exec_space.hip_stream();
rocfft_status status =
rocfft_execution_info_set_stream(m_execution_info, stream);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_stream failed");

// Set work buffer
if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
}
};

/// \brief A class that wraps rocfft for RAII
template <typename T>
struct ScopedRocfftPlan {
private:
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
using BufferViewType =
Kokkos::View<Kokkos::complex<floating_point_type> *, Kokkos::HIP>;

using ScopedRocfftExecutionInfoType =
ScopedRocfftExecutionInfo<floating_point_type>;
rocfft_precision m_precision = std::is_same_v<floating_point_type, float>
? rocfft_precision_single
: rocfft_precision_double;
rocfft_plan m_plan;
rocfft_execution_info m_execution_info;

//! Internal work buffer
BufferViewType m_buffer;
std::unique_ptr<ScopedRocfftExecutionInfoType> m_execution_info;

public:
ScopedRocfftPlan(const FFTWTransformType transform_type,
Expand Down Expand Up @@ -128,7 +173,11 @@ struct ScopedRocfftPlan {
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");
}
~ScopedRocfftPlan() noexcept { cleanup(); }
~ScopedRocfftPlan() noexcept {
rocfft_status status = rocfft_plan_destroy(m_plan);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_destroy failed");
}

ScopedRocfftPlan() = delete;
ScopedRocfftPlan(const ScopedRocfftPlan &) = delete;
Expand All @@ -138,53 +187,17 @@ struct ScopedRocfftPlan {

rocfft_plan plan() const noexcept { return m_plan; }
rocfft_execution_info execution_info() const noexcept {
return m_execution_info;
return m_execution_info->execution_info();
}

void commit(const Kokkos::HIP &exec_space) {
try {
// Prepare workbuffer and set execution information
rocfft_status status = rocfft_execution_info_create(&m_execution_info);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");

// set stream
// NOTE: The stream must be of type hipStream_t.
// It is an error to pass the address of a hipStream_t object.
hipStream_t stream = exec_space.hip_stream();
status = rocfft_execution_info_set_stream(m_execution_info, stream);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_stream failed");

// Set work buffer
std::size_t workbuffersize = 0;
status = rocfft_plan_get_work_buffer_size(m_plan, &workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
}

private:
void cleanup() {
rocfft_status status = rocfft_plan_destroy(m_plan);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_destroy failed");
std::size_t workbuffersize = 0;
status = rocfft_plan_get_work_buffer_size(m_plan, &workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
m_execution_info = std::make_unique<ScopedRocfftExecutionInfo>();
m_execution_info->setup(exec_space, workbuffersize);
}

// Helper to get input and output array type and direction from transform type
Expand Down

0 comments on commit 28d9891

Please sign in to comment.