-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor serial tbsv implementation details and tests
Signed-off-by: Yuuichi Asahi <[email protected]>
- Loading branch information
Yuuichi Asahi
committed
Jan 13, 2025
1 parent
4c33556
commit d28324c
Showing
6 changed files
with
260 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,16 +19,17 @@ | |
|
||
/// \author Yuuichi Asahi ([email protected]) | ||
|
||
#include "KokkosBlas_util.hpp" | ||
#include "KokkosBatched_Util.hpp" | ||
#include "KokkosBatched_Tbsv_Serial_Internal.hpp" | ||
|
||
namespace KokkosBatched { | ||
|
||
namespace Impl { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A, | ||
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) { | ||
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View."); | ||
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2."); | ||
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1."); | ||
|
||
|
@@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp | |
return 0; | ||
} | ||
|
||
} // namespace Impl | ||
|
||
//// Lower non-transpose //// | ||
template <typename ArgDiag> | ||
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke( | ||
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
} | ||
}; | ||
|
@@ -81,11 +84,12 @@ template <typename ArgDiag> | |
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), | ||
x.stride_0(), k); | ||
} | ||
}; | ||
|
||
|
@@ -94,11 +98,12 @@ template <typename ArgDiag> | |
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), | ||
x.stride_0(), k); | ||
} | ||
}; | ||
|
||
|
@@ -107,10 +112,10 @@ template <typename ArgDiag> | |
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke( | ||
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
} | ||
}; | ||
|
@@ -120,11 +125,12 @@ template <typename ArgDiag> | |
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), | ||
x.stride_0(), k); | ||
} | ||
}; | ||
|
||
|
@@ -133,11 +139,12 @@ template <typename ArgDiag> | |
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> { | ||
template <typename AViewType, typename XViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) { | ||
auto info = checkTbsvInput(A, x, k); | ||
auto info = Impl::checkTbsvInput(A, x, k); | ||
if (info) return info; | ||
|
||
return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k); | ||
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke( | ||
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), | ||
x.stride_0(), k); | ||
} | ||
}; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.