Skip to content

Commit

Permalink
fix constructor of fftw wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 4, 2024
1 parent d950e00 commit 8e70b0c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
35 changes: 33 additions & 2 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,42 @@ struct ScopedFFTWPlanType {

public:
ScopedFFTWPlanType() = delete;
ScopedFFTWPlanType(const ExecutionSpace &exec_space) {
template <typename InScalarType, typename OutScalarType>
ScopedFFTWPlanType(const ExecutionSpace &exec_space, int rank, const int *n,
int howmany, InScalarType *in, const int *inembed,
int istride, int idist, OutScalarType *out,
const int *onembed, int ostride, int odist,
[[maybe_unused]] int sign, unsigned flags) {
init_threads(exec_space);
constexpr auto type = fftw_transform_type<ExecutionSpace, T1, T2>::type();
if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
m_plan =
fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
m_plan =
fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
m_plan =
fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
m_plan =
fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
m_plan =
fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
}
m_is_created = true;
}
~ScopedFFTWPlanType() {
cleanup_threads();
if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
if (m_is_created) fftwf_destroy_plan(m_plan);
} else {
Expand All @@ -85,7 +117,6 @@ struct ScopedFFTWPlanType {
}

const plan_type &plan() const { return m_plan; }
void set_is_created() { m_is_created = true; }

private:
void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
Expand Down
34 changes: 4 additions & 30 deletions fft/src/KokkosFFT_Host_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,10 @@ auto create_plan(const ExecutionSpace& exec_space,
[[maybe_unused]] auto sign =
KokkosFFT::Impl::direction_type<ExecutionSpace>(direction);

plan = std::make_unique<PlanType>(exec_space);
constexpr auto type =
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
plan->m_plan = fftwf_plan_many_dft_r2c(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
plan->m_plan = fftw_plan_many_dft_r2c(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
plan->m_plan = fftwf_plan_many_dft_c2r(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
plan->m_plan = fftw_plan_many_dft_c2r(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
plan->m_plan = fftwf_plan_many_dft(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
plan->m_plan = fftw_plan_many_dft(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
}
plan->set_is_created();
plan = std::make_unique<PlanType>(exec_space, rank, fft_extents.data(),
howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride,
odist, sign, FFTW_ESTIMATE);

return fft_size;
}
Expand Down
2 changes: 1 addition & 1 deletion fft/src/KokkosFFT_Host_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void exec_plan(ScopedPlanType& scoped_plan, fftwf_complex* idata,
}

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType scoped_plan, fftw_complex* idata,
void exec_plan(ScopedPlanType& scoped_plan, fftw_complex* idata,
fftw_complex* odata, int /*direction*/) {
fftw_execute_dft(scoped_plan.plan(), idata, odata);
}
Expand Down

0 comments on commit 8e70b0c

Please sign in to comment.