From f6389142390ada5f385a7f123c295659a4b8182d Mon Sep 17 00:00:00 2001 From: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> Date: Thu, 9 Jan 2025 19:09:09 +0100 Subject: [PATCH] Improve batched serial gemm (#2469) * Add ConjTrans to Serial Gemm Signed-off-by: Yuuichi Asahi * improve checks in serial Gemm Signed-off-by: Yuuichi Asahi * improve selective interface of batched gemm Signed-off-by: Yuuichi Asahi * check info in serial gemm testing Signed-off-by: Yuuichi Asahi * fix: op type of serial invoke Signed-off-by: Yuuichi Asahi * format Signed-off-by: Yuuichi Asahi * remove the global namespace Signed-off-by: Yuuichi Asahi --------- Signed-off-by: Yuuichi Asahi Co-authored-by: Yuuichi Asahi --- .../impl/KokkosBatched_Gemm_Serial_Impl.hpp | 558 ++++++++++++++++-- .../KokkosBatched_Gemm_Serial_Internal.hpp | 48 +- .../impl/KokkosBatched_Gemm_Team_Internal.hpp | 4 +- ...okkosBatched_InnerGemmFixC_Serial_Impl.hpp | 492 +++++++-------- .../impl/KokkosBatched_LU_Serial_Internal.hpp | 5 +- batched/dense/src/KokkosBatched_Gemm_Decl.hpp | 6 +- .../src/KokkosBatched_InnerGemmFixC_Decl.hpp | 12 +- .../unit_test/Test_Batched_SerialGemm.hpp | 204 ++++--- .../Test_Batched_SerialGemm_Complex.hpp | 271 +++++++-- .../Test_Batched_SerialGemm_Real.hpp | 101 ++-- ...osBlas2_serial_gemv_inner_multiple_dot.hpp | 17 +- blas/impl/KokkosBlas_util.hpp | 16 + common/src/KokkosKernels_BlockUtils.hpp | 6 +- sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp | 12 +- .../impl/KokkosSparse_spmv_bsrmatrix_impl.hpp | 11 +- 15 files changed, 1218 insertions(+), 545 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp index fae44c8f83..266cc5bb33 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp @@ -16,22 +16,70 @@ #ifndef KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP #define KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP +#include "KokkosBlas_util.hpp" #include "KokkosBatched_Util.hpp" #include "KokkosBatched_Gemm_Serial_Internal.hpp" namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewType &A, + [[maybe_unused]] const BViewType &B, + [[maybe_unused]] const CViewType &C) { + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: AViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: BViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::gemm: CViewType is not a Kokkos::View."); + + static_assert(AViewType::rank <= 2, "KokkosBatched::gemm: AViewType must have rank 0, 1 or 2."); + static_assert(BViewType::rank <= 2, "KokkosBatched::gemm: BViewType must have rank 0, 1 or 2."); + static_assert(CViewType::rank <= 2, "KokkosBatched::gemm: CViewType must have rank 0, 1 or 2."); + +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + const int m = C.extent(0), n = C.extent(1); + const int lda = A.extent(0); + const int ldb = B.extent(0); + + const int ka = std::is_same_v ? A.extent(1) : A.extent(0); + const int kb = std::is_same_v ? B.extent(0) : B.extent(1); + + if (ka != kb) { + Kokkos::printf( + "KokkosBatched::gemm: Dimensions of A and B do not match: A: %d x %d, " + "B: %d x %d\n", + A.extent(0), A.extent(1), B.extent(0), B.extent(1)); + return 1; + } + + const int nrowa = std::is_same_v ? m : ka; + const int nrowb = std::is_same_v ? kb : n; + + if (lda < Kokkos::max(1, nrowa)) { + Kokkos::printf( + "KokkosBatched::gemm: leading dimension of A must not be smaller than " + "max(1, nrowa): " + "lda = %d, nrowa = %d\n", + lda, nrowa); + return 1; + } + if (ldb < Kokkos::max(1, nrowb)) { + Kokkos::printf( + "KokkosBatched::gemm: leading dimension of B must not be smaller than " + "max(1, nrowb): " + "ldb = %d, nrowb = %d\n", + ldb, nrowb); + return 1; + } + +#endif + + return 0; +} +} // namespace Impl + /// /// Serial Impl /// =========== -/// -/// Implemented: -/// NT/NT, T/NT, NT/T, T/T -/// -/// Not yet immplemented (ConjTranspose): -/// CT/NT, NT/CT, CT/CT -/// - /// /// NT/NT /// @@ -73,22 +121,36 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -132,22 +194,109 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_0(), - B.stride_1(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/NT +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_NOTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_NOTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B + // C (m x n), A(k x m), B(k x n) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -191,22 +340,36 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^T + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), - A.stride_0(), A.stride_1(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^T + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } /// @@ -250,23 +413,330 @@ template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } template <> template KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { - // C = beta C + alpha A B - // C (m x n), A(m x k), B(k x n) - return SerialGemmInternal::invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), - A.stride_1(), A.stride_0(), B.data(), B.stride_1(), - B.stride_0(), beta, C.data(), C.stride_0(), C.stride_1()); + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/T +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_TRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_TRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^T + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpID(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// NT/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^H + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(1); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A B^H + // C (m x n), A(m x k), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), + A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// T/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_1(), + (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), format, + (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), A.stride_0(), + (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), format, + (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); } + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^T B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +/// +/// C/C +/// + +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + typedef typename CViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) { + mkl_dgemm_compact(MKL_COL_MAJOR, MKL_CONJTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), + A.stride_1(), (const double *)B.data(), B.stride_1(), beta, (double *)C.data(), C.stride_1(), + format, (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) { + mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_CONJTRANS, MKL_CONJTRANS, m, n, k, alpha, (const double *)A.data(), + A.stride_0(), (const double *)B.data(), B.stride_0(), beta, (double *)C.data(), C.stride_0(), + format, (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; +} +#endif + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGemm::invoke( + const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { + // Quick return if possible + const int m = C.extent(0), n = C.extent(1), k = A.extent(0); + if (m == 0 || n == 0 || ((alpha == ScalarType(0) || k == 0) && beta == ScalarType(1))) return 0; + + auto info = KokkosBatched::Impl::checkGemmInput(A, B, C); + if (info) return info; + + // C = beta C + alpha A^H B^H + // C (m x n), A(k x m), B(n x k) + return KokkosBatched::Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpConj(), C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), + A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0), beta, C.data(), C.stride(0), C.stride(1)); +} + } // namespace KokkosBatched #endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp index 1a83a27112..09ce343fa6 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp @@ -26,6 +26,7 @@ #include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp" namespace KokkosBatched { +namespace Impl { /// /// Serial Internal Impl @@ -33,19 +34,20 @@ namespace KokkosBatched { template struct SerialGemmInternal { - template - KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, - const ScalarType beta, + template + KOKKOS_INLINE_FUNCTION static int invoke(OpA opA, OpB opB, const int m, const int n, const int k, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, + const int bs1, const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1); }; template <> -template +template KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta, + OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) @@ -58,17 +60,15 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1); if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return 0; - ValueType *KOKKOS_RESTRICT pC = C; for (int p = 0; p < k; ++p) { const ValueType *KOKKOS_RESTRICT pA = A + p * as1, *KOKKOS_RESTRICT pB = B + p * bs0; for (int i = 0; i < m; ++i) { - const ValueType tA(alpha * pA[i * as0]); + const ValueType tA(alpha * opA(pA[i * as0])); #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * pB[j * bs1]; + for (int j = 0; j < n; ++j) pC[i * cs0 + j * cs1] += tA * opB(pB[j * bs1]); } } } @@ -76,10 +76,11 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( } template <> -template +template KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( - const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, const ScalarType beta, + OpA opA, OpB opB, const int m, const int n, const int k, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) @@ -105,7 +106,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( const int mb = mbAlgo, nb = nbAlgo; for (int i = 0; i < ib; i += mb) for (int j = 0; j < jb; j += nb) - inner.serial_invoke(alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb, + inner.serial_invoke(opA, opB, alpha_value, AA + i * as0, BB + j * bs1, (i + mb) > ib ? (ib - i) : mb, (j + nb) > jb ? (jb - j) : nb, pb, CC + i * cs0 + j * cs1); }; @@ -138,6 +139,21 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal::invoke( return 0; } +} // namespace Impl + +template +struct [[deprecated("Use KokkosBatched::SerialGemm instead")]] SerialGemmInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int m, const int n, const int k, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + const ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1, + const ScalarType beta, + /**/ ValueType *KOKKOS_RESTRICT C, const int cs0, const int cs1) { + return Impl::SerialGemmInternal::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m, n, k, alpha, + A, as0, as1, B, bs0, bs1, beta, C, cs0, cs1); + } +}; + } // namespace KokkosBatched #endif diff --git a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp index b8647f5205..ff4882b548 100644 --- a/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp @@ -122,8 +122,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal::invoke( i = ij / nq * mb; j = ij % nq * nb; } - inner.serial_invoke(alpha, AA + i * as0, BB + j * bs1, (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, - CC + i * cs0 + j * cs1); + inner.serial_invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), alpha, AA + i * as0, BB + j * bs1, + (i + mb) > ib ? mp : mb, (j + nb) > jb ? np : nb, pb, CC + i * cs0 + j * cs1); }); }; diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp index e090ce57bd..b31bc895e2 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerGemmFixC_Serial_Impl.hpp @@ -28,8 +28,8 @@ namespace KokkosBatched { /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -47,16 +47,16 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; - b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); + b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -115,8 +115,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -133,15 +133,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -190,8 +190,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -207,14 +207,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -253,8 +253,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -269,13 +269,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -304,8 +304,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -319,12 +319,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; - a_4p = A[i4 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); + a_4p = opA(A[i4 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -343,8 +343,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -362,15 +362,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -419,8 +419,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -438,14 +438,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -484,8 +484,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -503,13 +503,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -538,8 +538,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -557,12 +557,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; - /**/ b_p4 = B[p * _bs0 + j4]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); + /**/ b_p4 = opB(B[p * _bs0 + j4]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -584,8 +584,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 5>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -603,14 +603,14 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; - b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); + b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -651,8 +651,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -668,13 +668,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -707,8 +707,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -723,12 +723,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -753,8 +753,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -769,11 +769,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; - a_3p = A[i3 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); + a_3p = opA(A[i3 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -790,8 +790,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -808,13 +808,13 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -847,8 +847,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -865,12 +865,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -895,8 +895,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -913,11 +913,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; - /**/ b_p3 = B[p * _bs0 + j3]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); + /**/ b_p3 = opB(B[p * _bs0 + j3]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -938,8 +938,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 4>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -955,12 +955,12 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; - b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); + b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -987,8 +987,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1003,11 +1003,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1028,8 +1028,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1043,10 +1043,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - a_2p = A[i2 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + a_2p = opA(A[i2 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1061,8 +1061,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1078,11 +1078,11 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1102,8 +1102,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 3>::serial_invoke(const ScalarType a return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1119,10 +1119,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /**/ b_p1 = B[p * _bs0 + j1]; - /**/ b_p2 = B[p * _bs0 + j2]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /**/ b_p1 = opB(B[p * _bs0 + j1]); + /**/ b_p2 = opB(B[p * _bs0 + j2]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1141,8 +1141,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 3>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1156,10 +1156,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; - b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); + b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; @@ -1176,8 +1176,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1191,9 +1191,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - a_1p = A[i1 + p * _as1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + a_1p = opA(A[i1 + p * _as1]); c_00 += a_0p * b_p0; c_10 += a_1p * b_p0; @@ -1206,8 +1206,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1221,9 +1221,9 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; - /* */ b_p1 = B[p * _bs0 + j1]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); + /* */ b_p1 = opB(B[p * _bs0 + j1]); c_00 += a_0p * b_p0; c_01 += a_0p * b_p1; } @@ -1239,8 +1239,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 2>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C) { @@ -1254,8 +1254,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a #pragma unroll #endif for (int p = 0; p < k; ++p) { - a_0p = A[i0 + p * _as1]; - b_p0 = B[p * _bs0 + j0]; + a_0p = opA(A[i0 + p * _as1]); + b_p0 = opB(B[p * _bs0 + j0]); c_00 += a_0p * b_p0; } C[0 * _cs0 + 0 * _cs1] += alpha * c_00; @@ -1264,8 +1264,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int k, @@ -1275,27 +1275,27 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType a switch (m) { case 5: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 4: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 3: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 2: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 1: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { @@ -1307,8 +1307,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<0, 1>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1319,52 +1319,52 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 55: { InnerGemmFixC<5, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 54: { InnerGemmFixC<5, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 53: { InnerGemmFixC<5, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 52: { InnerGemmFixC<5, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 51: { InnerGemmFixC<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 45: { InnerGemmFixC<4, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 35: { InnerGemmFixC<3, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 25: { InnerGemmFixC<2, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 15: { InnerGemmFixC<1, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1372,8 +1372,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<5, 5>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1384,42 +1384,42 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 44: { InnerGemmFixC<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 43: { InnerGemmFixC<4, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 42: { InnerGemmFixC<4, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 41: { InnerGemmFixC<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 34: { InnerGemmFixC<3, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 24: { InnerGemmFixC<2, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 14: { InnerGemmFixC<1, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1427,8 +1427,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<4, 4>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1439,32 +1439,32 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 33: { InnerGemmFixC<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 32: { InnerGemmFixC<3, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 31: { InnerGemmFixC<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 23: { InnerGemmFixC<2, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 13: { InnerGemmFixC<1, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } default: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, m, n, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, m, n, k, C); break; } } @@ -1472,8 +1472,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<3, 3>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1484,22 +1484,22 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a switch (m * 10 + n) { case 22: { InnerGemmFixC<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 21: { InnerGemmFixC<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 12: { InnerGemmFixC<1, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } case 11: { InnerGemmFixC<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, k, C); + inner.serial_invoke(opA, opB, alpha, A, B, k, C); break; } } @@ -1507,8 +1507,8 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<2, 2>::serial_invoke(const ScalarType a } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1516,7 +1516,7 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixC<1, 1>::serial_invoke(const ScalarType a if (m <= 0 || n <= 0 || k <= 0) return 0; if (!(m <= 1 && n <= 1)) Kokkos::abort("InnerGemmFixC<1,1>::serial_invoke, assert failure (m<=1 && n<=1)"); - return serial_invoke(alpha, A, B, k, C); + return serial_invoke(opA, opB, alpha, A, B, k, C); ; } diff --git a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp index 52002ad473..4b7166f0b9 100644 --- a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp @@ -115,8 +115,9 @@ KOKKOS_INLINE_FUNCTION int SerialLU_Internal::invoke( trsm_run.serial_invoke(Ap, pb, m_abr, Ap + mb * as0); // gemm update - SerialGemmInternal::invoke(m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, - Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); + Impl::SerialGemmInternal::invoke( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), m_abr, n_abr, pb, minus_one, Ap + mb * as0, as0, as1, + Ap + mb * as1, as0, as1, one, Ap + mb * as0 + mb * as1, as0, as1); } }; diff --git a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp index eabd5c42c2..1f3ba6095d 100644 --- a/batched/dense/src/KokkosBatched_Gemm_Decl.hpp +++ b/batched/dense/src/KokkosBatched_Gemm_Decl.hpp @@ -61,10 +61,12 @@ struct Gemm { KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { int r_val = 0; - if (std::is_same::value) { + if constexpr (std::is_same_v) { r_val = SerialGemm::invoke(alpha, A, B, beta, C); - } else if (std::is_same::value) { + } else if constexpr (std::is_same_v) { r_val = TeamGemm::invoke(member, alpha, A, B, beta, C); + } else if constexpr (std::is_same_v) { + r_val = TeamVectorGemm::invoke(member, alpha, A, B, beta, C); } return r_val; } diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp index 31ba2a03d9..ca55816fe4 100644 --- a/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp +++ b/batched/dense/src/KokkosBatched_InnerGemmFixC_Decl.hpp @@ -29,20 +29,20 @@ struct InnerGemmFixC { : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int k, /**/ ValueType *KOKKOS_RESTRICT C); // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int k, /**/ ValueType *KOKKOS_RESTRICT C); // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C); diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp index 0b2ed4a162..adc9a51aac 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm.hpp @@ -14,133 +14,133 @@ // //@HEADER /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "gtest/gtest.h" #include "Kokkos_Core.hpp" #include "Kokkos_Random.hpp" - -// #include "KokkosBatched_Vector.hpp" - #include "KokkosBatched_Gemm_Decl.hpp" #include "KokkosBatched_Gemm_Serial_Impl.hpp" #include "KokkosKernels_TestUtils.hpp" #include "KokkosKernels_TestVanilla.hpp" -using namespace KokkosBatched; - namespace Test { namespace Gemm { template struct ParamTag { - typedef TA transA; - typedef TB transB; + using transA = TA; + using transB = TB; }; template struct Functor_TestBatchedSerialGemm { using execution_space = typename DeviceType::execution_space; - ViewType _a, _b, _c; - - ScalarType _alpha, _beta; + ViewType m_a, m_b, m_c; + ScalarType m_alpha, m_beta; KOKKOS_INLINE_FUNCTION Functor_TestBatchedSerialGemm(const ScalarType alpha, const ViewType &a, const ViewType &b, const ScalarType beta, const ViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + : m_a(a), m_b(b), m_c(c), m_alpha(alpha), m_beta(beta) {} KOKKOS_INLINE_FUNCTION - void operator()(const ParamTagType &, const int k) const { - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - - SerialGemm::invoke(_alpha, aa, bb, _beta, - cc); + void operator()(const ParamTagType &, const int k, int &info) const { + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto cc = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL()); + + info += + KokkosBatched::SerialGemm::invoke( + m_alpha, aa, bb, m_beta, cc); } - inline void run() { - typedef typename ViewType::value_type value_type; + inline int run() { + using value_type = typename ViewType::non_const_value_type; std::string name_region("KokkosBatched::Test::SerialGemm"); const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; + int info_sum = 0; Kokkos::Profiling::pushRegion(name.c_str()); - Kokkos::RangePolicy policy(0, _c.extent(0)); - Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::RangePolicy policy(0, m_c.extent(0)); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); Kokkos::Profiling::popRegion(); + return info_sum; } }; -template +/// \brief Implementation details of batched gemm test +/// \param N [in] Batch size of matrices +/// \param matAdim1 [in] Number of rows of matrix A +/// \param matAdim2 [in] Number of columns of matrix A +/// \param matBdim1 [in] Number of rows of matrix B +/// \param matBdim2 [in] Number of columns of matrix B +/// \param matCdim1 [in] Number of rows of matrix C +/// \param matCdim2 [in] Number of columns of matrix C +template void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2, const int matCdim1, const int matCdim2) { using execution_space = typename DeviceType::execution_space; using transA = typename ParamTagType::transA; using transB = typename ParamTagType::transB; - using value_type = typename ViewType::value_type; - using ats = Kokkos::ArithTraits; + using ats = Kokkos::ArithTraits; + using ViewType = Kokkos::View; /// randomized input testing views ScalarType alpha = ScalarType(1.5); ScalarType beta = ScalarType(3.0); - ViewType a_expected("a_expected", N, matAdim1, matAdim2), a_actual("a_actual", N, matAdim1, matAdim2), - b_expected("b_expected", N, matBdim1, matBdim2), b_actual("b_actual", N, matBdim1, matBdim2), - c_expected("c_expected", N, matCdim1, matCdim2), c_actual("c_actual", N, matCdim1, matCdim2); + ViewType A("A", N, matAdim1, matAdim2), B("B", N, matBdim1, matBdim2), C("C", N, matCdim1, matCdim2), + C_ref("C_ref", N, matCdim1, matCdim2); - Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::Random_XorShift64_Pool rand_pool(13718); - Kokkos::fill_random(a_expected, random, value_type(1.0)); - Kokkos::fill_random(b_expected, random, value_type(1.0)); - Kokkos::fill_random(c_expected, random, value_type(1.0)); + ScalarType randStart, randEnd; + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::fill_random(B, rand_pool, randStart, randEnd); + Kokkos::fill_random(C, rand_pool, randStart, randEnd); - Kokkos::fence(); - - Kokkos::deep_copy(a_actual, a_expected); - Kokkos::deep_copy(b_actual, b_expected); - Kokkos::deep_copy(c_actual, c_expected); + Kokkos::deep_copy(C_ref, C); Functor_BatchedVanillaGEMM vgemm; - vgemm.A_t = std::is_same::value; - vgemm.B_t = std::is_same::value; - vgemm.A_c = vgemm.B_c = false; - vgemm.A = a_expected; - vgemm.B = b_expected; - vgemm.C = c_expected; - vgemm.alpha = alpha; - vgemm.beta = beta; - vgemm.run(); // Compute c_expected - Functor_TestBatchedSerialGemm(alpha, a_actual, b_actual, - beta, c_actual) - .run(); - - typename ViewType::HostMirror c_expected_host = Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c_actual_host = Kokkos::create_mirror_view(c_actual); - - // Copy to host for comparison - Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c_actual_host, c_actual); - - Kokkos::fence(); - - // check c_expected = c_actual - // std::conditional<, float, + vgemm.A_t = !std::is_same_v; + vgemm.B_t = !std::is_same_v; + vgemm.A_c = std::is_same_v; + vgemm.B_c = std::is_same_v; + vgemm.A = A; + vgemm.B = B; + vgemm.C = C_ref; + vgemm.alpha = alpha; + vgemm.beta = beta; + vgemm.run(); // Compute C_ref + + // Compute using gemm API + auto info = + Functor_TestBatchedSerialGemm(alpha, A, B, beta, C) + .run(); + EXPECT_EQ(info, 0); + + auto h_C = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C); + auto h_C_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), C_ref); + + // check C = C_ref using mag_type = typename ats::mag_type; mag_type sum(1), diff(0); mag_type eps = ats::epsilon(); - - eps *= std::is_same::value || - std::is_same::value + eps *= std::is_same_v || + std::is_same_v ? 4 : 1e3; for (int k = 0; k < N; ++k) for (int i = 0; i < matCdim1; ++i) for (int j = 0; j < matCdim2; ++j) { - sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); + sum += ats::abs(h_C_ref(k, i, j)); + diff += ats::abs(h_C_ref(k, i, j) - h_C(k, i, j)); } EXPECT_NEAR_KK(diff / sum, 0, eps); } @@ -151,37 +151,35 @@ template ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, 10, - 10, 10); + using LayoutType = Kokkos::LayoutLeft; + Test::Gemm::impl_test_batched_gemm( + 0, 10, 10, 10, 10, 10, 10); for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, i, i, - i, i); + Test::Gemm::impl_test_batched_gemm( + 1024, i, i, i, i, i, i); } for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); int dimM = i; int dimN = 2 * i; int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimN, dimK, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimN, dimK, dimM, dimN); } } @@ -189,37 +187,35 @@ int test_batched_gemm() { #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) { - typedef Kokkos::View ViewType; - Test::Gemm::impl_test_batched_gemm(0, 10, 10, 10, 10, - 10, 10); + using LayoutType = Kokkos::LayoutRight; + Test::Gemm::impl_test_batched_gemm( + 0, 10, 10, 10, 10, 10, 10); for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d\n", i); - Test::Gemm::impl_test_batched_gemm(1024, i, i, i, i, - i, i); + Test::Gemm::impl_test_batched_gemm( + 1024, i, i, i, i, i, i); } for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d\n", i); int dimM = i; int dimN = 2 * i; int dimK = 3 * i; - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimM, dimK, dimN, dimK, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimK, dimN, dimM, dimN); } - if ((std::is_same::value) && - (std::is_same::value)) { - Test::Gemm::impl_test_batched_gemm( + if ((!std::is_same_v)&&( + !std::is_same_v)) { + Test::Gemm::impl_test_batched_gemm( 1024, dimK, dimM, dimN, dimK, dimM, dimN); } } diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp index f785965602..ab97732238 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Complex.hpp @@ -13,72 +13,249 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //@HEADER +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) + +/// fcomplex, fcomplex + +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_fcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_fcomplex) { + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} + +/// fcomplex, float +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_fcomplex_float) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_fcomplex_float) { + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, float, param_tag_type, KokkosBatched::Algo::Gemm::Unblocked>(); +} + +#endif + #if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) /// dcomplex, dcomplex TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_dcomplex) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, Kokkos::complex, param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_dcomplex ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,Kokkos::complex,param_tag_type,algo_tag_type>(); -// } + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_dcomplex) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_dcomplex) { + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, Kokkos::complex, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} /// dcomplex, double - TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_nt_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_dcomplex_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm, double, param_tag_type, algo_tag_type>(); -} -// TEST_F( TestCategory, batched_scalar_serial_gemm_ct_nt_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } -// TEST_F( TestCategory, batched_scalar_serial_gemm_nt_ct_dcomplex_double ) { -// typedef ::Test::Gemm::ParamTag -// param_tag_type; typedef Algo::Gemm::Blocked algo_tag_type; -// test_batched_gemm,double,param_tag_type,algo_tag_type>(); -// } + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_t_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_nt_c_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_t_c_dcomplex_double) { + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} +TEST_F(TestCategory, batched_scalar_serial_gemm_c_c_dcomplex_double) { + using param_tag_type = + ::Test::Gemm::ParamTag; + test_batched_gemm, double, param_tag_type, KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm, double, param_tag_type, + KokkosBatched::Algo::Gemm::Unblocked>(); +} #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp index afe5744688..0192b61b0f 100644 --- a/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialGemm_Real.hpp @@ -15,112 +15,117 @@ //@HEADER #if defined(KOKKOS_BHALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_bhalf_bhalf) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Unblocked>(); } #endif // KOKKOS_BHALF_T_IS_FLOAT #if defined(KOKKOS_HALF_T_IS_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_half_half) { - typedef ::Test::Gemm::ParamTag param_tag_type; + using param_tag_type = ::Test::Gemm::ParamTag; - test_batched_gemm(); test_batched_gemm(); + KokkosBatched::Algo::Gemm::Blocked>(); + test_batched_gemm(); } #endif // KOKKOS_HALF_T_IS_FLOAT #if defined(KOKKOSKERNELS_INST_FLOAT) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_float_float) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif #if defined(KOKKOSKERNELS_INST_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_gemm_nt_nt_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } + TEST_F(TestCategory, batched_scalar_serial_gemm_t_nt_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_nt_t_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } TEST_F(TestCategory, batched_scalar_serial_gemm_t_t_double_double) { - typedef ::Test::Gemm::ParamTag param_tag_type; - typedef Algo::Gemm::Blocked algo_tag_type; - test_batched_gemm(); + using param_tag_type = ::Test::Gemm::ParamTag; + test_batched_gemm(); + test_batched_gemm(); } #endif diff --git a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp index 1a41ff4db3..2b9789cc02 100644 --- a/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp +++ b/blas/impl/KokkosBlas2_serial_gemv_inner_multiple_dot.hpp @@ -16,26 +16,13 @@ #ifndef KOKKOSBLAS_INNER_MULTIPLE_DOT_PRODUCT_SERIAL_IMPL_HPP #define KOKKOSBLAS_INNER_MULTIPLE_DOT_PRODUCT_SERIAL_IMPL_HPP +#include "KokkosBlas_util.hpp" + /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace KokkosBlas { namespace Impl { -struct OpID { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - return v; - } -}; - -struct OpConj { - template - KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { - using KAT = Kokkos::ArithTraits; - return KAT::conj(v); - } -}; - template struct InnerMultipleDotProduct { const int _as0, _as1, _xs0, _ys0; diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index c0777ac9ea..35661caef7 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -20,6 +20,22 @@ #include "Kokkos_ArithTraits.hpp" namespace KokkosBlas { +namespace Impl { +struct OpID { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + return v; + } +}; + +struct OpConj { + template + KOKKOS_INLINE_FUNCTION ValueType operator()(ValueType v) const { + using KAT = Kokkos::ArithTraits; + return KAT::conj(v); + } +}; +} // namespace Impl //////// Tags for BLAS //////// diff --git a/common/src/KokkosKernels_BlockUtils.hpp b/common/src/KokkosKernels_BlockUtils.hpp index 64309372ac..26a0baac67 100644 --- a/common/src/KokkosKernels_BlockUtils.hpp +++ b/common/src/KokkosKernels_BlockUtils.hpp @@ -52,13 +52,13 @@ KOKKOS_INLINE_FUNCTION void kk_block_add(const size_type block_dim, value_type * // Note: block is assumed to be row-major, dense matrix (no extra padding) // Note: set clear=true to set C = 0 before increment template > + typename DGEMM = KokkosBatched::Impl::SerialGemmInternal> KOKKOS_INLINE_FUNCTION void kk_block_dgemm(const size_type block_dim, value_type *dst, const value_type *valA, const value_type *valB, const bool clear = false) { const auto ZERO = static_cast(0); const auto ONE = static_cast(1); - DGEMM::invoke(block_dim, block_dim, block_dim, ONE, valA, block_dim, 1, valB, block_dim, 1, clear ? ZERO : ONE, dst, - block_dim, 1); + DGEMM::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), block_dim, block_dim, block_dim, ONE, valA, + block_dim, 1, valB, block_dim, 1, clear ? ZERO : ONE, dst, block_dim, 1); } // dgemm: C = A * B diff --git a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp index 98501a5814..b3c870e8cc 100644 --- a/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp +++ b/sparse/impl/KokkosSparse_bspgemm_impl_seq.hpp @@ -64,10 +64,10 @@ void bspgemm_debug_numeric(KernelHandle* /* handle */, typename KernelHandle::nn typename cscalar_nnz_view_t_::HostMirror h_valc = Kokkos::create_mirror_view(valuesC); Kokkos::fence(); - typedef typename KernelHandle::nnz_lno_t lno_t; - typedef typename KernelHandle::size_type size_type; - typedef typename KernelHandle::nnz_scalar_t scalar_t; - typedef KokkosBatched::SerialGemmInternal GEMM; + using lno_t = typename KernelHandle::nnz_lno_t; + using size_type = typename KernelHandle::size_type; + using scalar_t = typename KernelHandle::nnz_scalar_t; + using GEMM = KokkosBatched::Impl::SerialGemmInternal; const auto block_size = block_dim * block_dim; const auto ZERO = static_cast(0); @@ -106,8 +106,8 @@ void bspgemm_debug_numeric(KernelHandle* /* handle */, typename KernelHandle::nn } // accumulator(b_col) += a_val * b_val auto acc = get_block(accumulator, b_col, block_size); - GEMM::invoke(block_dim, block_dim, block_dim, ONE, a_val, block_dim, 1, b_val, block_dim, 1, ONE, acc.data(), - block_dim, 1); + GEMM::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), block_dim, block_dim, block_dim, ONE, a_val, + block_dim, 1, b_val, block_dim, 1, ONE, acc.data(), block_dim, 1); } } diff --git a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp index d9702af900..3fae741c94 100644 --- a/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp +++ b/sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp @@ -19,6 +19,7 @@ #include "KokkosKernels_Error.hpp" #include "KokkosKernels_ExecSpaceUtils.hpp" +#include "KokkosBlas_util.hpp" #if defined(KOKKOS_ENABLE_CUDA) && (defined(KOKKOS_ARCH_VOLTA) || defined(KOKKOS_ARCH_AMPERE)) @@ -1028,10 +1029,12 @@ struct BSR_GEMM_Functor { for (ordinal_type ic = 0; ic < count; ++ic) { const auto Aview = row.block(ic); const auto xstart = row.block_colidx(ic) * block_dim; - KokkosBatched::SerialGemmInternal::invoke( - static_cast(block_dim), static_cast(num_rhs), - static_cast(block_dim), alpha, Aview.data(), Aview.stride_0(), Aview.stride_1(), - &m_x(xstart, 0), m_x.stride_0(), ldx, beta1, &m_y(ystart, 0), m_y.stride_0(), ldy); + KokkosBatched::Impl::SerialGemmInternal::invoke< + KokkosBlas::Impl::OpID, KokkosBlas::Impl::OpID, value_type, value_type>( + KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), static_cast(block_dim), + static_cast(num_rhs), static_cast(block_dim), alpha, Aview.data(), + Aview.stride_0(), Aview.stride_1(), &m_x(xstart, 0), m_x.stride_0(), ldx, beta1, &m_y(ystart, 0), + m_y.stride_0(), ldy); } } }