Skip to content

Commit

Permalink
Implement fftn families and related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 6, 2023
1 parent 370e8aa commit a9b0aa4
Show file tree
Hide file tree
Showing 2 changed files with 1,091 additions and 313 deletions.
308 changes: 308 additions & 0 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,314 @@ namespace KokkosFFT {
_irfft2(plan, in, out, norm, axes);
}
}
}

namespace KokkosFFT {
template <typename PlanType, typename InViewType, typename OutViewType>
void _fftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename fft_data_type<in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename fft_data_type<out_value_type>::type*>(out.data());

_exec(plan.plan(), idata, odata, KOKKOS_FFT_FORWARD);
normalize(out, KOKKOS_FFT_FORWARD, norm, plan.fft_size());
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _ifftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename fft_data_type<in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename fft_data_type<out_value_type>::type*>(out.data());

_exec(plan.plan(), idata, odata, KOKKOS_FFT_BACKWARD);
normalize(out, KOKKOS_FFT_BACKWARD, norm, plan.fft_size());
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _rfftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

_fftn(plan, in, out, norm);
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _irfftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

_ifftn(plan, in, out, norm);
}

template <typename InViewType, typename OutViewType>
void fftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_fftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_fftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void fftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_fftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_fftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void ifftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_ifftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_ifftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void ifftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_ifftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_ifftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void rfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_rfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_rfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void rfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_rfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_rfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void irfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_irfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_irfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void irfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_irfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_irfftn(plan, in, out, norm);
}
}
};

#endif
Loading

0 comments on commit a9b0aa4

Please sign in to comment.