Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes while documenting #2466

Merged
merged 11 commits into from
Jan 6, 2025
11 changes: 6 additions & 5 deletions blas/impl/KokkosBlas1_rot_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
namespace KokkosBlas {
namespace Impl {

template <class VectorView, class ScalarView>
template <class VectorView, class MagnitudeView, class ScalarView>
struct rot_functor {
using scalar_type = typename VectorView::non_const_value_type;

VectorView X, Y;
ScalarView c, s;
MagnitudeView c;
ScalarView s;

rot_functor(VectorView const& X_, VectorView const& Y_, ScalarView const& c_, ScalarView const& s_)
rot_functor(VectorView const& X_, VectorView const& Y_, MagnitudeView const& c_, ScalarView const& s_)
: X(X_), Y(Y_), c(c_), s(s_) {}

KOKKOS_INLINE_FUNCTION
Expand All @@ -41,8 +42,8 @@ struct rot_functor {
}
};

template <class ExecutionSpace, class VectorView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::RangePolicy<ExecutionSpace> rot_policy(space, 0, X.extent(0));
rot_functor rot_func(X, Y, c, s);
Expand Down
37 changes: 20 additions & 17 deletions blas/impl/KokkosBlas1_rot_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class ExecutionSpace, class VectorView, class ScalarView>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct rot_eti_spec_avail {
enum : bool { value = false };
};
Expand All @@ -43,14 +43,15 @@ struct rot_eti_spec_avail {
// We may spread out definitions (see _INST macro below) across one or
// more .cpp files.
//
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct rot_eti_spec_avail< \
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct rot_eti_spec_avail< \
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
};

// Include the actual specialization declarations
Expand All @@ -61,19 +62,19 @@ namespace KokkosBlas {
namespace Impl {

// Unification layer
template <class ExecutionSpace, class VectorView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, ScalarView>::value>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value>
struct Rot {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s);
};

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
//! Full specialization of Rot.
template <class ExecutionSpace, class VectorView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, MagnitudeView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY ? "KokkosBlas::rot[ETI]"
: "KokkosBlas::rot[noETI]");
Expand All @@ -86,7 +87,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
typeid(VectorView).name(), typeid(ScalarView).name());
}
#endif
Rot_Invoke<ExecutionSpace, VectorView, ScalarView>(space, X, Y, c, s);
Rot_Invoke<ExecutionSpace, VectorView, MagnitudeView, ScalarView>(space, X, Y, c, s);
Kokkos::Profiling::popRegion();
}
};
Expand All @@ -108,6 +109,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

//
Expand All @@ -121,6 +123,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

#include <KokkosBlas1_rot_tpl_spec_decl.hpp>
Expand Down
37 changes: 30 additions & 7 deletions blas/src/KokkosBlas1_rot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,68 @@

namespace KokkosBlas {

template <class execution_space, class VectorView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class execution_space, class VectorView, class MagnitudeView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
static_assert(Kokkos::is_execution_space<execution_space>::value,
"rot: execution_space template parameter is not a Kokkos "
"execution space.");
static_assert(Kokkos::is_view_v<VectorView>, "KokkosBlas::rot: VectorView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<MagnitudeView>, "KokkosBlas::rot: MagnitudeView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<ScalarView>, "KokkosBlas::rot: ScalarView is not a Kokkos::View.");
static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view");
static_assert(MagnitudeView::rank == 0, "rot: MagnitudeView template parameter needs to be a rank 0 view");
static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename VectorView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MagnitudeView::memory_space>::accessible,
"rot: MagnitudeView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename ScalarView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"rot: ScalarView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(std::is_same<typename VectorView::non_const_value_type, typename VectorView::value_type>::value,
"rot: VectorView template parameter needs to store non-const values");

// Check compatibility of dimensions at run time.
if (X.extent(0) != Y.extent(0)) {
std::ostringstream os;
os << "KokkosBlas::rot: Dimensions of X and Y do not match: "
<< "X: " << X.extent(0) << ", Y: " << Y.extent(0);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

using VectorView_Internal = Kokkos::View<typename VectorView::non_const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<VectorView>::array_layout,
Kokkos::Device<execution_space, typename VectorView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using MagnitudeView_Internal = Kokkos::View<typename MagnitudeView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using ScalarView_Internal = Kokkos::View<typename ScalarView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

VectorView_Internal X_(X), Y_(Y);
ScalarView_Internal c_(c), s_(s);
MagnitudeView_Internal c_(c);
ScalarView_Internal s_(s);

Kokkos::Profiling::pushRegion("KokkosBlas::rot");
Impl::Rot<execution_space, VectorView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_, s_);
Impl::Rot<execution_space, VectorView_Internal, MagnitudeView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_,
s_);
Kokkos::Profiling::popRegion();
}

template <class VectorView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, ScalarView const& c, ScalarView const& s) {
template <class VectorView, class MagnitudeView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) {
const typename VectorView::execution_space space = typename VectorView::execution_space();
rot(space, X, Y, c, s);
}
Expand Down
2 changes: 2 additions & 0 deletions blas/src/KokkosBlas1_rotg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void rotg(execution_space const& space, SViewType const& a, SViewType const& b,
"rotg: execution_space cannot access data in SViewType");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MViewType::memory_space>::accessible,
"rotg: execution_space cannot access data in MViewType");
static_assert(!Kokkos::ArithTraits<typename MViewType::value_type>::is_complex,
"rotg: MViewType cannot hold complex values.");

using SView_Internal = Kokkos::View<
typename SViewType::value_type, typename KokkosKernels::Impl::GetUnifiedLayout<SViewType>::array_layout,
Expand Down
2 changes: 0 additions & 2 deletions blas/src/KokkosBlas1_scal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ void scal(const execution_space& space, const RMV& R, const AV& a, const XMV& X)
"X is not a Kokkos::View.");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename XMV::memory_space>::accessible,
"KokkosBlas::scal: XMV must be accessible from execution_space");
static_assert(Kokkos::SpaceAccessibility<typename RMV::memory_space, typename XMV::memory_space>::assignable,
"KokkosBlas::scal: XMV must be assignable to RMV");
static_assert(std::is_same<typename RMV::value_type, typename RMV::non_const_value_type>::value,
"KokkosBlas::scal: R is const. "
"It must be nonconst, because it is an output argument "
Expand Down
21 changes: 12 additions & 9 deletions blas/src/KokkosBlas2_ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@ template <class ExecutionSpace, class XViewType, class YViewType, class AViewTyp
void ger(const ExecutionSpace& space, const char trans[], const typename AViewType::const_value_type& alpha,
const XViewType& x, const YViewType& y, const AViewType& A) {
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename AViewType::memory_space>::accessible,
"AViewType memory space must be accessible from ExecutionSpace");
"ger: AViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename XViewType::memory_space>::accessible,
"XViewType memory space must be accessible from ExecutionSpace");
"ger: XViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename YViewType::memory_space>::accessible,
"YViewType memory space must be accessible from ExecutionSpace");
"ger: YViewType memory space must be accessible from ExecutionSpace");

static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "YViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<AViewType>::value, "ger: AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "ger: XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "ger: YViewType must be a Kokkos::View.");

static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "YViewType must have rank 1.");
static_assert(static_cast<int>(AViewType::rank) == 2, "ger: AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "ger: XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "ger: YViewType must have rank 1.");

static_assert(std::is_same_v<typename AViewType::value_type, typename AViewType::non_const_value_type>,
"ger: AViewType must store non const values.");

// Check compatibility of dimensions at run time.
if ((A.extent(0) != x.extent(0)) || (A.extent(1) != y.extent(0))) {
Expand Down
12 changes: 8 additions & 4 deletions blas/src/KokkosBlas3_trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ namespace KokkosBlas {
template <class execution_space, class AViewType, class BViewType>
void trmm(const execution_space& space, const char side[], const char uplo[], const char trans[], const char diag[],
typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) {
static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<BViewType>::value, "BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "BViewType must have rank 2.");
static_assert(Kokkos::is_execution_space_v<execution_space>,
"trmm: execution_space must be a Kokkos::execution_space.");
static_assert(Kokkos::is_view_v<AViewType>,
"trmm: AViewType must be a "
"Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "trmm: BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "trmm: AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "trmm: BViewType must have rank 2.");

// Check validity of indicator argument
bool valid_side = (side[0] == 'L') || (side[0] == 'l') || (side[0] == 'R') || (side[0] == 'r');
Expand Down
Loading
Loading