Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply check functions to common functions #129

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ auto get_extents(const InViewType& in, const OutViewType& out,
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

static_assert(InViewType::rank() >= DIM,
"KokkosFFT::get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
static_assert(DIM > 0,
"KokkosFFT::get_map_axes: Rank of FFT axes must be larger than "
"or equal to 1.");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
Expand Down Expand Up @@ -64,32 +60,25 @@ auto get_extents(const InViewType& in, const OutViewType& out,
_fft_extents.push_back(fft_extent);
}

static_assert(!(is_real_v<in_value_type> && is_real_v<out_value_type>),
"get_extents: real to real transform is not supported");

if (is_real_v<in_value_type>) {
// Then R2C
if (is_complex_v<out_value_type>) {
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
}
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
}

if (is_real_v<out_value_type>) {
// Then C2R
if (is_complex_v<in_value_type>) {
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
}
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
}

if (std::is_same_v<array_layout_type, Kokkos::LayoutLeft>) {
Expand Down
5 changes: 5 additions & 0 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ template <typename ExecutionSpace, typename ViewType>
void normalize(const ExecutionSpace& exec_space, ViewType& inout,
Direction direction, Normalization normalization,
std::size_t fft_size) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"normalize: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
tpadioleau marked this conversation as resolved.
Show resolved Hide resolved
auto [coef, to_normalize] =
get_coefficients(inout, direction, normalization, fft_size);
if (to_normalize) normalize_impl(exec_space, inout, coef);
Expand Down
38 changes: 14 additions & 24 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,11 @@ namespace Impl {
template <typename InViewType, typename OutViewType, std::size_t DIM>
auto get_modified_shape(const InViewType in, const OutViewType /* out */,
shape_type<DIM> shape, axis_type<DIM> axes) {
static_assert(InViewType::rank() >= DIM,
"get_modified_shape: Rank of Input View must be larger "
"than or equal to the Rank of new shape");
static_assert(OutViewType::rank() >= DIM,
"get_modified_shape: Rank of Output View must be larger "
"than or equal to the Rank of new shape");
static_assert(DIM > 0,
"get_modified_shape: Rank of FFT axes must be "
"larger than or equal to 1");
constexpr int rank = static_cast<int>(InViewType::rank());
static_assert(
KokkosFFT::Impl::have_same_rank_v<InViewType, OutViewType>,
"get_modified_shape: Input View and Output View must have the same rank");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
Expand All @@ -50,14 +45,7 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */,
positive_axes.push_back(axis);
}

// Assert if the elements are overlapped
KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(positive_axes),
"Axes overlap");
KOKKOSFFT_EXPECTS(
!KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank),
"Axes include an out-of-range index."
"Axes must be in the range of [-rank, rank-1].");

constexpr int rank = static_cast<int>(InViewType::rank());
using full_shape_type = shape_type<rank>;
full_shape_type modified_shape;
for (int i = 0; i < rank; i++) {
Expand Down Expand Up @@ -346,12 +334,14 @@ template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void crop_or_pad(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, shape_type<DIM> s) {
static_assert(InViewType::rank() == DIM,
"crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");
static_assert(OutViewType::rank() == DIM,
"crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"crop_or_pad: InViewType and OutViewType must have the same base "
"floating point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
crop_or_pad_impl(exec_space, in, out, s);
}
} // namespace Impl
Expand Down
38 changes: 14 additions & 24 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,8 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
static_assert(ViewType::rank() >= DIM,
"get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
static_assert(DIM > 0,
"get_map_axes: Rank of FFT axes must be larger than "
"or equal to 1.");

constexpr int rank = static_cast<int>(ViewType::rank());
using array_layout_type = typename ViewType::array_layout;
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
Expand All @@ -31,16 +24,14 @@ auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
axes.push_back(axis);
}

// Assert if the elements are overlapped
assert(!KokkosFFT::Impl::has_duplicate_values(axes));

// how indices are map
// For 5D View and axes are (2,3), map would be (0, 1, 4, 2, 3)
constexpr int rank = static_cast<int>(ViewType::rank());
std::vector<int> map, map_inv;
map.reserve(rank);
map_inv.reserve(rank);

if (std::is_same_v<array_layout_type, Kokkos::LayoutRight>) {
if (std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>) {
// Stack axes not specified by axes (0, 1, 4)
for (int i = 0; i < rank; i++) {
if (!is_found(axes, i)) {
Expand Down Expand Up @@ -396,22 +387,21 @@ template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void transpose(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, axis_type<DIM> map) {
static_assert(Kokkos::is_view<InViewType>::value,
"transpose: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
"transpose: OutViewType is not a Kokkos::View.");

static_assert(InViewType::rank() == OutViewType::rank(),
"transpose: InViewType and OutViewType must have "
"the same rank.");
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"transpose: InViewType and OutViewType must have the same base floating "
"point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");

static_assert(InViewType::rank() == DIM,
"transpose: Rank of View must be equal to Rank of "
"transpose axes.");

if (!KokkosFFT::Impl::is_transpose_needed(map)) {
throw std::runtime_error("transpose: transpose not necessary");
}
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");

// in order not to call transpose_impl for 1D case
if constexpr (DIM > 1) {
Expand Down
Loading
Loading