Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batched serial gemm #2469

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add ConjTrans to Serial Gemm
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
  • Loading branch information
Yuuichi Asahi committed Jan 7, 2025
commit 0e5d2ca224f365d360f3f14278772e083d784004
551 changes: 507 additions & 44 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp

Large diffs are not rendered by default.

34 changes: 18 additions & 16 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -26,26 +26,28 @@
#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

template <typename ArgAlgo>
struct SerialGemmInternal {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(OpA opA, OpB opB, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1);
};

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
@@ -58,28 +60,27 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

if (alpha != zero) {
if (m <= 0 || n <= 0 || k <= 0) return 0;

ValueType *KOKKOS_RESTRICT pC = C;
for (int p = 0; p < k; ++p) {
const ValueType *KOKKOS_RESTRICT pA = A + p * as1, *KOKKOS_RESTRICT pB = B + p * bs0;
for (int i = 0; i < m; ++i) {
const ValueType tA(alpha * pA[i * as0]);
const ValueType tA(alpha * opA(pA[i * as0]));
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1];
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]);
}
}
}
return 0;
}

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
@@ -105,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int mb = mbAlgo, nb = nbAlgo;
for (int i = 0; i < ib; i += mb)
for (int j = 0; j < jb; j += nb)
inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
(j + nb) > jb ? (jb - j) : nb, pb, CC + i * cs0 + j * cs1);
};

@@ -138,6 +139,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif
4 changes: 2 additions & 2 deletions batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
@@ -122,8 +122,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
i = ij / nq * mb;
j = ij % nq * nb;
}
inner.serial_invoke(alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb,
CC + i * cs0 + j * cs1);
inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), alpha, AA + i * as0, BB + j * bs1,
(i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1);
});
};

Loading