Skip to content

Commit

Permalink
update rocm backend based on reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 17, 2024
1 parent 36144b2 commit b87d799
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 168 deletions.
153 changes: 3 additions & 150 deletions fft/src/KokkosFFT_ROCM_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,72 +15,6 @@

namespace KokkosFFT {
namespace Impl {
// Helper to get input and output array type and direction from transform type
template <typename TransformType>
auto get_in_out_array_type(TransformType type, Direction direction) {
rocfft_array_type in_array_type, out_array_type;
rocfft_transform_type fft_direction;

if (type == FFTWTransformType::C2C || type == FFTWTransformType::Z2Z) {
in_array_type = rocfft_array_type_complex_interleaved;
out_array_type = rocfft_array_type_complex_interleaved;
fft_direction = direction == Direction::forward
? rocfft_transform_type_complex_forward
: rocfft_transform_type_complex_inverse;
} else if (type == FFTWTransformType::R2C || type == FFTWTransformType::D2Z) {
in_array_type = rocfft_array_type_real;
out_array_type = rocfft_array_type_hermitian_interleaved;
fft_direction = rocfft_transform_type_real_forward;
} else if (type == FFTWTransformType::C2R || type == FFTWTransformType::Z2D) {
in_array_type = rocfft_array_type_hermitian_interleaved;
out_array_type = rocfft_array_type_real;
fft_direction = rocfft_transform_type_real_inverse;
}

return std::tuple<rocfft_array_type, rocfft_array_type,
rocfft_transform_type>(
{in_array_type, out_array_type, fft_direction});
};

template <typename ValueType>
rocfft_precision get_in_out_array_type() {
return std::is_same_v<KokkosFFT::Impl::base_floating_point_type<ValueType>,
float>
? rocfft_precision_single
: rocfft_precision_double;
}

// Helper to convert the integer type of vectors
template <typename InType, typename OutType>
auto convert_int_type_and_reverse(std::vector<InType>& in)
-> std::vector<OutType> {
std::vector<OutType> out(in.size());
std::transform(
in.begin(), in.end(), out.begin(),
[](const InType v) -> OutType { return static_cast<OutType>(v); });

std::reverse(out.begin(), out.end());
return out;
}

// Helper to compute strides from extents
// (n0, n1, n2) -> (1, n0, n0*n1)
// (n0, n1) -> (1, n0)
// (n0) -> (1)
template <typename InType, typename OutType>
auto compute_strides(const std::vector<InType>& extents)
-> std::vector<OutType> {
std::vector<OutType> out = {1};
auto reversed_extents = extents;
std::reverse(reversed_extents.begin(), reversed_extents.end());

for (std::size_t i = 1; i < reversed_extents.size(); i++) {
out.push_back(static_cast<OutType>(reversed_extents.at(i - 1)) *
out.at(i - 1));
}

return out;
}

// batched transform, over ND Views
template <typename ExecutionSpace, typename PlanType, typename InViewType,
Expand Down Expand Up @@ -112,92 +46,11 @@ auto create_plan(const ExecutionSpace& exec_space,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
std::multiplies<>());
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

// For the moment, considering the contiguous layout only
// Create plan
auto in_strides = compute_strides<int, std::size_t>(in_extents);
auto out_strides = compute_strides<int, std::size_t>(out_extents);
auto reversed_fft_extents =
convert_int_type_and_reverse<int, std::size_t>(fft_extents);

// Create the description
std::unique_ptr<rocfft_plan_description,
std::function<void(rocfft_plan_description*)>> const
description(new rocfft_plan_description,
[](rocfft_plan_description* desc) {
rocfft_plan_description_destroy(*desc);
});
rocfft_status status = rocfft_plan_description_create(&(*description));
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_create failed");

auto [in_array_type, out_array_type, fft_direction] =
get_in_out_array_type(type, direction);
rocfft_precision precision = get_in_out_array_type<in_value_type>();

status = rocfft_plan_description_set_data_layout(
*description, // description handle
in_array_type, // input array type
out_array_type, // output array type
nullptr, // offsets to start of input data
nullptr, // offsets to start of output data
in_strides.size(), // input stride length
in_strides.data(), // input stride data
idist, // input batch distance
out_strides.size(), // output stride length
out_strides.data(), // output stride data
odist); // output batch distance
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_description_set_data_layout failed");

// Out-of-place transform
const rocfft_result_placement place =
is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace;

// Create a plan
plan = std::make_unique<PlanType>();
status = rocfft_plan_create(&(plan->plan()), place, fft_direction, precision,
reversed_fft_extents.size(), // Dimension
reversed_fft_extents.data(), // Lengths
howmany, // Number of transforms
*description // Description
);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_create failed");
plan->set_is_plan_created();

// Prepare workbuffer and set execution information
status = rocfft_execution_info_create(&(plan->execution_info()));
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_create failed");
plan->set_is_info_created();

// set stream
// NOTE: The stream must be of type hipStream_t.
// It is an error to pass the address of a hipStream_t object.
hipStream_t stream = exec_space.hip_stream();
status = rocfft_execution_info_set_stream(plan->execution_info(), stream);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_stream failed");

std::size_t workbuffersize = 0;
status = rocfft_plan_get_work_buffer_size(plan->plan(), &workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

if (workbuffersize > 0) {
plan->allocate_work_buffer(workbuffersize);
status = rocfft_execution_info_set_work_buffer(
plan->execution_info(), (void*)plan->buffer_data(), workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
plan =
std::make_unique<PlanType>(exec_space, type, in_extents, out_extents,
fft_extents, howmany, direction, is_inplace);

return fft_size;
}
Expand Down
12 changes: 6 additions & 6 deletions fft/src/KokkosFFT_ROCM_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace KokkosFFT {
namespace Impl {
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, float* idata,
void exec_plan(const ScopedPlanType& scoped_plan, float* idata,
std::complex<float>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -22,7 +22,7 @@ void exec_plan(ScopedPlanType& scoped_plan, float* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, double* idata,
void exec_plan(const ScopedPlanType& scoped_plan, double* idata,
std::complex<double>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -32,7 +32,7 @@ void exec_plan(ScopedPlanType& scoped_plan, double* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
void exec_plan(const ScopedPlanType& scoped_plan, std::complex<float>* idata,
float* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -42,7 +42,7 @@ void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<double>* idata,
void exec_plan(const ScopedPlanType& scoped_plan, std::complex<double>* idata,
double* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -52,7 +52,7 @@ void exec_plan(ScopedPlanType& scoped_plan, std::complex<double>* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
void exec_plan(const ScopedPlanType& scoped_plan, std::complex<float>* idata,
std::complex<float>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand All @@ -62,7 +62,7 @@ void exec_plan(ScopedPlanType& scoped_plan, std::complex<float>* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, std::complex<double>* idata,
void exec_plan(const ScopedPlanType& scoped_plan, std::complex<double>* idata,
std::complex<double>* odata, int /*direction*/) {
rocfft_status status =
rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata,
Expand Down
Loading

0 comments on commit b87d799

Please sign in to comment.