-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve batched serial trsm implementation and testing (#2432)
* Use Trsv instead of Trsm if X is a rank 1 matrix Signed-off-by: Yuuichi Asahi <[email protected]> * Add missing specialization of Trsm serial implementation Signed-off-by: Yuuichi Asahi <[email protected]> * Add missing tests for serial Trsm Signed-off-by: Yuuichi Asahi <[email protected]> * fix: format Signed-off-by: Yuuichi Asahi <[email protected]> * fix: initialization order based on codeQL Signed-off-by: Yuuichi Asahi <[email protected]> * Allow trsm serial to work on rank 1 matrix Signed-off-by: Yuuichi Asahi <[email protected]> * fix: MKL interface of Serial Trsm Signed-off-by: Yuuichi Asahi <[email protected]> * simplify do_conj logic in trsm serial internal Signed-off-by: Yuuichi Asahi <[email protected]> --------- Signed-off-by: Yuuichi Asahi <[email protected]> Co-authored-by: Yuuichi Asahi <[email protected]>
- Loading branch information
1 parent
18eaa5b
commit 8def092
Showing
8 changed files
with
1,494 additions
and
300 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#define KOKKOSBATCHED_TRSM_SERIAL_INTERNAL_HPP | ||
|
||
/// \author Kyungjoo Kim ([email protected]) | ||
/// \author Yuuichi Asahi ([email protected]) | ||
|
||
#include "KokkosBatched_Util.hpp" | ||
|
||
|
@@ -26,6 +27,7 @@ | |
#include "KokkosBatched_InnerTrsm_Serial_Impl.hpp" | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
|
||
/// | ||
/// Serial Internal Impl | ||
|
@@ -34,34 +36,35 @@ namespace KokkosBatched { | |
template <typename AlgoType> | ||
struct SerialTrsmInternalLeftLower { | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n, | ||
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, | ||
const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); | ||
}; | ||
|
||
template <> | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke( | ||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | ||
const int as0, const int as1, | ||
const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { | ||
const ScalarType one(1.0), zero(0.0); | ||
|
||
if (alpha == zero) | ||
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); | ||
else { | ||
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); | ||
if (m <= 0 || n <= 0) return 0; | ||
|
||
for (int p = 0; p < m; ++p) { | ||
const int iend = m - p - 1, jend = n; | ||
|
||
const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : NULL; | ||
const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr; | ||
|
||
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : NULL; | ||
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr; | ||
|
||
if (!use_unit_diag) { | ||
const ValueType alpha11 = A[p * as0 + p * as1]; | ||
const ValueType alpha11 = | ||
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]); | ||
|
||
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) | ||
#pragma unroll | ||
|
@@ -74,17 +77,20 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::i | |
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) | ||
#pragma unroll | ||
#endif | ||
for (int j = 0; j < jend; ++j) B2[i * bs0 + j * bs1] -= a21[i * as0] * b1t[j * bs1]; | ||
for (int j = 0; j < jend; ++j) | ||
B2[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1] | ||
: a21[i * as0] * b1t[j * bs1]); | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
template <> | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke( | ||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | ||
const int as0, const int as1, | ||
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { | ||
constexpr int mbAlgo = Algo::Trsm::Blocked::mb(); | ||
|
||
|
@@ -94,7 +100,6 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::inv | |
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); | ||
else { | ||
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); | ||
if (m <= 0 || n <= 0) return 0; | ||
|
||
InnerTrsmLeftLowerUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1); | ||
InnerTrsmLeftLowerNonUnitDiag<mbAlgo> trsm_n(as0, as1, bs0, bs1); | ||
|
@@ -137,24 +142,24 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::inv | |
template <typename AlgoType> | ||
struct SerialTrsmInternalLeftUpper { | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n, | ||
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, | ||
const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); | ||
}; | ||
|
||
template <> | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke( | ||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | ||
const int as0, const int as1, | ||
const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { | ||
const ScalarType one(1.0), zero(0.0); | ||
|
||
if (alpha == zero) | ||
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); | ||
else { | ||
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); | ||
if (m <= 0 || n <= 0) return 0; | ||
|
||
ValueType *KOKKOS_RESTRICT B0 = B; | ||
for (int p = (m - 1); p >= 0; --p) { | ||
|
@@ -164,7 +169,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i | |
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0; | ||
|
||
if (!use_unit_diag) { | ||
const ValueType alpha11 = A[p * as0 + p * as1]; | ||
const ValueType alpha11 = | ||
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]); | ||
|
||
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) | ||
#pragma unroll | ||
|
@@ -179,7 +185,9 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i | |
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) | ||
#pragma unroll | ||
#endif | ||
for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= a01[i * as0] * b1t[j * bs1]; | ||
for (int j = 0; j < jend; ++j) | ||
B0[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1] | ||
: a01[i * as0] * b1t[j * bs1]); | ||
} | ||
} | ||
} | ||
|
@@ -189,8 +197,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i | |
template <> | ||
template <typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke( | ||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | ||
const int as0, const int as1, | ||
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, | ||
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { | ||
const ScalarType one(1.0), zero(0.0), minus_one(-1.0); | ||
|
||
|
@@ -200,7 +208,6 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::inv | |
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); | ||
else { | ||
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); | ||
if (m <= 0 || n <= 0) return 0; | ||
|
||
InnerTrsmLeftUpperUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1); | ||
InnerTrsmLeftUpperNonUnitDiag<mbAlgo> trsm_n(as0, as1, bs0, bs1); | ||
|
@@ -240,6 +247,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::inv | |
return 0; | ||
} | ||
|
||
} // namespace Impl | ||
} // namespace KokkosBatched | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#define KOKKOSBATCHED_TRSM_DECL_HPP | ||
|
||
/// \author Kyungjoo Kim ([email protected]) | ||
/// \author Yuuichi Asahi ([email protected]) | ||
|
||
#include "KokkosBatched_Util.hpp" | ||
#include "KokkosBatched_Vector.hpp" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.