diff --git a/fft/src/KokkosFFT_FFTW_Types.hpp b/fft/src/KokkosFFT_FFTW_Types.hpp index 9bf434d6..0481d2b1 100644 --- a/fft/src/KokkosFFT_FFTW_Types.hpp +++ b/fft/src/KokkosFFT_FFTW_Types.hpp @@ -60,6 +60,26 @@ struct fftw_transform_type, Kokkos::complex> { static constexpr FFTWTransformType type() { return m_type; }; }; +/// \brief A class that wraps fftw_init_threads and fftw_cleanup_threads +template +struct ScopedFFTWThreads { + ScopedFFTWThreads() { + if constexpr (std::is_same_v) { + fftwf_init_threads(); + } else { + fftw_init_threads(); + } + } + + ~ScopedFFTWThreads() noexcept { + if constexpr (std::is_same_v) { + fftwf_cleanup_threads(); + } else { + fftw_cleanup_threads(); + } + } +}; + /// \brief A class that wraps fftw_plan and fftwf_plan for RAII template struct ScopedFFTWPlan { @@ -69,15 +89,13 @@ struct ScopedFFTWPlan { std::conditional_t, fftwf_plan, fftw_plan>; plan_type m_plan; - const int m_local_id; public: template ScopedFFTWPlan(const ExecutionSpace &exec_space, int rank, const int *n, int howmany, InScalarType *in, const int *inembed, int istride, int idist, OutScalarType *out, const int *onembed, int ostride, - int odist, [[maybe_unused]] int sign, unsigned flags) - : m_local_id(global_id()) { + int odist, [[maybe_unused]] int sign, unsigned flags) { init_threads(exec_space); constexpr auto type = fftw_transform_type::type(); if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { @@ -107,7 +125,6 @@ struct ScopedFFTWPlan { } ~ScopedFFTWPlan() noexcept { - cleanup_threads(); if constexpr (std::is_same_v) { fftwf_destroy_plan(m_plan); } else { @@ -124,40 +141,21 @@ struct ScopedFFTWPlan { plan_type plan() const noexcept { return m_plan; } private: - static int global_id() { - static int global_id = 0; - static std::mutex mtx; - std::lock_guard lock(mtx); - return global_id++; - } - void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) { #if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) + static std::mutex mtx; + std::lock_guard lock(mtx); + static ScopedFFTWThreads fftw_threads; if constexpr (std::is_same_v) { int nthreads = exec_space.concurrency(); if constexpr (std::is_same_v) { - if (m_local_id == 0) fftwf_init_threads(); fftwf_plan_with_nthreads(nthreads); } else { - if (m_local_id == 0) fftw_init_threads(); fftw_plan_with_nthreads(nthreads); } } -#endif - } - - void cleanup_threads() { -#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) - if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - if (m_local_id == 0) fftwf_cleanup_threads(); - } else { - if (m_local_id == 0) fftw_cleanup_threads(); - } - } #endif } };