Skip to content

Commit

Permalink
Allow dyn-rank-view in serial trsv (#2464)
Browse files Browse the repository at this point in the history
* Allow dyn-rank-view in serial trsv

Signed-off-by: Yuuichi Asahi <[email protected]>

* suppress shape checks

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Dec 19, 2024
1 parent d96c6ee commit 07de262
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 64 deletions.
132 changes: 98 additions & 34 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include <Kokkos_DynRankView.hpp>
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsv_Serial_Internal.hpp"

Expand All @@ -27,20 +28,39 @@ namespace Impl {
template <typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const bViewType &b) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::trsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<bViewType>, "KokkosBatched::trsv: bViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2.");
static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1.");
static_assert(Kokkos::is_view_v<AViewType> || Kokkos::is_dyn_rank_view_v<AViewType>,
"KokkosBatched::trsv: AViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
static_assert(Kokkos::is_view_v<bViewType> || Kokkos::is_dyn_rank_view_v<bViewType>,
"KokkosBatched::trsv: bViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
if (A.rank() != 2) {
Kokkos::printf(
"KokkosBatched::trsv: leading dimension of A must not be smaller than "
"max(1, n): "
"lda = %d, n = %d\n",
lda, n);
"KokkosBatched::trsv: A must be a rank 2 View."
"A.rank() = %d\n",
A.rank());
return 1;
}

if (b.rank() != 1) {
Kokkos::printf(
"KokkosBatched::trsv: b must be a rank 1 View."
"b.rank() = %d\n",
b.rank());
return 1;
}

// FIXME : check leading dimension is suppressed for now
// because of the compatibility issue with Trilinos
// const int lda = A.extent(0), n = A.extent(1);
// if (lda < Kokkos::max(1, n)) {
// Kokkos::printf(
// "KokkosBatched::trsv: leading dimension of A must not be smaller than "
// "max(1, n): "
// "lda = %d, n = %d\n",
// lda, n);
// return 1;
// }

#endif
return 0;
}
Expand All @@ -53,6 +73,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -88,12 +114,13 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -102,12 +129,13 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -118,6 +146,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -153,12 +187,13 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -167,12 +202,13 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -183,6 +219,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -218,12 +260,12 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -232,12 +274,12 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -248,6 +290,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -283,12 +331,13 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -297,12 +346,13 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -313,6 +363,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -348,12 +404,13 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -362,12 +419,13 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -378,6 +436,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -413,12 +477,12 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -427,12 +491,12 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand Down
Loading

0 comments on commit 07de262

Please sign in to comment.