Skip to content

Commit

Permalink
refactor serial tbsv implementation details and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Jan 13, 2025
1 parent 4c33556 commit d28324c
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 175 deletions.
11 changes: 7 additions & 4 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

Expand Down Expand Up @@ -50,8 +51,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::inv
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

// Solve L**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

return 0;
}
Expand All @@ -76,8 +78,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::inv
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve U**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

// Solve U*X = B, overwriting B with X.
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);
Expand Down
45 changes: 26 additions & 19 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

/// \author Yuuichi Asahi ([email protected])

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

namespace KokkosBatched {

namespace Impl {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2.");
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1.");

Expand Down Expand Up @@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
return 0;
}

} // namespace Impl

//// Lower non-transpose ////
template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -81,11 +84,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -94,11 +98,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -107,10 +112,10 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -120,11 +125,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -133,11 +139,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand Down
58 changes: 18 additions & 40 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {

namespace Impl {
///
/// Serial Internal Impl
/// ====================

///
/// Lower, Non-Transpose
/// Lower
///

template <typename AlgoType>
Expand Down Expand Up @@ -70,49 +70,37 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalLowerTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = an - 1; j >= 0; --j) {
auto temp = x[j * xs0];

if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= A[(i - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[0 + j * as1];
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= op(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

///
/// Upper, Non-Transpose
/// Upper
///

template <typename AlgoType>
Expand Down Expand Up @@ -154,46 +142,36 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalUpperTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < an; j++) {
auto temp = x[j * xs0];
if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[k * as0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= A[(i + k - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[k * as0 + j * as1];
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= op(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[k * as0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_TBSV_SERIAL_INTERNAL_HPP_
Loading

0 comments on commit d28324c

Please sign in to comment.