From 4806cfbf0eeb67e3e233a195cc61b66985b2704d Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Wed, 4 Dec 2024 18:20:41 +0900 Subject: [PATCH] fix: fftw plan creation --- fft/src/KokkosFFT_FFTW_Types.hpp | 43 ++++---------------------------- fft/src/KokkosFFT_Host_plans.hpp | 34 ++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 42 deletions(-) diff --git a/fft/src/KokkosFFT_FFTW_Types.hpp b/fft/src/KokkosFFT_FFTW_Types.hpp index 16d2e8c7..16a0c0f6 100644 --- a/fft/src/KokkosFFT_FFTW_Types.hpp +++ b/fft/src/KokkosFFT_FFTW_Types.hpp @@ -70,7 +70,10 @@ struct ScopedFFTWPlanType { plan_type m_plan; bool m_is_created = false; - ScopedFFTWPlanType() {} + ScopedFFTWPlanType() = delete; + ScopedFFTWPlanType(const ExecutionSpace &exec_space) { + init_threads(exec_space); + } ~ScopedFFTWPlanType() { cleanup_threads(); if constexpr (std::is_same_v) { @@ -80,43 +83,7 @@ struct ScopedFFTWPlanType { } } - plan_type &plan() { return m_plan; } - - template - void create(const ExecutionSpace &exec_space, int rank, const int *n, - int howmany, InScalarType *in, const int *inembed, int istride, - int idist, OutScalarType *out, const int *onembed, int ostride, - int odist, [[maybe_unused]] int sign, unsigned flags) { - init_threads(exec_space); - - constexpr auto type = fftw_transform_type::type(); - - if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { - m_plan = - fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, flags); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) { - m_plan = - fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, flags); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) { - m_plan = - fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, flags); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) { - m_plan = - fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, flags); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) { - m_plan = - fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, sign, flags); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) { - m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist, - out, onembed, ostride, odist, sign, flags); - } - m_is_created = true; - } + const plan_type &plan() const { return m_plan; } private: template diff --git a/fft/src/KokkosFFT_Host_plans.hpp b/fft/src/KokkosFFT_Host_plans.hpp index dff9d15e..47dd78cc 100644 --- a/fft/src/KokkosFFT_Host_plans.hpp +++ b/fft/src/KokkosFFT_Host_plans.hpp @@ -58,10 +58,36 @@ auto create_plan(const ExecutionSpace& exec_space, [[maybe_unused]] auto sign = KokkosFFT::Impl::direction_type(direction); - plan = std::make_unique(); - plan->create(exec_space, rank, fft_extents.data(), howmany, idata, - in_extents.data(), istride, idist, odata, out_extents.data(), - ostride, odist, sign, FFTW_ESTIMATE); + plan = std::make_unique(exec_space); + constexpr auto type = + KokkosFFT::Impl::transform_type::type(); + if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { + plan->m_plan = fftwf_plan_many_dft_r2c( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) { + plan->m_plan = fftw_plan_many_dft_r2c( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) { + plan->m_plan = fftwf_plan_many_dft_c2r( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) { + plan->m_plan = fftw_plan_many_dft_c2r( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) { + plan->m_plan = fftwf_plan_many_dft( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) { + plan->m_plan = fftw_plan_many_dft( + rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE); + } + plan->m_is_created; return fft_size; }