Skip to content

Commit

Permalink
Blas - rot: fixing interface of rot
Browse files Browse the repository at this point in the history
The cosine coefficient is strictly real while the sine coefficient
can be real or complex leading to a bug in the current API. This
commit should fix that for the native and TPL implementation and
the associated unit-test is also fixed accordingly.

Signed-off-by: Luc Berger-Vergiat <[email protected]>
  • Loading branch information
lucbv committed Dec 18, 2024
1 parent 0569e9d commit 6667b94
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 35 deletions.
11 changes: 6 additions & 5 deletions blas/impl/KokkosBlas1_rot_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
namespace KokkosBlas {
namespace Impl {

template <class VectorView, class ScalarView>
template <class VectorView, class MagnitudeView, class ScalarView>
struct rot_functor {
using scalar_type = typename VectorView::non_const_value_type;

VectorView X, Y;
ScalarView c, s;
MagnitudeView c;
ScalarView s;

rot_functor(VectorView const& X_, VectorView const& Y_, ScalarView const& c_, ScalarView const& s_)
rot_functor(VectorView const& X_, VectorView const& Y_, MagnitudeView const& c_, ScalarView const& s_)
: X(X_), Y(Y_), c(c_), s(s_) {}

KOKKOS_INLINE_FUNCTION
Expand All @@ -41,8 +42,8 @@ struct rot_functor {
}
};

template <class ExecutionSpace, class VectorView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::RangePolicy<ExecutionSpace> rot_policy(space, 0, X.extent(0));
rot_functor rot_func(X, Y, c, s);
Expand Down
23 changes: 13 additions & 10 deletions blas/impl/KokkosBlas1_rot_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class ExecutionSpace, class VectorView, class ScalarView>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct rot_eti_spec_avail {
enum : bool { value = false };
};
Expand All @@ -49,7 +49,8 @@ struct rot_eti_spec_avail {
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>> {\
enum : bool { value = true }; \
};

Expand All @@ -61,19 +62,19 @@ namespace KokkosBlas {
namespace Impl {

// Unification layer
template <class ExecutionSpace, class VectorView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, ScalarView>::value>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value>
struct Rot {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s);
};

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
//! Full specialization of Rot.
template <class ExecutionSpace, class VectorView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, MagnitudeView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY ? "KokkosBlas::rot[ETI]"
: "KokkosBlas::rot[noETI]");
Expand All @@ -86,7 +87,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
typeid(VectorView).name(), typeid(ScalarView).name());
}
#endif
Rot_Invoke<ExecutionSpace, VectorView, ScalarView>(space, X, Y, c, s);
Rot_Invoke<ExecutionSpace, VectorView, MagnitudeView, ScalarView>(space, X, Y, c, s);
Kokkos::Profiling::popRegion();
}
};
Expand All @@ -108,6 +109,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

//
Expand All @@ -121,6 +123,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

#include <KokkosBlas1_rot_tpl_spec_decl.hpp>
Expand Down
26 changes: 19 additions & 7 deletions blas/src/KokkosBlas1_rot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,28 @@

namespace KokkosBlas {

template <class execution_space, class VectorView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class execution_space, class VectorView, class MagnitudeView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
static_assert(Kokkos::is_execution_space<execution_space>::value,
"rot: execution_space template parameter is not a Kokkos "
"execution space.");
static_assert(Kokkos::is_view_v<VectorView>, "KokkosBlas::rot: VectorView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<MagnitudeView>, "KokkosBlas::rot: MagnitudeView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<ScalarView>, "KokkosBlas::rot: ScalarView is not a Kokkos::View.");
static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view");
static_assert(MagnitudeView::rank == 0, "rot: MagnitudeView template parameter needs to be a rank 0 view");
static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename VectorView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MagnitudeView::memory_space>::accessible,
"rot: MagnitudeView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename ScalarView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"rot: ScalarView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(std::is_same<typename VectorView::non_const_value_type, typename VectorView::value_type>::value,
Expand All @@ -55,21 +61,27 @@ void rot(execution_space const& space, VectorView const& X, VectorView const& Y,
Kokkos::Device<execution_space, typename VectorView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using MagnitudeView_Internal = Kokkos::View<typename MagnitudeView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using ScalarView_Internal = Kokkos::View<typename ScalarView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

VectorView_Internal X_(X), Y_(Y);
ScalarView_Internal c_(c), s_(s);
MagnitudeView_Internal c_(c);
ScalarView_Internal s_(s);

Kokkos::Profiling::pushRegion("KokkosBlas::rot");
Impl::Rot<execution_space, VectorView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_, s_);
Impl::Rot<execution_space, VectorView_Internal, MagnitudeView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_, s_);
Kokkos::Profiling::popRegion();
}

template <class VectorView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, ScalarView const& c, ScalarView const& s) {
template <class VectorView, class MagnitudeView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) {
const typename VectorView::execution_space space = typename VectorView::execution_space();
rot(space, X, Y, c, s);
}
Expand Down
7 changes: 6 additions & 1 deletion blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class ExecutionSpace, class VectorView, class ScalarView>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct rot_tpl_spec_avail {
enum : bool { value = false };
};
Expand All @@ -37,6 +37,9 @@ namespace Impl {
struct rot_tpl_spec_avail<EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, \
Kokkos::Device<EXECSPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
Expand Down Expand Up @@ -64,6 +67,8 @@ KOKKOSBLAS1_ROT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
struct rot_tpl_spec_avail< \
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
};
Expand Down
Loading

0 comments on commit 6667b94

Please sign in to comment.