Skip to content

Commit

Permalink
update hip backend based on reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 17, 2024
1 parent 4a6d32b commit 36144b2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
30 changes: 15 additions & 15 deletions fft/src/KokkosFFT_HIP_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,45 @@

namespace KokkosFFT {
namespace Impl {
template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftReal* idata,

struct ScopedHIPfftPlan;

inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftReal* idata,
hipfftComplex* odata, int /*direction*/) {
hipfftResult hipfft_rt = hipfftExecR2C(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecR2C failed");
}

template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftDoubleReal* idata,
hipfftDoubleComplex* odata, int /*direction*/) {
inline void exec_plan(const ScopedHIPfftPlan& scoped_plan,
hipfftDoubleReal* idata, hipfftDoubleComplex* odata,
int /*direction*/) {
hipfftResult hipfft_rt = hipfftExecD2Z(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecD2Z failed");
}

template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftComplex* idata,
inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftComplex* idata,
hipfftReal* odata, int /*direction*/) {
hipfftResult hipfft_rt = hipfftExecC2R(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecC2R failed");
}

template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftDoubleComplex* idata,
hipfftDoubleReal* odata, int /*direction*/) {
inline void exec_plan(const ScopedHIPfftPlan& scoped_plan,
hipfftDoubleComplex* idata, hipfftDoubleReal* odata,
int /*direction*/) {
hipfftResult hipfft_rt = hipfftExecZ2D(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecZ2D failed");
}

template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftComplex* idata,
inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftComplex* idata,
hipfftComplex* odata, int direction) {
hipfftResult hipfft_rt =
hipfftExecC2C(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecC2C failed");
}

template <typename ScopedPlanType>
inline void exec_plan(ScopedPlanType& scoped_plan, hipfftDoubleComplex* idata,
hipfftDoubleComplex* odata, int direction) {
inline void exec_plan(const ScopedHIPfftPlan& scoped_plan,
hipfftDoubleComplex* idata, hipfftDoubleComplex* odata,
int direction) {
hipfftResult hipfft_rt =
hipfftExecZ2Z(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecZ2Z failed");
Expand Down
17 changes: 8 additions & 9 deletions fft/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,33 @@ namespace Impl {
using FFTDirectionType = int;

/// \brief A class that wraps hipfft for RAII
template <typename ExecutionSpace>
struct ScopedHIPfftPlan {
private:
hipfftHandle m_plan;

public:
ScopedHIPfftPlan(const ExecutionSpace &exec_space, int nx, hipfftType type,
ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, hipfftType type,
int batch) {
hipfftResult hipfft_rt = hipfftPlan1d(&m_plan, nx, type, batch);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan1d failed");
set_stream(exec_space);
}

ScopedHIPfftPlan(const ExecutionSpace &exec_space, int nx, int ny,
ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, int ny,
hipfftType type) {
hipfftResult hipfft_rt = hipfftPlan2d(&m_plan, nx, ny, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan2d failed");
set_stream(exec_space);
}

ScopedHIPfftPlan(const ExecutionSpace &exec_space, int nx, int ny, int nz,
ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int nx, int ny, int nz,
hipfftType type) {
hipfftResult hipfft_rt = hipfftPlan3d(&m_plan, nx, ny, nz, type);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan3d failed");
set_stream(exec_space);
}

ScopedHIPfftPlan(const ExecutionSpace &exec_space, int rank, int *n,
ScopedHIPfftPlan(const Kokkos::HIP &exec_space, int rank, int *n,
int *inembed, int istride, int idist, int *onembed,
int ostride, int odist, hipfftType type, int batch) {
hipfftResult hipfft_rt =
Expand All @@ -74,10 +73,10 @@ struct ScopedHIPfftPlan {
ScopedHIPfftPlan &operator=(ScopedHIPfftPlan &&) = delete;
ScopedHIPfftPlan(ScopedHIPfftPlan &&) = delete;

hipfftHandle &plan() { return m_plan; }
hipfftHandle plan() const noexcept { return m_plan; }

private:
void set_stream(const ExecutionSpace &exec_space) {
void set_stream(const Kokkos::HIP &exec_space) {
hipStream_t stream = exec_space.hip_stream();
hipfftResult hipfft_rt = hipfftSetStream(m_plan, stream);
KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed");
Expand Down Expand Up @@ -179,7 +178,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftw_plan_type = ScopedFFTWPlan<ExecutionSpace, T1, T2>;
using hipfft_plan_type = ScopedHIPfftPlan<ExecutionSpace>;
using hipfft_plan_type = ScopedHIPfftPlan;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::HIP>,
hipfft_plan_type, fftw_plan_type>;
};
Expand Down Expand Up @@ -242,7 +241,7 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = ScopedHIPfftPlan<ExecutionSpace>;
using type = ScopedHIPfftPlan;
};

template <typename ExecutionSpace>
Expand Down

0 comments on commit 36144b2

Please sign in to comment.