Skip to content

Commit

Permalink
Improve batched serial gemm (#2469)
Browse files Browse the repository at this point in the history
* Add ConjTrans to Serial Gemm

Signed-off-by: Yuuichi Asahi <[email protected]>

* improve checks in serial Gemm

Signed-off-by: Yuuichi Asahi <[email protected]>

* improve selective interface of batched gemm

Signed-off-by: Yuuichi Asahi <[email protected]>

* check info in serial gemm testing

Signed-off-by: Yuuichi Asahi <[email protected]>

* fix: op type of serial invoke

Signed-off-by: Yuuichi Asahi <[email protected]>

* format

Signed-off-by: Yuuichi Asahi <[email protected]>

* remove the global namespace

Signed-off-by: Yuuichi Asahi <[email protected]>

---------

Signed-off-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jan 9, 2025
1 parent 535c697 commit f638914
Show file tree
Hide file tree
Showing 15 changed files with 1,218 additions and 545 deletions.
558 changes: 514 additions & 44 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp

Large diffs are not rendered by default.

48 changes: 32 additions & 16 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,28 @@
#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

template <typename ArgAlgo>
struct SerialGemmInternal {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(OpA opA, OpB opB, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1);
};

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
Expand All @@ -58,28 +60,27 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

if (alpha != zero) {
if (m <= 0 || n <= 0 || k <= 0) return 0;

ValueType *KOKKOS_RESTRICT pC = C;
for (int p = 0; p < k; ++p) {
const ValueType *KOKKOS_RESTRICT pA = A + p * as1, *KOKKOS_RESTRICT pB = B + p * bs0;
for (int i = 0; i < m; ++i) {
const ValueType tA(alpha * pA[i * as0]);
const ValueType tA(alpha * opA(pA[i * as0]));
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1];
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]);
}
}
}
return 0;
}

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
Expand All @@ -105,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int mb = mbAlgo, nb = nbAlgo;
for (int i = 0; i < ib; i += mb)
for (int j = 0; j < jb; j += nb)
inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
(j + nb) > jb ? (jb - j) : nb, pb, CC + i * cs0 + j * cs1);
};

Expand Down Expand Up @@ -138,6 +139,21 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
return 0;
}

} // namespace Impl

template <typename ArgAlgo>
struct [[deprecated("Use KokkosBatched::SerialGemm instead")]] SerialGemmInternal {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
return Impl::SerialGemmInternal<ArgAlgo>::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m, n, k, alpha,
A, as0, as1, B, bs0, bs1, beta, C, cs0, cs1);
}
};

} // namespace KokkosBatched

#endif
4 changes: 2 additions & 2 deletions batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
i = ij / nq * mb;
j = ij % nq * nb;
}
inner.serial_invoke(alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb,
CC + i * cs0 + j * cs1);
inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), alpha, AA + i * as0, BB + j * bs1,
(i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1);
});
};

Expand Down
Loading

0 comments on commit f638914

Please sign in to comment.