Skip to content

Commit

Permalink
Improve batched serial trsm implementation and testing (#2432)
Browse files Browse the repository at this point in the history
* Use Trsv instead of Trsm if X is a rank 1 matrix

Signed-off-by: Yuuichi Asahi <[email protected]>

* Add missing specialization of Trsm serial implementation

Signed-off-by: Yuuichi Asahi <[email protected]>

* Add missing tests for serial Trsm

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: format

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: initialization order based on codeQL

Signed-off-by: Yuuichi Asahi <[email protected]>

* Allow trsm serial to work on rank 1 matrix

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: MKL interface of Serial Trsm

Signed-off-by: Yuuichi Asahi <[email protected]>

* simplify do_conj logic in trsm serial internal

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Dec 4, 2024
1 parent 18eaa5b commit 8def092
Show file tree
Hide file tree
Showing 8 changed files with 1,494 additions and 300 deletions.
650 changes: 581 additions & 69 deletions batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp

Large diffs are not rendered by default.

52 changes: 30 additions & 22 deletions batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define KOKKOSBATCHED_TRSM_SERIAL_INTERNAL_HPP

/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include "KokkosBatched_Util.hpp"

Expand All @@ -26,6 +27,7 @@
#include "KokkosBatched_InnerTrsm_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
Expand All @@ -34,34 +36,35 @@ namespace KokkosBatched {
template <typename AlgoType>
struct SerialTrsmInternalLeftLower {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1);
};

template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
const ScalarType one(1.0), zero(0.0);

if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

for (int p = 0; p < m; ++p) {
const int iend = m - p - 1, jend = n;

const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : NULL;
const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr;

ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : NULL;
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr;

if (!use_unit_diag) {
const ValueType alpha11 = A[p * as0 + p * as1];
const ValueType alpha11 =
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand All @@ -74,17 +77,20 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::i
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) B2[i * bs0 + j * bs1] -= a21[i * as0] * b1t[j * bs1];
for (int j = 0; j < jend; ++j)
B2[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1]
: a21[i * as0] * b1t[j * bs1]);
}
}

return 0;
}

template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
constexpr int mbAlgo = Algo::Trsm::Blocked::mb();

Expand All @@ -94,7 +100,6 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::inv
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

InnerTrsmLeftLowerUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1);
InnerTrsmLeftLowerNonUnitDiag<mbAlgo> trsm_n(as0, as1, bs0, bs1);
Expand Down Expand Up @@ -137,24 +142,24 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::inv
template <typename AlgoType>
struct SerialTrsmInternalLeftUpper {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1);
};

template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
const ScalarType one(1.0), zero(0.0);

if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

ValueType *KOKKOS_RESTRICT B0 = B;
for (int p = (m - 1); p >= 0; --p) {
Expand All @@ -164,7 +169,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0;

if (!use_unit_diag) {
const ValueType alpha11 = A[p * as0 + p * as1];
const ValueType alpha11 =
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand All @@ -179,7 +185,9 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= a01[i * as0] * b1t[j * bs1];
for (int j = 0; j < jend; ++j)
B0[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1]
: a01[i * as0] * b1t[j * bs1]);
}
}
}
Expand All @@ -189,8 +197,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
const ScalarType one(1.0), zero(0.0), minus_one(-1.0);

Expand All @@ -200,7 +208,6 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::inv
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

InnerTrsmLeftUpperUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1);
InnerTrsmLeftUpperNonUnitDiag<mbAlgo> trsm_n(as0, as1, bs0, bs1);
Expand Down Expand Up @@ -240,6 +247,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::inv
return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif
1 change: 1 addition & 0 deletions batched/dense/src/KokkosBatched_Trsm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define KOKKOSBATCHED_TRSM_DECL_HPP

/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"
Expand Down
45 changes: 45 additions & 0 deletions batched/dense/unit_test/Test_Batched_DenseUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,51 @@ void banded_to_full(InViewType& in, OutViewType& out, int k = 1) {
Kokkos::deep_copy(out, h_out);
}

/// \brief Create a triangular matrix from an input matrix:
/// Copies the input matrix into the upper/lower triangular of the output matrix specified
/// by the parameter k. Zero out elements below/above the k-th diagonal.
///
/// \tparam InViewType: Input type for the matrix, needs to be a 3D view
/// \tparam OutViewType: Output type for the matrix, needs to be a 3D view
/// \tparam UploType: Type indicating whether the matrix is upper or lower triangular
/// \tparam DiagType: Type indicating whether the matrix is unit or non-unit diagonal
///
/// \param in [in]: Input batched matrix, a rank 3 view
/// \param out [out]: Output batched matrix, where the upper or lower
/// triangular components are kept, a rank 3 view
/// \param k [in]: The diagonal offset to be zero out (default is 0).
///
template <typename InViewType, typename OutViewType, typename UploType, typename DiagType>
void create_triangular_matrix(InViewType& in, OutViewType& out, int k = 0) {
auto h_in = Kokkos::create_mirror_view(in);
auto h_out = Kokkos::create_mirror_view(out);
const int N = in.extent(0), BlkSize = in.extent(1);

Kokkos::deep_copy(h_in, in);
Kokkos::deep_copy(h_out, 0.0);
for (int i0 = 0; i0 < N; i0++) {
for (int i1 = 0; i1 < BlkSize; i1++) {
for (int i2 = 0; i2 < BlkSize; i2++) {
if constexpr (std::is_same_v<UploType, KokkosBatched::Uplo::Upper>) {
// Upper
// Zero out elements below the k-th diagonal
h_out(i0, i1, i2) = i2 < i1 + k ? 0.0 : h_in(i0, i1, i2);
} else {
// Lower
// Zero out elements above the k-th diagonal
h_out(i0, i1, i2) = i2 > i1 + k ? 0.0 : h_in(i0, i1, i2);
}
}

if constexpr (std::is_same_v<DiagType, KokkosBatched::Diag::Unit>) {
// Unit diagonal
h_out(i0, i1, i1) = 1.0;
}
}
}
Kokkos::deep_copy(out, h_out);
}

} // namespace KokkosBatched

#endif // TEST_BATCHED_DENSE_HELPER_HPP
Loading

0 comments on commit 8def092

Please sign in to comment.