Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batched serial gemm #2469

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
558 changes: 514 additions & 44 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp

Large diffs are not rendered by default.

48 changes: 32 additions & 16 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,28 @@
#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

template <typename ArgAlgo>
struct SerialGemmInternal {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(OpA opA, OpB opB, const int m, const int n, const int k,
const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0,
const int bs1, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1);
};

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
Expand All @@ -58,28 +60,27 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

if (alpha != zero) {
if (m <= 0 || n <= 0 || k <= 0) return 0;

ValueType *KOKKOS_RESTRICT pC = C;
for (int p = 0; p < k; ++p) {
const ValueType *KOKKOS_RESTRICT pA = A + p * as1, *KOKKOS_RESTRICT pB = B + p * bs0;
for (int i = 0; i < m; ++i) {
const ValueType tA(alpha * pA[i * as0]);
const ValueType tA(alpha * opA(pA[i * as0]));
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1];
for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]);
}
}
}
return 0;
}

template <>
template <typename ScalarType, typename ValueType>
template <typename OpA, typename OpB, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta,
OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
Expand All @@ -105,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const int mb = mbAlgo, nb = nbAlgo;
for (int i = 0; i < ib; i += mb)
for (int j = 0; j < jb; j += nb)
inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb,
(j + nb) > jb ? (jb - j) : nb, pb, CC + i * cs0 + j * cs1);
};

Expand Down Expand Up @@ -138,6 +139,21 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
return 0;
}

} // namespace Impl

template <typename ArgAlgo>
struct [[deprecated("Use KokkosBatched::SerialGemm instead")]] SerialGemmInternal {
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1,
const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) {
return Impl::SerialGemmInternal<ArgAlgo>::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m, n, k, alpha,
A, as0, as1, B, bs0, bs1, beta, C, cs0, cs1);
}
};

} // namespace KokkosBatched

#endif
4 changes: 2 additions & 2 deletions batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
i = ij / nq * mb;
j = ij % nq * nb;
}
inner.serial_invoke(alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb,
CC + i * cs0 + j * cs1);
inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), alpha, AA + i * as0, BB + j * bs1,
(i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1);
});
};

Expand Down
Loading
Loading