Skip to content

Commit

Permalink
call fftw_cleanup_threads only once
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jan 6, 2025
1 parent 88d4c43 commit 46d2faf
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef KOKKOSFFT_FFTW_TYPES_HPP
#define KOKKOSFFT_FFTW_TYPES_HPP

#include <mutex>
#include <fftw3.h>
#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
Expand Down Expand Up @@ -68,13 +69,15 @@ struct ScopedFFTWPlan {
std::conditional_t<std::is_same_v<floating_point_type, float>, fftwf_plan,
fftw_plan>;
plan_type m_plan;
const int m_local_id;

public:
template <typename InScalarType, typename OutScalarType>
ScopedFFTWPlan(const ExecutionSpace &exec_space, int rank, const int *n,
int howmany, InScalarType *in, const int *inembed, int istride,
int idist, OutScalarType *out, const int *onembed, int ostride,
int odist, [[maybe_unused]] int sign, unsigned flags) {
int odist, [[maybe_unused]] int sign, unsigned flags)
: m_local_id(global_id()) {
init_threads(exec_space);
constexpr auto type = fftw_transform_type<T1, T2>::type();
if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
Expand Down Expand Up @@ -121,17 +124,24 @@ struct ScopedFFTWPlan {
plan_type plan() const noexcept { return m_plan; }

private:
static int global_id() {
static int global_id = 0;
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
return global_id++;
}

static void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
int nthreads = exec_space.concurrency();

if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_init_threads();
if (m_local_id == 0) fftwf_init_threads();
fftwf_plan_with_nthreads(nthreads);
} else {
fftw_init_threads();
if (m_local_id == 0) fftw_init_threads();
fftw_plan_with_nthreads(nthreads);
}
}
Expand All @@ -143,9 +153,9 @@ struct ScopedFFTWPlan {
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_cleanup_threads();
if (m_local_id == 0) fftwf_cleanup_threads();
} else {
fftw_cleanup_threads();
if (m_local_id == 0) fftw_cleanup_threads();
}
}
#endif
Expand Down

0 comments on commit 46d2faf

Please sign in to comment.