Skip to content

Commit

Permalink
Add commit method to scoped hipfft plan
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 20, 2024
1 parent f181066 commit f7944c8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 59 deletions.
16 changes: 10 additions & 6 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ auto create_plan(const ExecutionSpace& exec_space,
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
plan = std::make_unique<PlanType>(exec_space, nx, type, howmany);
plan = std::make_unique<PlanType>(nx, type, howmany);
plan->commit(exec_space);

return fft_size;
}
Expand Down Expand Up @@ -72,7 +73,8 @@ auto create_plan(const ExecutionSpace& exec_space,
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
plan = std::make_unique<PlanType>(exec_space, nx, ny, type);
plan = std::make_unique<PlanType>(nx, ny, type);
plan->commit(exec_space);

return fft_size;
}
Expand Down Expand Up @@ -106,7 +108,8 @@ auto create_plan(const ExecutionSpace& exec_space,
nz = fft_extents.at(2);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
plan = std::make_unique<PlanType>(exec_space, nx, ny, nz, type);
plan = std::make_unique<PlanType>(nx, ny, nz, type);
plan->commit(exec_space);

return fft_size;
}
Expand Down Expand Up @@ -151,9 +154,10 @@ auto create_plan(const ExecutionSpace& exec_space,

// For the moment, considering the contiguous layout only
int istride = 1, ostride = 1;
plan = std::make_unique<PlanType>(
exec_space, rank, fft_extents.data(), in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);
plan = std::make_unique<PlanType>(rank, fft_extents.data(), in_extents.data(),
istride, idist, out_extents.data(), ostride,
odist, type, howmany);
plan->commit(exec_space);

return fft_size;
}
Expand Down
1 change: 1 addition & 0 deletions fft/src/KokkosFFT_HIP_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define KOKKOSFFT_HIP_TRANSFORM_HPP

#include <hipfft/hipfft.h>
#include "KokkosFFT_asserts.hpp"
#include "KokkosFFT_HIP_types.hpp"

namespace KokkosFFT {
Expand Down
80 changes: 27 additions & 53 deletions fft/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,59 +32,28 @@ struct ScopedHIPfftPlan {
hipfftHandle m_plan;

public:
ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, hipfftType type,
int batch) {
try {
hipfftResult hipfft_rt = hipfftPlan1d(&m_plan, nx, type, batch);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan1d failed");
set_stream(exec_space);
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
ScopedHIPfftPlan(int nx, hipfftType type, int batch) {
hipfftResult hipfft_rt = hipfftPlan1d(&m_plan, nx, type, batch);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan1d failed");
}

ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, int ny,
hipfftType type) {
try {
hipfftResult hipfft_rt = hipfftPlan2d(&m_plan, nx, ny, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan2d failed");
set_stream(exec_space);
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
ScopedHIPfftPlan(int nx, int ny, hipfftType type) {
hipfftResult hipfft_rt = hipfftPlan2d(&m_plan, nx, ny, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan2d failed");
}

ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, int ny, int nz,
hipfftType type) {
try {
hipfftResult hipfft_rt = hipfftPlan3d(&m_plan, nx, ny, nz, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan3d failed");
set_stream(exec_space);
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
ScopedHIPfftPlan(int nx, int ny, int nz, hipfftType type) {
hipfftResult hipfft_rt = hipfftPlan3d(&m_plan, nx, ny, nz, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan3d failed");
}

ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int rank, int *n,
int *inembed, int istride, int idist, int *onembed,
int ostride, int odist, hipfftType type, int batch) {
try {
hipfftResult hipfft_rt =
hipfftPlanMany(&m_plan, rank, n, inembed, istride, idist, onembed,
ostride, odist, type, batch);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlanMany failed");
set_stream(exec_space);
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
ScopedHIPfftPlan(int rank, int *n, int *inembed, int istride, int idist,
int *onembed, int ostride, int odist, hipfftType type,
int batch) {
hipfftResult hipfft_rt =
hipfftPlanMany(&m_plan, rank, n, inembed, istride, idist, onembed,
ostride, odist, type, batch);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlanMany failed");
}

~ScopedHIPfftPlan() noexcept { cleanup(); }
Expand All @@ -96,18 +65,23 @@ struct ScopedHIPfftPlan {
ScopedHIPfftPlan(ScopedHIPfftPlan &&) = delete;

hipfftHandle plan() const noexcept { return m_plan; }
void commit(const Kokkos::HIP &exec_space) {
hipStream_t stream = exec_space.hip_stream();
try {
hipfftResult hipfft_rt = hipfftSetStream(m_plan, stream);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed");
} catch (const std::runtime_error &e) {
std::cerr << e.what() << std::endl;
cleanup();
throw;
}
}

private:
void cleanup() {
hipfftResult hipfft_rt = hipfftDestroy(m_plan);
if (hipfft_rt != HIPFFT_SUCCESS) Kokkos::abort("hipfftDestroy failed");
}

void set_stream(const Kokkos::HIP &exec_space) {
hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream(m_plan, stream);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed");
}
};

#if defined(ENABLE_HOST_AND_DEVICE)
Expand Down

0 comments on commit f7944c8

Please sign in to comment.