Skip to content

Commit

Permalink
fix: buffer size for rocfft execution info (kokkos#219)
Browse files Browse the repository at this point in the history
* fix: buffer size for rocfft execution info

* use kokkos_malloc instead of hipMalloc

* fix annotation for kokkos_malloc

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jan 8, 2025
1 parent b4b272a commit 7c3aff8
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ struct ScopedRocfftPlanDescription {
};

/// \brief A class that wraps rocfft_execution_info for RAII
template <typename FloatingPointType>
struct ScopedRocfftExecutionInfo {
private:
using BufferViewType =
Kokkos::View<Kokkos::complex<FloatingPointType> *, Kokkos::HIP>;
rocfft_execution_info m_execution_info;

//! Internal work buffer
BufferViewType m_buffer;
void *m_workbuffer = nullptr;

public:
ScopedRocfftExecutionInfo() {
Expand All @@ -84,6 +81,9 @@ struct ScopedRocfftExecutionInfo {
"rocfft_execution_info_create failed");
}
~ScopedRocfftExecutionInfo() noexcept {
if (m_workbuffer != nullptr) {
Kokkos::kokkos_free<Kokkos::HIP>(m_workbuffer);
}
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
Expand Down Expand Up @@ -111,9 +111,11 @@ struct ScopedRocfftExecutionInfo {

// Set work buffer
if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
m_workbuffer =
Kokkos::kokkos_malloc<Kokkos::HIP>("workbuffer", workbuffersize);

status = rocfft_execution_info_set_work_buffer(
m_execution_info, m_workbuffer, workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
Expand All @@ -124,14 +126,12 @@ struct ScopedRocfftExecutionInfo {
template <typename T>
struct ScopedRocfftPlan {
private:
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
using ScopedRocfftExecutionInfoType =
ScopedRocfftExecutionInfo<floating_point_type>;
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
rocfft_precision m_precision = std::is_same_v<floating_point_type, float>
? rocfft_precision_single
: rocfft_precision_double;
rocfft_plan m_plan;
std::unique_ptr<ScopedRocfftExecutionInfoType> m_execution_info;
std::unique_ptr<ScopedRocfftExecutionInfo> m_execution_info;

public:
ScopedRocfftPlan(const FFTWTransformType transform_type,
Expand Down Expand Up @@ -209,7 +209,7 @@ struct ScopedRocfftPlan {
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

m_execution_info = std::make_unique<ScopedRocfftExecutionInfoType>();
m_execution_info = std::make_unique<ScopedRocfftExecutionInfo>();
m_execution_info->setup(exec_space, workbuffersize);
}

Expand Down

0 comments on commit 7c3aff8

Please sign in to comment.