Skip to content

Commit

Permalink
fix: work buffer allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 3, 2024
1 parent 018f490 commit 457c77e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
5 changes: 2 additions & 3 deletions fft/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace Impl {
using FFTDirectionType = int;

/// \brief A class that wraps cufft for RAII
template <typename ExecutionSpace, typename T1, typename T2>
struct ScopedCufftPlanType {
cufftHandle m_plan;

Expand Down Expand Up @@ -129,7 +128,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftw_plan_type = ScopedFFTWPlanType<ExecutionSpace, T1, T2>;
using cufft_plan_type = ScopedCufftPlanType<ExecutionSpace, T1, T2>;
using cufft_plan_type = ScopedCufftPlanType;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
cufft_plan_type, fftw_plan_type>;
};
Expand Down Expand Up @@ -192,7 +191,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = ScopedCufftPlanType<ExecutionSpace, T1, T2>;
using type = ScopedCufftPlanType;
};

template <typename ExecutionSpace>
Expand Down
5 changes: 2 additions & 3 deletions fft/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ namespace Impl {
using FFTDirectionType = int;

/// \brief A class that wraps hipfft for RAII
template <typename ExecutionSpace, typename T1, typename T2>
struct ScopedHIPfftPlanType {
hipfftHandle m_plan;

Expand Down Expand Up @@ -129,7 +128,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftw_plan_type = ScopedFFTWPlanType<ExecutionSpace, T1, T2>;
using hipfft_plan_type = ScopedHIPfftPlanType<ExecutionSpace, T1, T2>;
using hipfft_plan_type = ScopedHIPfftPlanType;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
hipfft_plan_type, fftw_plan_type>;
};
Expand Down Expand Up @@ -192,7 +191,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = ScopedHIPfftPlanType<ExecutionSpace, T1, T2>;
using type = ScopedHIPfftPlanType;
};

template <typename ExecutionSpace>
Expand Down
4 changes: 2 additions & 2 deletions fft/src/KokkosFFT_ROCM_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ auto create_plan(const ExecutionSpace& exec_space,
"rocfft_plan_get_work_buffer_size failed");

if (workbuffersize > 0) {
plan->m_buffer = BufferViewType("work_buffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
plan->allocate_work_buffer(workbuffersize);
status = rocfft_execution_info_set_work_buffer(
plan->execution_info(), (void*)plan->m_buffer.data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
Expand Down
11 changes: 7 additions & 4 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ template <typename ExecutionSpace>
using TransformType = FFTWTransformType;

/// \brief A class that wraps rocfft for RAII
template <typename ExecutionSpace, typename T1, typename T2>
template <typename ExecutionSpace, typename T>
struct ScopedRocfftPlanType {
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T1>;
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
rocfft_plan m_plan;
rocfft_execution_info m_execution_info;

Expand All @@ -55,6 +55,9 @@ struct ScopedRocfftPlanType {
if (m_is_plan_created) rocfft_plan_destroy(m_plan);
}

void allocate_work_buffer(std::size_t workbuffersize) {
m_buffer = BufferViewType("work buffer", workbuffersize);
}
rocfft_plan &plan() { return m_plan; }
rocfft_execution_info &execution_info() { return m_execution_info; }
};
Expand Down Expand Up @@ -114,7 +117,7 @@ struct FFTDataType {
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftw_plan_type = ScopedFFTWPlanType<ExecutionSpace, T1, T2>;
using rocfft_plan_type = ScopedRocfftPlanType<ExecutionSpace, T1, T2>;
using rocfft_plan_type = ScopedRocfftPlanType<ExecutionSpace, T1>;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
rocfft_plan_type, fftw_plan_type>;
};
Expand All @@ -141,7 +144,7 @@ struct FFTDataType {

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = ScopedRocfftPlanType<ExecutionSpace, T1, T2>;
using type = ScopedRocfftPlanType<ExecutionSpace, T1>;
};

template <typename ExecutionSpace>
Expand Down

0 comments on commit 457c77e

Please sign in to comment.