diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index 86e5f95c..d05d2cbc 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -23,12 +23,12 @@ auto get_shift(const ViewType& inout, axis_type _axes, int direction = 1) { // Assert if the elements are overlapped constexpr int rank = ViewType::rank(); - check_precondition(!KokkosFFT::Impl::has_duplicate_values(axes), - "axes are overlapped"); - check_precondition( + KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(axes), + "Axes overlap"); + KOKKOSFFT_EXPECTS( !KokkosFFT::Impl::is_out_of_range_value_included(axes, rank), - "axes include out of range index." - "axes should be in the range of [-rank, rank-1]."); + "Axes include an out-of-range index." + "Axes must be in the range of [-rank, rank-1]."); axis_type shift = {0}; for (int i = 0; i < static_cast(DIM); i++) { diff --git a/common/src/KokkosFFT_layouts.hpp b/common/src/KokkosFFT_layouts.hpp index 85441edc..065b898c 100644 --- a/common/src/KokkosFFT_layouts.hpp +++ b/common/src/KokkosFFT_layouts.hpp @@ -67,12 +67,11 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then R2C if (is_complex_v) { - if (_out_extents.at(inner_most_axis) != - _in_extents.at(inner_most_axis) / 2 + 1) { - throw std::runtime_error( - "For R2C, the output extent of transform should be input extent / " - "2 + 1"); - } + KOKKOSFFT_EXPECTS( + _out_extents.at(inner_most_axis) == + _in_extents.at(inner_most_axis) / 2 + 1, + "For R2C, the 'output extent' of transform must be equal to " + "'input extent'/2 + 1"); } else { throw std::runtime_error( "If the input type is real, the output type should be complex"); @@ -82,12 +81,11 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then C2R if (is_complex_v) { - if (_in_extents.at(inner_most_axis) != - _out_extents.at(inner_most_axis) / 2 + 1) { - throw std::runtime_error( - "For C2R, the input extent of transform should be output extent / " - "2 + 1"); - } + KOKKOSFFT_EXPECTS( + _in_extents.at(inner_most_axis) == + _out_extents.at(inner_most_axis) / 2 + 1, + "For C2R, the 'input extent' of transform must be equal to " + "'output extent' / 2 + 1"); } else { throw std::runtime_error( "If the output type is real, the input type should be complex"); diff --git a/common/src/KokkosFFT_padding.hpp b/common/src/KokkosFFT_padding.hpp index b672d1c2..3fc059be 100644 --- a/common/src/KokkosFFT_padding.hpp +++ b/common/src/KokkosFFT_padding.hpp @@ -51,14 +51,12 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */, } // Assert if the elements are overlapped - if (KokkosFFT::Impl::has_duplicate_values(positive_axes)) { - throw std::runtime_error("get_modified_shape: axes are overlapped."); - } - if (KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank)) { - throw std::runtime_error( - "get_modified_shape: axes include out of range index." - "axes should be in the range of [-rank, rank-1]."); - } + KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(positive_axes), + "Axes overlap"); + KOKKOSFFT_EXPECTS( + !KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank), + "Axes include an out-of-range index." + "Axes must be in the range of [-rank, rank-1]."); using full_shape_type = shape_type; full_shape_type modified_shape; diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index 00bd6764..48492759 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -50,26 +50,21 @@ inline void check_precondition(const bool expression, } throw std::runtime_error(ss.str()); } -inline void check_precondition(const bool expression, const std::string& msg, - const char* file_name, int line, - const char* function_name) { - std::stringstream ss("file: "); - ss << file_name << '(' << line << ") `" << function_name << "`: " << msg - << '\n'; - if (!expression) { - throw std::runtime_error(ss.str()); - } -} -#endif template auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view::value, "convert_negative_axis: ViewType is not a Kokkos::View."); int rank = static_cast(ViewType::rank()); +<<<<<<< HEAD if (_axis < -rank || _axis >= rank) { throw std::runtime_error("axis should be in [-rank, rank-1]"); } +======= + + KOKKOSFFT_EXPECTS(_axis >= -rank && _axis < rank, + "Axis must be in [-rank, rank-1]"); +>>>>>>> a786585 (improve assertion) int axis = _axis < 0 ? rank + _axis : _axis; return axis;