Skip to content

Commit

Permalink
Check view rank and fft rank consistency in all APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Jul 23, 2024
1 parent 762dea8 commit 9e3c321
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
28 changes: 18 additions & 10 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM = 1>
auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {
static_assert(DIM > 0,
"get_shift: Rank of shift axes must be "
"larger than or equal to 1.");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
for (std::size_t i = 0; i < DIM; i++) {
Expand Down Expand Up @@ -132,19 +128,13 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"fftshift_impl: Rank of View must be larger thane "
"or equal to the Rank of shift axes.");
auto shift = get_shift(inout, axes);
roll(exec_space, inout, shift, axes);
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"ifftshift_impl: Rank of View must be larger "
"thane or equal to the Rank of shift axes.");
auto shift = get_shift(inout, axes, -1);
roll(exec_space, inout, shift, axes);
}
Expand Down Expand Up @@ -229,6 +219,9 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"fftshift: View rank must be larger than or equal to 1");

if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
Expand All @@ -253,6 +246,12 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"fftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"fftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
}

Expand All @@ -269,6 +268,8 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"ifftshift: View rank must be larger than or equal to 1");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
Expand All @@ -293,6 +294,13 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"ifftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"ifftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");

KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
}
} // namespace KokkosFFT
Expand Down
8 changes: 8 additions & 0 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class Plan {
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");
static_assert(InViewType::rank() >= 1,
"Plan::Plan: View rank must be larger than or equal to 1");

if (KokkosFFT::Impl::is_real_v<in_value_type> &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -220,6 +222,12 @@ class Plan {
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"Plan::Plan: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(InViewType::rank() >= DIM,
"Plan::Plan: View rank must be larger than or equal to the "
"Rank of FFT axes");

if (std::is_floating_point<in_value_type>::value &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down
44 changes: 44 additions & 0 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"fft: View rank must be larger than or equal to 1");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axis, n);
Expand All @@ -165,6 +167,8 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"ifft: View rank must be larger than or equal to 1");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axis, n);
Expand All @@ -191,6 +195,8 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"rfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -224,6 +230,8 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"irfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -255,6 +263,8 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"hfft: View rank must be larger than or equal to 1");

// [TO DO]
// allow real type as input, need to obtain complex view type from in view
Expand Down Expand Up @@ -295,6 +305,8 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"ihfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -332,6 +344,8 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"fft2: View rank must be larger than or equal to 2");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);
Expand Down Expand Up @@ -359,6 +373,8 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"ifft2: View rank must be larger than or equal to 2");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);
Expand Down Expand Up @@ -386,6 +402,9 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"rfft2: View rank must be larger than or equal to 2");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

Expand Down Expand Up @@ -418,6 +437,8 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"irfft2: View rank must be larger than or equal to 2");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -453,6 +474,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"fftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"fftn: View rank must be larger than or equal to the Rank of FFT axes");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);
Expand Down Expand Up @@ -481,6 +507,12 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"ifftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"ifftn: View rank must be larger than or equal to the Rank of FFT axes");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);
Expand Down Expand Up @@ -509,6 +541,12 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"rfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"rfftn: View rank must be larger than or equal to the Rank of FFT axes");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -543,6 +581,12 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"irfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"irfftn: View rank must be larger than or equal to the Rank of FFT axes");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down

0 comments on commit 9e3c321

Please sign in to comment.