Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace KOKKOSFFT_EXPECTS with KOKKOSFFT_THROW_IF #131

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {

// Assert if the elements are overlapped
constexpr int rank = ViewType::rank();
KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(axes),
"Axes overlap");
KOKKOSFFT_EXPECTS(
!KokkosFFT::Impl::is_out_of_range_value_included(axes, rank),
KOKKOSFFT_THROW_IF(KokkosFFT::Impl::has_duplicate_values(axes),
"Axes overlap");
KOKKOSFFT_THROW_IF(
KokkosFFT::Impl::is_out_of_range_value_included(axes, rank),
"Axes include an out-of-range index."
"Axes must be in the range of [-rank, rank-1].");

Expand Down
6 changes: 3 additions & 3 deletions common/src/KokkosFFT_asserts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

#if defined(__cpp_lib_source_location) && __cpp_lib_source_location >= 201907L
#include <source_location>
#define KOKKOSFFT_EXPECTS(expression, msg) \
#define KOKKOSFFT_THROW_IF(expression, msg) \
KokkosFFT::Impl::check_precondition( \
(expression), msg, std::source_location::current().file_name(), \
std::source_location::current().line(), \
std::source_location::current().function_name(), \
std::source_location::current().column())
#else
#include <cstdlib>
#define KOKKOSFFT_EXPECTS(expression, msg) \
#define KOKKOSFFT_THROW_IF(expression, msg) \
KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \
__FUNCTION__)
#endif
Expand All @@ -33,7 +33,7 @@ inline void check_precondition(const bool expression,
const char* function_name,
const int column = -1) {
// Quick return if possible
if (expression) return;
if (!expression) return;

std::stringstream ss("file: ");
if (column == -1) {
Expand Down
12 changes: 6 additions & 6 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ auto get_extents(const InViewType& in, const OutViewType& out,
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
Expand Down Expand Up @@ -65,17 +65,17 @@ auto get_extents(const InViewType& in, const OutViewType& out,

if (is_real_v<in_value_type>) {
// Then R2C
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
KOKKOSFFT_THROW_IF(
_out_extents.at(inner_most_axis) !=
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
}

if (is_real_v<out_value_type>) {
// Then C2R
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
KOKKOSFFT_THROW_IF(
_in_extents.at(inner_most_axis) !=
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
Expand Down
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */,
static_assert(
KokkosFFT::Impl::have_same_rank_v<InViewType, OutViewType>,
"get_modified_shape: Input View and Output View must have the same rank");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
Expand Down
8 changes: 4 additions & 4 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
Expand Down Expand Up @@ -400,8 +400,8 @@ void transpose(const ExecutionSpace& exec_space, InViewType& in,
"transpose: Rank of View must be equal to Rank of "
"transpose axes.");

KOKKOSFFT_EXPECTS(KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");

// in order not to call transpose_impl for 1D case
if constexpr (DIM > 1) {
Expand Down
66 changes: 33 additions & 33 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ auto convert_negative_axis(ViewType, int _axis = -1) {
"convert_negative_axis: ViewType must be a Kokkos::View.");
int rank = static_cast<int>(ViewType::rank());

KOKKOSFFT_EXPECTS(_axis >= -rank && _axis < rank,
"Axis must be in [-rank, rank-1]");
KOKKOSFFT_THROW_IF(_axis < -rank || _axis >= rank,
"Axis must be in [-rank, rank-1]");

int axis = _axis < 0 ? rank + _axis : _axis;
return axis;
Expand Down Expand Up @@ -130,7 +130,7 @@ std::size_t get_index(ContainerType& values, const ValueType& value) {
static_assert(std::is_same_v<value_type, ValueType>,
"get_index: Container value type must match ValueType");
auto it = std::find(values.begin(), values.end(), value);
KOKKOSFFT_EXPECTS(it != values.end(), "value is not included in values");
KOKKOSFFT_THROW_IF(it == values.end(), "value is not included in values");
return it - values.begin();
}

Expand Down Expand Up @@ -256,44 +256,44 @@ void create_view(ViewType& out, const Label& label,

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 1>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(extents[0]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 2>& extents) {
KOKKOSFFT_EXPECTS(
ViewType::required_allocation_size(out.layout()) >=
KOKKOSFFT_THROW_IF(
ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0], extents[1]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 3>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 4>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3]),
"reshape_view: insufficient memory");

out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 5>& extents) {
KOKKOSFFT_EXPECTS(
ViewType::required_allocation_size(out.layout()) >=
KOKKOSFFT_THROW_IF(
ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(extents[0], extents[1], extents[2],
extents[3], extents[4]),
"reshape_view: insufficient memory");
Expand All @@ -303,33 +303,33 @@ void reshape_view(ViewType& out, const std::array<int, 5>& extents) {

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 6>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 7>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6]);
}

template <typename ViewType>
void reshape_view(ViewType& out, const std::array<int, 8>& extents) {
KOKKOSFFT_EXPECTS(ViewType::required_allocation_size(out.layout()) >=
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]),
"reshape_view: insufficient memory");
KOKKOSFFT_THROW_IF(ViewType::required_allocation_size(out.layout()) <
ViewType::required_allocation_size(
extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]),
"reshape_view: insufficient memory");
out = ViewType(out.data(), extents[0], extents[1], extents[2], extents[3],
extents[4], extents[5], extents[6], extents[7]);
}
Expand Down
16 changes: 8 additions & 8 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -45,7 +45,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan1d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed");

return fft_size;
}
Expand All @@ -69,7 +69,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -83,7 +83,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan2d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed");

return fft_size;
}
Expand All @@ -107,7 +107,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -123,7 +123,7 @@ auto create_plan(const ExecutionSpace& exec_space,
std::multiplies<>());

cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlan3d failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed");

return fft_size;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ auto create_plan(const ExecutionSpace& exec_space,

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftCreate(&(*plan));
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftCreate failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream((*plan), stream);
Expand All @@ -176,7 +176,7 @@ auto create_plan(const ExecutionSpace& exec_space,
in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);

KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftPlanMany failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed");

return fft_size;
}
Expand Down
12 changes: 6 additions & 6 deletions fft/src/KokkosFFT_Cuda_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,42 @@ template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecR2C failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecR2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata,
cufftDoubleComplex* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecD2Z failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecD2Z failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2R failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2R failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleReal* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2D failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2D failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata,
cufftComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecC2C failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
KOKKOSFFT_EXPECTS(cufft_rt == CUFFT_SUCCESS, "cufftExecZ2Z failed");
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2Z failed");
}
} // namespace Impl
} // namespace KokkosFFT
Expand Down
Loading
Loading