Skip to content

Commit

Permalink
Merge pull request #6 from CExA-project/implement-fftn
Browse files Browse the repository at this point in the history
Implement fftn
  • Loading branch information
yasahi-hpc authored Dec 6, 2023
2 parents bd843c1 + a9b0aa4 commit 1e34961
Show file tree
Hide file tree
Showing 4 changed files with 1,107 additions and 313 deletions.
15 changes: 15 additions & 0 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ namespace KokkosFFT {

return index;
}

template <typename T, std::size_t... I>
std::array<T, sizeof...(I)> make_sequence_array(std::index_sequence<I...>) {
return std::array<T, sizeof...(I)>{ {I...} };
}

template <int N, typename T>
std::array<T, N> index_sequence(T const& start) {
auto sequence = make_sequence_array<T>(std::make_index_sequence<N>());
std::transform(sequence.begin(), sequence.end(), sequence.begin(),
[=](const T sequence) -> T {return start + sequence;});
return sequence;
}


};

#endif
1 change: 1 addition & 0 deletions fft/src/KokkosFFT_OpenMP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ namespace KokkosFFT {
istride,
idist,
odata,
out_extents.data(),
ostride,
odist,
FFTW_ESTIMATE
Expand Down
308 changes: 308 additions & 0 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,314 @@ namespace KokkosFFT {
_irfft2(plan, in, out, norm, axes);
}
}
}

namespace KokkosFFT {
template <typename PlanType, typename InViewType, typename OutViewType>
void _fftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename fft_data_type<in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename fft_data_type<out_value_type>::type*>(out.data());

_exec(plan.plan(), idata, odata, KOKKOS_FFT_FORWARD);
normalize(out, KOKKOS_FFT_FORWARD, norm, plan.fft_size());
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _ifftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename fft_data_type<in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename fft_data_type<out_value_type>::type*>(out.data());

_exec(plan.plan(), idata, odata, KOKKOS_FFT_BACKWARD);
normalize(out, KOKKOS_FFT_BACKWARD, norm, plan.fft_size());
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _rfftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

_fftn(plan, in, out, norm);
}

template <typename PlanType, typename InViewType, typename OutViewType>
void _irfftn(PlanType& plan, const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

_ifftn(plan, in, out, norm);
}

template <typename InViewType, typename OutViewType>
void fftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_fftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_fftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void fftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::fftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::fftn: OutViewType is not a Kokkos::View.");

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_fftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_fftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void ifftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_ifftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_ifftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void ifftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::ifftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::ifftn: OutViewType is not a Kokkos::View.");

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_ifftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_ifftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void rfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_rfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_rfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void rfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::rfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::rfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
"KokkosFFT::rfftn: InViewType must be real");
static_assert(is_complex<out_value_type>::value,
"KokkosFFT::rfftn: OutViewType must be complex");

Plan plan(in, out, KOKKOS_FFT_FORWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_rfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_rfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType>
void irfftn(const InViewType& in, OutViewType& out, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

// Create a default sequence of axes {-rank, -(rank-1), ..., -1}
constexpr std::size_t rank = InViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = index_sequence<rank>(start);

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_irfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_irfftn(plan, in, out, norm);
}
}

template <typename InViewType, typename OutViewType, std::size_t DIM=1>
void irfftn(const InViewType& in, OutViewType& out, axis_type<DIM> axes, FFT_Normalization norm=FFT_Normalization::BACKWARD) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::irfftn: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"KokkosFFT::irfftn: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(is_complex<in_value_type>::value,
"KokkosFFT::irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
"KokkosFFT::irfftn: OutViewType must be real");

Plan plan(in, out, KOKKOS_FFT_BACKWARD, axes);
if(plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;

KokkosFFT::transpose(in, in_T, plan.map());
KokkosFFT::transpose(out, out_T, plan.map());

_irfftn(plan, in_T, out_T, norm);

KokkosFFT::transpose(out_T, out, plan.map_inv());
} else {
_irfftn(plan, in, out, norm);
}
}
};

#endif
Loading

0 comments on commit 1e34961

Please sign in to comment.