Skip to content

Commit

Permalink
Fixes while documenting (#2466)
Browse files Browse the repository at this point in the history
* BLAS - scal: removing check on assignable memory spaces

That check is stricter than required as we will values by reference
to perform copies and won't try to reassign pointers.

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - rot: check at runtime that X and Y have same extent

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - rot: improving static assertions

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - rotg: check for non-complex types

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - ger: check that matrix stores values as non-const

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - trmm: check for valid execution space type.

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS: fix missing semi-colon at end of static_assert

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* Applying clang-format

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* More clang-format

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* Blas - rot: fixing interface of rot

The cosine coefficient is strictly real while the sine coefficient
can be real or complex leading to a bug in the current API. This
commit should fix that for the native and TPL implementation and
the associated unit-test is also fixed accordingly.

Signed-off-by: Luc Berger-Vergiat <[email protected]>

* BLAS - ROT: fixing types for Host TPL calls to ROT function

The types for the arguments c and s are actually different and need
to be appropriately propagated through the TPL layers of the library.

Signed-off-by: Luc Berger-Vergiat <[email protected]>

---------

Signed-off-by: Luc Berger-Vergiat <[email protected]>
  • Loading branch information
lucbv authored Jan 6, 2025
1 parent 85bbf1f commit 0adc88b
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 149 deletions.
11 changes: 6 additions & 5 deletions blas/impl/KokkosBlas1_rot_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
namespace KokkosBlas {
namespace Impl {

template <class VectorView, class ScalarView>
template <class VectorView, class MagnitudeView, class ScalarView>
struct rot_functor {
using scalar_type = typename VectorView::non_const_value_type;

VectorView X, Y;
ScalarView c, s;
MagnitudeView c;
ScalarView s;

rot_functor(VectorView const& X_, VectorView const& Y_, ScalarView const& c_, ScalarView const& s_)
rot_functor(VectorView const& X_, VectorView const& Y_, MagnitudeView const& c_, ScalarView const& s_)
: X(X_), Y(Y_), c(c_), s(s_) {}

KOKKOS_INLINE_FUNCTION
Expand All @@ -41,8 +42,8 @@ struct rot_functor {
}
};

template <class ExecutionSpace, class VectorView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::RangePolicy<ExecutionSpace> rot_policy(space, 0, X.extent(0));
rot_functor rot_func(X, Y, c, s);
Expand Down
37 changes: 20 additions & 17 deletions blas/impl/KokkosBlas1_rot_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class ExecutionSpace, class VectorView, class ScalarView>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct rot_eti_spec_avail {
enum : bool { value = false };
};
Expand All @@ -43,14 +43,15 @@ struct rot_eti_spec_avail {
// We may spread out definitions (see _INST macro below) across one or
// more .cpp files.
//
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct rot_eti_spec_avail< \
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct rot_eti_spec_avail< \
EXECSPACE, \
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
};

// Include the actual specialization declarations
Expand All @@ -61,19 +62,19 @@ namespace KokkosBlas {
namespace Impl {

// Unification layer
template <class ExecutionSpace, class VectorView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, ScalarView>::value>
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView,
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value,
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value>
struct Rot {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s);
};

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
//! Full specialization of Rot.
template <class ExecutionSpace, class VectorView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
struct Rot<ExecutionSpace, VectorView, MagnitudeView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY ? "KokkosBlas::rot[ETI]"
: "KokkosBlas::rot[noETI]");
Expand All @@ -86,7 +87,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
typeid(VectorView).name(), typeid(ScalarView).name());
}
#endif
Rot_Invoke<ExecutionSpace, VectorView, ScalarView>(space, X, Y, c, s);
Rot_Invoke<ExecutionSpace, VectorView, MagnitudeView, ScalarView>(space, X, Y, c, s);
Kokkos::Profiling::popRegion();
}
};
Expand All @@ -108,6 +109,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

//
Expand All @@ -121,6 +123,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
false, true>;

#include <KokkosBlas1_rot_tpl_spec_decl.hpp>
Expand Down
37 changes: 30 additions & 7 deletions blas/src/KokkosBlas1_rot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,68 @@

namespace KokkosBlas {

template <class execution_space, class VectorView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
template <class execution_space, class VectorView, class MagnitudeView, class ScalarView>
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
ScalarView const& s) {
static_assert(Kokkos::is_execution_space<execution_space>::value,
"rot: execution_space template parameter is not a Kokkos "
"execution space.");
static_assert(Kokkos::is_view_v<VectorView>, "KokkosBlas::rot: VectorView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<MagnitudeView>, "KokkosBlas::rot: MagnitudeView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<ScalarView>, "KokkosBlas::rot: ScalarView is not a Kokkos::View.");
static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view");
static_assert(MagnitudeView::rank == 0, "rot: MagnitudeView template parameter needs to be a rank 0 view");
static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename VectorView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MagnitudeView::memory_space>::accessible,
"rot: MagnitudeView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename ScalarView::memory_space>::accessible,
"rot: VectorView template parameter memory space needs to be accessible "
"rot: ScalarView template parameter memory space needs to be accessible "
"from "
"execution_space template parameter");
static_assert(std::is_same<typename VectorView::non_const_value_type, typename VectorView::value_type>::value,
"rot: VectorView template parameter needs to store non-const values");

// Check compatibility of dimensions at run time.
if (X.extent(0) != Y.extent(0)) {
std::ostringstream os;
os << "KokkosBlas::rot: Dimensions of X and Y do not match: "
<< "X: " << X.extent(0) << ", Y: " << Y.extent(0);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

using VectorView_Internal = Kokkos::View<typename VectorView::non_const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<VectorView>::array_layout,
Kokkos::Device<execution_space, typename VectorView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using MagnitudeView_Internal = Kokkos::View<typename MagnitudeView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

using ScalarView_Internal = Kokkos::View<typename ScalarView::non_const_value_type,
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

VectorView_Internal X_(X), Y_(Y);
ScalarView_Internal c_(c), s_(s);
MagnitudeView_Internal c_(c);
ScalarView_Internal s_(s);

Kokkos::Profiling::pushRegion("KokkosBlas::rot");
Impl::Rot<execution_space, VectorView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_, s_);
Impl::Rot<execution_space, VectorView_Internal, MagnitudeView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_,
s_);
Kokkos::Profiling::popRegion();
}

template <class VectorView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, ScalarView const& c, ScalarView const& s) {
template <class VectorView, class MagnitudeView, class ScalarView>
void rot(VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) {
const typename VectorView::execution_space space = typename VectorView::execution_space();
rot(space, X, Y, c, s);
}
Expand Down
2 changes: 2 additions & 0 deletions blas/src/KokkosBlas1_rotg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void rotg(execution_space const& space, SViewType const& a, SViewType const& b,
"rotg: execution_space cannot access data in SViewType");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MViewType::memory_space>::accessible,
"rotg: execution_space cannot access data in MViewType");
static_assert(!Kokkos::ArithTraits<typename MViewType::value_type>::is_complex,
"rotg: MViewType cannot hold complex values.");

using SView_Internal = Kokkos::View<
typename SViewType::value_type, typename KokkosKernels::Impl::GetUnifiedLayout<SViewType>::array_layout,
Expand Down
2 changes: 0 additions & 2 deletions blas/src/KokkosBlas1_scal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ void scal(const execution_space& space, const RMV& R, const AV& a, const XMV& X)
"X is not a Kokkos::View.");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename XMV::memory_space>::accessible,
"KokkosBlas::scal: XMV must be accessible from execution_space");
static_assert(Kokkos::SpaceAccessibility<typename RMV::memory_space, typename XMV::memory_space>::assignable,
"KokkosBlas::scal: XMV must be assignable to RMV");
static_assert(std::is_same<typename RMV::value_type, typename RMV::non_const_value_type>::value,
"KokkosBlas::scal: R is const. "
"It must be nonconst, because it is an output argument "
Expand Down
21 changes: 12 additions & 9 deletions blas/src/KokkosBlas2_ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@ template <class ExecutionSpace, class XViewType, class YViewType, class AViewTyp
void ger(const ExecutionSpace& space, const char trans[], const typename AViewType::const_value_type& alpha,
const XViewType& x, const YViewType& y, const AViewType& A) {
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename AViewType::memory_space>::accessible,
"AViewType memory space must be accessible from ExecutionSpace");
"ger: AViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename XViewType::memory_space>::accessible,
"XViewType memory space must be accessible from ExecutionSpace");
"ger: XViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename YViewType::memory_space>::accessible,
"YViewType memory space must be accessible from ExecutionSpace");
"ger: YViewType memory space must be accessible from ExecutionSpace");

static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "YViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<AViewType>::value, "ger: AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "ger: XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "ger: YViewType must be a Kokkos::View.");

static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "YViewType must have rank 1.");
static_assert(static_cast<int>(AViewType::rank) == 2, "ger: AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "ger: XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "ger: YViewType must have rank 1.");

static_assert(std::is_same_v<typename AViewType::value_type, typename AViewType::non_const_value_type>,
"ger: AViewType must store non const values.");

// Check compatibility of dimensions at run time.
if ((A.extent(0) != x.extent(0)) || (A.extent(1) != y.extent(0))) {
Expand Down
12 changes: 8 additions & 4 deletions blas/src/KokkosBlas3_trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ namespace KokkosBlas {
template <class execution_space, class AViewType, class BViewType>
void trmm(const execution_space& space, const char side[], const char uplo[], const char trans[], const char diag[],
typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) {
static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<BViewType>::value, "BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "BViewType must have rank 2.");
static_assert(Kokkos::is_execution_space_v<execution_space>,
"trmm: execution_space must be a Kokkos::execution_space.");
static_assert(Kokkos::is_view_v<AViewType>,
"trmm: AViewType must be a "
"Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "trmm: BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "trmm: AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "trmm: BViewType must have rank 2.");

// Check validity of indicator argument
bool valid_side = (side[0] == 'L') || (side[0] == 'l') || (side[0] == 'R') || (side[0] == 'r');
Expand Down
Loading

0 comments on commit 0adc88b

Please sign in to comment.