Skip to content

Commit

Permalink
fix: rocm types
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 17, 2024
1 parent 622ac0e commit 1d33d7a
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z };
template <typename ExecutionSpace>
using TransformType = FFTWTransformType;

/// \brief A class that wraps rocfft_plan_description for RAII
struct ScopedRocfftPlanDescription {
private:
rocfft_plan_description m_description;

public:
ScopedRocfftPlanDescription() {
rocfft_status status = rocfft_plan_description_create(&m_description);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_create failed");
}
~ScopedRocfftPlanDescription() noexcept {
rocfft_status status = rocfft_plan_description_destroy(m_description);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_plan_description_destroy failed");
}

rocfft_plan_description &description() { return m_description; }
}

/// \brief A class that wraps rocfft for RAII
template <typename T>
struct ScopedRocfftPlan {
Expand Down Expand Up @@ -65,17 +85,7 @@ struct ScopedRocfftPlan {
const std::vector<int> &out_extents,
const std::vector<int> &fft_extents, int howmany,
Direction direction, bool is_inplace) {
std::unique_ptr<rocfft_plan_description,
std::function<void(rocfft_plan_description *)>> const
description(new rocfft_plan_description,
[](rocfft_plan_description *desc) {
if (desc) rocfft_plan_description_destroy(*desc);
desc = nullptr;
});

rocfft_status status = rocfft_plan_description_create(&(*description));
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_create failed");
ScopedRocfftPlanDescription scoped_description;

auto [in_array_type, out_array_type, fft_direction] =
get_in_out_array_type(transform_type, direction);
Expand All @@ -92,17 +102,17 @@ struct ScopedRocfftPlan {
convert_int_type_and_reverse<int, std::size_t>(fft_extents);

status = rocfft_plan_description_set_data_layout(
*description, // description handle
in_array_type, // input array type
out_array_type, // output array type
nullptr, // offsets to start of input data
nullptr, // offsets to start of output data
in_strides.size(), // input stride length
in_strides.data(), // input stride data
idist, // input batch distance
out_strides.size(), // output stride length
out_strides.data(), // output stride data
odist); // output batch distance
scoped_description.description(), // description handle
in_array_type, // input array type
out_array_type, // output array type
nullptr, // offsets to start of input data
nullptr, // offsets to start of output data
in_strides.size(), // input stride length
in_strides.data(), // input stride data
idist, // input batch distance
out_strides.size(), // output stride length
out_strides.data(), // output stride data
odist); // output batch distance

KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_set_data_layout failed");
Expand All @@ -115,8 +125,8 @@ struct ScopedRocfftPlan {
status = rocfft_plan_create(&m_plan, place, fft_direction, m_precision,
reversed_fft_extents.size(), // Dimension
reversed_fft_extents.data(), // Lengths
howmany, // Number of transforms
*description // Description
howmany, // Number of transforms
scoped_description.description() // Description
);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");
Expand Down

0 comments on commit 1d33d7a

Please sign in to comment.