From 7a6946c2555f535d7ae80c25c51c6a4bc939456b Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 17 Dec 2024 22:21:29 +0900 Subject: [PATCH] update host backend based on revies --- fft/src/KokkosFFT_FFTW_Types.hpp | 50 ++++++++++++++++------------ fft/src/KokkosFFT_Host_transform.hpp | 1 + fft/src/KokkosFFT_Host_types.hpp | 2 +- fft/src/KokkosFFT_Plans.hpp | 2 +- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/fft/src/KokkosFFT_FFTW_Types.hpp b/fft/src/KokkosFFT_FFTW_Types.hpp index d519aff6..0e0bd43e 100644 --- a/fft/src/KokkosFFT_FFTW_Types.hpp +++ b/fft/src/KokkosFFT_FFTW_Types.hpp @@ -23,14 +23,14 @@ namespace Impl { enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; // Define fft transform types -template +template struct fftw_transform_type { static_assert(std::is_same_v, "Real to real transform is unavailable"); }; -template -struct fftw_transform_type> { +template +struct fftw_transform_type> { static_assert(std::is_same_v, "T1 and T2 should have the same precision"); static constexpr FFTWTransformType m_type = std::is_same_v @@ -39,8 +39,8 @@ struct fftw_transform_type> { static constexpr FFTWTransformType type() { return m_type; }; }; -template -struct fftw_transform_type, T2> { +template +struct fftw_transform_type, T2> { static_assert(std::is_same_v, "T1 and T2 should have the same precision"); static constexpr FFTWTransformType m_type = std::is_same_v @@ -49,9 +49,8 @@ struct fftw_transform_type, T2> { static constexpr FFTWTransformType type() { return m_type; }; }; -template -struct fftw_transform_type, - Kokkos::complex> { +template +struct fftw_transform_type, Kokkos::complex> { static_assert(std::is_same_v, "T1 and T2 should have the same precision"); static constexpr FFTWTransformType m_type = std::is_same_v @@ -78,7 +77,7 @@ struct ScopedFFTWPlan { int idist, OutScalarType *out, const int *onembed, int ostride, int odist, [[maybe_unused]] int sign, unsigned flags) { init_threads(exec_space); - constexpr auto type = fftw_transform_type::type(); + constexpr auto type = fftw_transform_type::type(); if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { m_plan = fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, @@ -105,6 +104,7 @@ struct ScopedFFTWPlan { } m_is_created = true; } + ~ScopedFFTWPlan() noexcept { cleanup_threads(); if constexpr (std::is_same_v) { @@ -120,29 +120,35 @@ struct ScopedFFTWPlan { ScopedFFTWPlan &operator=(ScopedFFTWPlan &&) = delete; ScopedFFTWPlan(ScopedFFTWPlan &&) = delete; - const plan_type &plan() const { return m_plan; } + plan_type plan() const noexcept { return m_plan; } private: void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) { #if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) - int nthreads = exec_space.concurrency(); - - if constexpr (std::is_same_v) { - fftwf_init_threads(); - fftwf_plan_with_nthreads(nthreads); - } else { - fftw_init_threads(); - fftw_plan_with_nthreads(nthreads); + if constexpr (std::is_same_v) { + int nthreads = exec_space.concurrency(); + + if constexpr (std::is_same_v) { + fftwf_init_threads(); + fftwf_plan_with_nthreads(nthreads); + } else { + fftw_init_threads(); + fftw_plan_with_nthreads(nthreads); + } } #endif } void cleanup_threads() { #if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) - if constexpr (std::is_same_v) { - fftwf_cleanup_threads(); - } else { - fftw_cleanup_threads(); + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + fftwf_cleanup_threads(); + } else { + fftw_cleanup_threads(); + } } #endif } diff --git a/fft/src/KokkosFFT_Host_transform.hpp b/fft/src/KokkosFFT_Host_transform.hpp index 6aeabb1a..5113acde 100644 --- a/fft/src/KokkosFFT_Host_transform.hpp +++ b/fft/src/KokkosFFT_Host_transform.hpp @@ -9,6 +9,7 @@ namespace KokkosFFT { namespace Impl { + template void exec_plan(ScopedPlanType& scoped_plan, float* idata, fftwf_complex* odata, int /*direction*/) { diff --git a/fft/src/KokkosFFT_Host_types.hpp b/fft/src/KokkosFFT_Host_types.hpp index 6d438a0c..85e754b2 100644 --- a/fft/src/KokkosFFT_Host_types.hpp +++ b/fft/src/KokkosFFT_Host_types.hpp @@ -23,7 +23,7 @@ template using TransformType = FFTWTransformType; template -using transform_type = fftw_transform_type; +using transform_type = fftw_transform_type; template struct FFTPlanType { diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 1ce815c5..7f6b98af 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -264,7 +264,7 @@ class Plan { direction, axes, s, m_is_inplace); } - ~Plan() {} + ~Plan() noexcept = default; Plan() = delete; Plan(const Plan&) = delete;