Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dyn-rank-view in serial trsv #2464

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 98 additions & 34 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include <Kokkos_DynRankView.hpp>
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsv_Serial_Internal.hpp"

Expand All @@ -27,20 +28,39 @@ namespace Impl {
template <typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const bViewType &b) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::trsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<bViewType>, "KokkosBatched::trsv: bViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2.");
static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1.");
static_assert(Kokkos::is_view_v<AViewType> || Kokkos::is_dyn_rank_view_v<AViewType>,
"KokkosBatched::trsv: AViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
static_assert(Kokkos::is_view_v<bViewType> || Kokkos::is_dyn_rank_view_v<bViewType>,
"KokkosBatched::trsv: bViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
if (A.rank() != 2) {
Kokkos::printf(
"KokkosBatched::trsv: leading dimension of A must not be smaller than "
"max(1, n): "
"lda = %d, n = %d\n",
lda, n);
"KokkosBatched::trsv: A must be a rank 2 View."
"A.rank() = %d\n",
A.rank());
return 1;
}

if (b.rank() != 1) {
Kokkos::printf(
"KokkosBatched::trsv: b must be a rank 1 View."
"b.rank() = %d\n",
b.rank());
return 1;
}

// FIXME : check leading dimension is suppressed for now
// because of the compatibility issue with Trilinos
// const int lda = A.extent(0), n = A.extent(1);
// if (lda < Kokkos::max(1, n)) {
// Kokkos::printf(
// "KokkosBatched::trsv: leading dimension of A must not be smaller than "
// "max(1, n): "
// "lda = %d, n = %d\n",
// lda, n);
// return 1;
// }

#endif
return 0;
}
Expand All @@ -53,6 +73,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -88,12 +114,13 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -102,12 +129,13 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -118,6 +146,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -153,12 +187,13 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -167,12 +202,13 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -183,6 +219,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -218,12 +260,12 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -232,12 +274,12 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -248,6 +290,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -283,12 +331,13 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -297,12 +346,13 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -313,6 +363,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -348,12 +404,13 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -362,12 +419,13 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -378,6 +436,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -413,12 +477,12 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -427,12 +491,12 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand Down
Loading
Loading