Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl namespace #13

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cufft.h>

namespace KokkosFFT {
namespace Impl {
#define KOKKOS_FFT_FORWARD CUFFT_FORWARD
#define KOKKOS_FFT_BACKWARD CUFFT_INVERSE
#define KOKKOS_FFT_R2C CUFFT_R2C
Expand Down Expand Up @@ -54,6 +55,7 @@ namespace KokkosFFT {
static constexpr TransformType m_type = std::is_same_v<T1, float> ? KOKKOS_FFT_C2C : KOKKOS_FFT_Z2Z;
static constexpr TransformType type() { return m_type; };
};
};
} // namespace Impl
}; // namespace KokkosFFT

#endif
4 changes: 3 additions & 1 deletion common/src/KokkosFFT_HIP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <hipfft/hipfft.h>

namespace KokkosFFT {
namespace Impl {
#define KOKKOS_FFT_FORWARD HIPFFT_FORWARD
#define KOKKOS_FFT_BACKWARD HIPFFT_BACKWARD
#define KOKKOS_FFT_R2C HIPFFT_R2C
Expand Down Expand Up @@ -54,6 +55,7 @@ namespace KokkosFFT {
static constexpr TransformType m_type = std::is_same_v<T1, float> ? KOKKOS_FFT_C2C : KOKKOS_FFT_Z2Z;
static constexpr TransformType type() { return m_type; };
};
};
} // namespace Impl
}; // namespace KokkosFFT

#endif
6 changes: 4 additions & 2 deletions common/src/KokkosFFT_OpenMP_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
namespace Impl {
enum class TransformType {
R2C,
D2Z,
Expand Down Expand Up @@ -32,7 +33,7 @@ namespace KokkosFFT {

template <typename T>
struct FFTPlanType {
using type = std::conditional_t<std::is_same_v<real_type_t<T>, float>, fftwf_plan, fftw_plan>;
using type = std::conditional_t<std::is_same_v<KokkosFFT::Impl::real_type_t<T>, float>, fftwf_plan, fftw_plan>;
};

using FFTResultType = int;
Expand Down Expand Up @@ -63,6 +64,7 @@ namespace KokkosFFT {
static constexpr TransformType m_type = std::is_same_v<T1, float> ? KOKKOS_FFT_C2C : KOKKOS_FFT_Z2Z;
static constexpr TransformType type() { return m_type; };
};
};
} // namespace Impl
}; // namespace KokkosFFT

#endif
32 changes: 21 additions & 11 deletions common/src/KokkosFFT_default_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,37 @@
#include "KokkosFFT_utils.hpp"

// Check the size of complex type
static_assert(sizeof(KokkosFFT::FFTDataType::complex64) == sizeof(Kokkos::complex<float>));
static_assert(alignof(KokkosFFT::FFTDataType::complex64) <= alignof(Kokkos::complex<float>));
static_assert(sizeof(KokkosFFT::Impl::FFTDataType::complex64) == sizeof(Kokkos::complex<float>));
static_assert(alignof(KokkosFFT::Impl::FFTDataType::complex64) <= alignof(Kokkos::complex<float>));

static_assert(sizeof(KokkosFFT::FFTDataType::complex128) == sizeof(Kokkos::complex<double>));
static_assert(alignof(KokkosFFT::FFTDataType::complex128) <= alignof(Kokkos::complex<double>));
static_assert(sizeof(KokkosFFT::Impl::FFTDataType::complex128) == sizeof(Kokkos::complex<double>));
static_assert(alignof(KokkosFFT::Impl::FFTDataType::complex128) <= alignof(Kokkos::complex<double>));

namespace KokkosFFT {
// Define type to specify transform axis
template <std::size_t DIM>
using axis_type = std::array<int, DIM>;

enum class Normalization {
FORWARD,
BACKWARD,
ORTHO
};
} // namespace KokkosFFT

namespace KokkosFFT {
namespace Impl {
// Define fft data types
template <typename T>
struct fft_data_type {
using type = std::conditional_t<std::is_same_v<T, float>, KokkosFFT::FFTDataType::float32, KokkosFFT::FFTDataType::float64>;
using type = std::conditional_t<std::is_same_v<T, float>, KokkosFFT::Impl::FFTDataType::float32, KokkosFFT::Impl::FFTDataType::float64>;
};

template <typename T>
struct fft_data_type<Kokkos::complex<T>> {
using type = std::conditional_t<std::is_same_v<T, float>, KokkosFFT::FFTDataType::complex64, KokkosFFT::FFTDataType::complex128>;
using type = std::conditional_t<std::is_same_v<T, float>, KokkosFFT::Impl::FFTDataType::complex64, KokkosFFT::Impl::FFTDataType::complex128>;
};

// Define type to specify transform axis
template <std::size_t DIM>
using axis_type = std::array<int, DIM>;
}
} // namespace Impl
} // namespace KokkosFFT

#endif
8 changes: 5 additions & 3 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "KokkosFFT_transpose.hpp"

namespace KokkosFFT {
namespace Impl {
/* Input and output extents exposed to the fft library
i.e extents are converted into Layout Right
*/
Expand All @@ -21,7 +22,7 @@ namespace KokkosFFT {
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = get_map_axes(in, _axes);
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

constexpr std::size_t rank = InViewType::rank;
int inner_most_axis = std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft> ? 0 : rank - 1;
Expand Down Expand Up @@ -77,7 +78,7 @@ namespace KokkosFFT {
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = get_map_axes(in, _axes);
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

static_assert(InViewType::rank() >= DIM,
"KokkosFFT::get_map_axes: Rank of View must be larger thane or equal to the Rank of FFT axes.");
Expand Down Expand Up @@ -168,6 +169,7 @@ namespace KokkosFFT {
auto get_extents_batched(InViewType& in, OutViewType& out, int _axis) {
return get_extents_batched(in, out, axis_type<1>{_axis});
}
};
} // namespace Impl
}; // namespace KokkosFFT

#endif
12 changes: 4 additions & 8 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
enum class Normalization {
FORWARD,
BACKWARD,
ORTHO
};

namespace Impl {
template <typename ExecutionSpace, typename ViewType, typename T>
void _normalize(const ExecutionSpace& exec_space, ViewType& inout, const T coef) {
std::size_t size = inout.size();
Expand All @@ -25,7 +20,7 @@ namespace KokkosFFT {

template <typename ViewType>
auto _coefficients(const ViewType& inout, FFTDirectionType direction, Normalization normalization, std::size_t fft_size) {
using value_type = real_type_t<typename ViewType::non_const_value_type>;
using value_type = KokkosFFT::Impl::real_type_t<typename ViewType::non_const_value_type>;
value_type coef = 1;
bool to_normalize = false;

Expand Down Expand Up @@ -58,6 +53,7 @@ namespace KokkosFFT {
auto [coef, to_normalize] = _coefficients(inout, direction, normalization, fft_size);
if(to_normalize) _normalize(exec_space, inout, coef);
}
};
} // namespace Impl
}; // namespace KokkosFFT

#endif
10 changes: 6 additions & 4 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
static_assert(ViewType::rank() >= DIM,
Expand All @@ -20,12 +21,12 @@ namespace KokkosFFT {
// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
for(std::size_t i=0; i<DIM; i++) {
int axis = convert_negative_axis(view, _axes.at(i));
int axis = KokkosFFT::Impl::convert_negative_axis(view, _axes.at(i));
axes.push_back(axis);
}

// Assert if the elements are overlapped
assert( ! has_duplicate_values(axes) );
assert( ! KokkosFFT::Impl::has_duplicate_values(axes) );

// how indices are map
// For 5D View and axes are (2,3), map would be (0, 1, 4, 2, 3)
Expand Down Expand Up @@ -211,12 +212,13 @@ namespace KokkosFFT {
static_assert(InViewType::rank() == OutViewType::rank(),
"KokkosFFT::transpose: InViewType and OutViewType must have the same rank.");

if(!is_transpose_needed(_map)) {
if(!KokkosFFT::Impl::is_transpose_needed(_map)) {
throw std::runtime_error("KokkosFFT::transpose: transpose not necessary");
}

_transpose(exec_space, in, out, _map);
}
};
} // namespace Impl
} // namespace KokkosFFT

#endif
6 changes: 3 additions & 3 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <numeric>

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct real_type {
using type = T;
Expand Down Expand Up @@ -90,8 +91,7 @@ namespace KokkosFFT {
[=](const T sequence) -> T {return start + sequence;});
return sequence;
}


};
} // namespace Impl
}; // namespace KokkosFFT

#endif
Loading