diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index c1e012683a0..30018138e0a 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2021 Intel Corporation +# Copyright 2021-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,6 +66,21 @@ else() endif() message(STATUS "Enabled primitive GPU ISA: ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}") +if (ONEDNN_ENABLE_GEMM_KERNELS_ISA STREQUAL "ALL") + set(BUILD_GEMM_KERNELS_ALL TRUE) +elseif (ONEDNN_ENABLE_GEMM_KERNELS_ISA STREQUAL "NONE") + set(BUILD_GEMM_KERNELS_NONE TRUE) +else() + foreach(isa ${ONEDNN_ENABLE_GEMM_KERNELS_ISA}) + string(TOUPPER ${isa} uisa) + if(NOT "${uisa}" MATCHES "^(SSE41|AVX2|AVX512)$") + message(FATAL_ERROR "Unsupported primitive CPU ISA: ${uisa}") + endif() + set(BUILD_GEMM_${uisa} TRUE) + endforeach() +endif() +message(STATUS "Enabled GeMM kernels ISA: ${ONEDNN_ENABLE_GEMM_KERNELS_ISA}") + # When certain primitives or primitive ISA are switched off, some functions may # become unused which is expected. Switch off warning for unused functions in # such cases. diff --git a/cmake/dnnl_compat.cmake b/cmake/dnnl_compat.cmake index 27e136afc37..c600637ed13 100644 --- a/cmake/dnnl_compat.cmake +++ b/cmake/dnnl_compat.cmake @@ -61,6 +61,8 @@ set(COMPAT_CACHE_STRING_VARS "LIBRARY_NAME" "ENABLE_WORKLOAD" "ENABLE_PRIMITIVE" + "ENABLE_PRIMITIVE_CPU_ISA" + "ENABLE_PRIMITIVE_GPU_ISA" "ARCH_OPT_FLAGS" "CPU_RUNTIME" "GPU_RUNTIME" diff --git a/cmake/options.cmake b/cmake/options.cmake index 2db3de7439d..6b295049b96 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -149,6 +149,16 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING - ;;... Includes only selected ISA to be enabled. Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC.") +set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING + "Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be + available at build time. Valid values: + - ALL (the default). Includes all ISA kernels to be enabled. + - NONE. Removes all kernels and interfaces. + - . Enables all ISA up to ISA_NAME included. + Possible value are: SSE41, AVX2, AVX512. The linear order is + SSE41 < AVX2 < AVX512 < AMX (or ALL). It means that if user selects, e.g. + AVX2 ISA, SSE41 kernels will also present at build time.") + # ============= # Optimizations # ============= diff --git a/doc/build/build_options.md b/doc/build/build_options.md index ca8f6630872..d00e29977f3 100644 --- a/doc/build/build_options.md +++ b/doc/build/build_options.md @@ -24,6 +24,7 @@ oneDNN supports the following build-time options. | ONEDNN_ENABLE_PRIMITIVE | **ALL**, PRIMITIVE_NAME | Specifies a set of functionality to be available based on primitives | | ONEDNN_ENABLE_PRIMITIVE_CPU_ISA | **ALL**, CPU_ISA_NAME | Specifies a set of functionality to be available for CPU backend based on CPU ISA | | ONEDNN_ENABLE_PRIMITIVE_GPU_ISA | **ALL**, GPU_ISA_NAME | Specifies a set of functionality to be available for GPU backend based on GPU ISA | +| ONEDNN_ENABLE_GEMM_KERNELS_ISA | **ALL**, NONE, ISA_NAME | Specifies a set of functionality to be available for GeMM kernels for CPU backend based on ISA | | ONEDNN_EXPERIMENTAL | ON, **OFF** | Enables [experimental features](@ref dev_guide_experimental) | | ONEDNN_VERBOSE | **ON**, OFF | Enables [verbose mode](@ref dev_guide_verbose) | | ONEDNN_AARCH64_USE_ACL | ON, **OFF** | Enables integration with Arm Compute Library for AArch64 builds | @@ -109,6 +110,17 @@ always be available. Example that enables XeLP and XeHP set: -DONEDNN_ENABLE_PRIMITIVE_GPU_ISA=XELP;XEHP ``` +#### ONEDNN_ENABLE_GEMM_KERNELS_ISA +This option supports several values: `ALL` (the default) which enables all +ISA kernels from x64/gemm folder, `NONE` which disables all kernels and removes +correspondent interfaces, or one of `SSE41`, `AVX2`, and `AVX512`. Values are +linearly ordered as `SSE41` < `AVX2` < `AVX512`. When specified, selected ISA +and all ISA that are "smaller" will be available. Example that leaves SSE41 and +AVX2 sets, but removes AVX512 and AMX kernels: +``` +-DONEDNN_ENABLE_GEMM_KERNELS_ISA=AVX2 +``` + ## CPU Options Intel Architecture Processors and compatible devices are supported by oneDNN CPU engine. The CPU engine is built by default but can be disabled diff --git a/include/oneapi/dnnl/dnnl_config.h.in b/include/oneapi/dnnl/dnnl_config.h.in index 7d8536ac641..5fb44873d0d 100644 --- a/include/oneapi/dnnl/dnnl_config.h.in +++ b/include/oneapi/dnnl/dnnl_config.h.in @@ -193,4 +193,10 @@ #cmakedefine01 BUILD_XEHP #cmakedefine01 BUILD_XEHPG #cmakedefine01 BUILD_XEHPC +// GeMM kernels ISA controls +#cmakedefine01 BUILD_GEMM_KERNELS_ALL +#cmakedefine01 BUILD_GEMM_KERNELS_NONE +#cmakedefine01 BUILD_GEMM_SSE41 +#cmakedefine01 BUILD_GEMM_AVX2 +#cmakedefine01 BUILD_GEMM_AVX512 #endif diff --git a/src/cpu/gemm/bf16/ref_gemm_bf16.cpp b/src/cpu/gemm/bf16/ref_gemm_bf16.cpp new file mode 100644 index 00000000000..c4c5404979d --- /dev/null +++ b/src/cpu/gemm/bf16/ref_gemm_bf16.cpp @@ -0,0 +1,327 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dnnl/dnnl_types.h" + +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" + +#include "cpu/platform.hpp" + +#include "cpu/gemm/bf16/ref_gemm_bf16.hpp" +#include "cpu/gemm/f32/gemm_utils_f32.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +using namespace dnnl::impl::utils; +using namespace gemm_utils; + +namespace { + +void copy_A(bool isTransA, dim_t K, const bfloat16_t *A, const dim_t lda, + bfloat16_t *ws) { + for (dim_t k = 0; k < K; k++) { + PRAGMA_OMP_SIMD() + for (dim_t i = 0; i < unroll_factor::m; i++) { + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; + } + ws += unroll_factor::m; + } +} + +template +void kernel_mxn(dim_t K, const bfloat16_t *A, const dim_t lda, + const bfloat16_t *B, const dim_t ldb, float *C, const dim_t ldc, + const float alpha, const float beta) { + float c[unroll_factor::m * unroll_factor::n] + = {0.f}; + for (dim_t k = 0; k < K; k++) { + for (dim_t j = 0; j < unroll_factor::n; j++) { + bfloat16_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + PRAGMA_OMP_SIMD() + for (dim_t i = 0; i < unroll_factor::m; i++) { + bfloat16_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor::m * j] += a * b; + } + } + } + for (dim_t j = 0; j < unroll_factor::n; j++) { + PRAGMA_OMP_SIMD() + for (dim_t i = 0; i < unroll_factor::m; i++) { + C[i + j * ldc] = (beta == 0.f) + ? alpha * c[i + unroll_factor::m * j] + : alpha * c[i + unroll_factor::m * j] + + beta * C[i + j * ldc]; + } + } +} + +template +void block_ker(const dim_t M, const dim_t N, const dim_t K, const bfloat16_t *A, + const dim_t lda, const bfloat16_t *B, const dim_t ldb, float *C, + const dim_t ldc, const float alpha, const float beta, bfloat16_t *ws, + bool do_copy) { + dim_t Nu = rnd_dn(N, unroll_factor::n); + dim_t Mu = rnd_dn(M, unroll_factor::m); + for (dim_t i = 0; i < Mu; i += unroll_factor::m) { + for (dim_t j = 0; j < Nu; j += unroll_factor::n) { + const bfloat16_t *b = isTransB ? &B[j] : &B[j * ldb]; + const bfloat16_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { copy_A(isTransA, K, a, lda, ws); } + kernel_mxn(K, ws, unroll_factor::m, + b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } else { + kernel_mxn( + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } + } + } + // tail processing + for (dim_t i = 0; i < M; i++) { + for (dim_t j = Nu; j < N; j++) { + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; + for (dim_t p = 0; p < K; p++) { + bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (dim_t i = Mu; i < M; i++) { + for (dim_t j = 0; j < Nu; j++) { + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; + for (dim_t p = 0; p < K; p++) { + bfloat16_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + bfloat16_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template +void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const float alpha, + const bfloat16_t *A, const dim_t lda, const bfloat16_t *B, + const dim_t ldb, const float beta, float *C, const dim_t ldc, + bool do_copy, bfloat16_t *ws) { + constexpr dim_t BM = gemm_traits::BM; + constexpr dim_t BN = gemm_traits::BN; + constexpr dim_t BK = gemm_traits::BK; + + const bfloat16_t *curA; + const bfloat16_t *curB; + float *curC; + + if ((M <= 0) || (N <= 0)) return; + + if ((K <= 0) || (alpha == 0.f)) { + dim_t MN = N * M; + if (beta == 0.f) { + for (dim_t j = 0; j < MN; j++) + C[j] = 0.f; + } else if (beta != 1.f) { + for (dim_t j = 0; j < MN; j++) + C[j] *= beta; + } + return; + } + + for (dim_t Bk = 0; Bk < K; Bk += BK) { + dim_t kb = nstl::min(K - Bk, BK); + for (dim_t Bm = 0; Bm < M; Bm += BM) { + dim_t mb = nstl::min(M - Bm, BM); + for (dim_t Bn = 0; Bn < N; Bn += BN) { + dim_t nb = nstl::min(N - Bn, BN); + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; + curC = C + Bm + Bn * ldc; + if (Bk == 0) { + block_ker(mb, nb, kb, curA, lda, curB, + ldb, curC, ldc, alpha, beta, ws, do_copy); + } else { + block_ker(mb, nb, kb, curA, lda, curB, + ldb, curC, ldc, alpha, 1.f, ws, do_copy); + } + } + } + } +} + +} // namespace + +dnnl_status_t ref_gemm_bf16bf16f32(const char *transa_, const char *transb_, + const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_, + const bfloat16_t *A, const dim_t *lda_, const bfloat16_t *B, + const dim_t *ldb_, const float *beta_, float *C, const dim_t *ldc_) { + + if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T') + && utils::one_of(*transb_, 'n', 'N', 't', 'T'))) + return dnnl_unimplemented; + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const dim_t M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const float alpha = *alpha_, beta = *beta_; + + // early out and avoid division by zero in partitioning + if (utils::one_of(0, M, N)) return dnnl_success; + + int max_nthr = dnnl_get_current_num_threads(); + int nthr_m, nthr_n, nthr_k; + dim_t MB, NB, KB; + // thread balancing over M, N, K & size of blocking dimensions + calc_nthr_nocopy_avx( + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!dnnl_thr_syncable(), nthr_k == 1)); + + float *c_buffers = nullptr; + bfloat16_t *ws_buffers = nullptr; + if (nthr_k > 1) { + c_buffers = (float *)malloc( + sizeof(*c_buffers) * nthr_m * nthr_n * (nthr_k - 1) * MB * NB, + PAGE_4K); + if (!c_buffers) { + nthr_k = 1; + KB = K; + } + } + + bool do_copy = (NB / unroll_factor::n > 3); + const int nthr_mn = nthr_m * nthr_n; + const int nthr_to_use = nthr_mn * nthr_k; + const size_t ws_elems_per_thr = K * unroll_factor::m; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (do_copy) { + ws_buffers + = (bfloat16_t *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K); + if (!ws_buffers) do_copy = false; + } + + auto get_thr_block = [&](dim_t &from, dim_t &to, dim_t &myN, dim_t NB, + dim_t N, int ithr) { + from = NB * (ithr); + to = NB * (ithr + 1); + if (to > N) to = N; + myN = to - from; + }; + + parallel(nthr_to_use, [&](int ithr, int nthr) { + assert(nthr_to_use == nthr); + MAYBE_UNUSED(nthr); + + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_n = ithr_mn / nthr_m; + int ithr_k = ithr / nthr_mn; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + bfloat16_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(float) + : nullptr; + + dim_t m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, + k_from = 0, k_to = 0, myK = 0; + + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); + + if (myM > 0 && myN > 0) { + float myBeta, *myC; + dim_t ld; + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + } else { + myC = c_buffers + MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0f; + ld = MB; + } + const bfloat16_t *myA = isTransA ? &(A[k_from + m_from * lda]) + : &(A[m_from + k_from * lda]); + const bfloat16_t *myB = isTransB ? &(B[n_from + k_from * ldb]) + : &(B[k_from + n_from * ldb]); + + if (!isTransA) { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, lda, myB, + ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, lda, myB, + ldb, myBeta, myC, ld, do_copy, ws); + } + } else { + if (!isTransB) { + gemm_ithr(myM, myN, myK, alpha, myA, lda, myB, + ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr(myM, myN, myK, alpha, myA, lda, myB, + ldb, myBeta, myC, ld, do_copy, ws); + } + } + } + }); + + if (nthr_k > 1) { + parallel(nthr_to_use, [&](int ithr, int nthr) { + assert(nthr_to_use == nthr); + MAYBE_UNUSED(nthr); + + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_k = ithr / nthr_mn; + int ithr_n = ithr_mn / nthr_m; + + dim_t n_from = 0, n_to = 0, myN = 0; + dim_t m_from = 0, m_to = 0, myM = 0; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + + // sum matrices partitioned along K dimension + dim_t offset = 0, block = 0; + partition_unit_diff(ithr_k, nthr_k, myN, &offset, &block); + for (int ik = 1; ik < nthr_k; ++ik) { + float *myC = c_buffers + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); + + gemm_utils::sum_two_matrices(myM, block, myC, MB, + &C[m_from + (n_from + offset) * ldc], ldc); + } + }); + } + + free(ws_buffers); + free(c_buffers); + + return dnnl_success; +} + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/gemm/bf16/ref_gemm_bf16.hpp b/src/cpu/gemm/bf16/ref_gemm_bf16.hpp new file mode 100644 index 00000000000..a80cd5a08ed --- /dev/null +++ b/src/cpu/gemm/bf16/ref_gemm_bf16.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_GEMM_BF16_REF_GEMM_BF16_HPP +#define CPU_GEMM_BF16_REF_GEMM_BF16_HPP + +#include "oneapi/dnnl/dnnl_types.h" + +#include "common/c_types_map.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +dnnl_status_t ref_gemm_bf16bf16f32(const char *transa, const char *transb, + const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, + const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, + const dim_t *ldb, const float *beta, float *C, const dim_t *ldc); + +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_GEMM_F32_REF_GEMM_F32_HPP diff --git a/src/cpu/gemm/f32/gemm_utils_f32.hpp b/src/cpu/gemm/f32/gemm_utils_f32.hpp index 36bad397b32..91b1dedbaed 100644 --- a/src/cpu/gemm/f32/gemm_utils_f32.hpp +++ b/src/cpu/gemm/f32/gemm_utils_f32.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2020 Intel Corporation +* Copyright 2018-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,6 +45,15 @@ struct gemm_traits { static constexpr dim_t BK = isTransB ? 96 : 256; }; +template +struct gemm_traits { + static constexpr dim_t m = 32; + static constexpr dim_t n = 6; + static constexpr dim_t BM = 4032; + static constexpr dim_t BN = isTransA ? 96 : 48; + static constexpr dim_t BK = isTransB ? 96 : 256; +}; + template using unroll_factor = gemm_traits; diff --git a/src/cpu/gemm/gemm.cpp b/src/cpu/gemm/gemm.cpp index 7c6806bcc10..1f51a23c054 100644 --- a/src/cpu/gemm/gemm.cpp +++ b/src/cpu/gemm/gemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2022 Intel Corporation +* Copyright 2018-2023 Intel Corporation * Copyright 2022 IBM Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,7 @@ #include "cpu/gemm/gemm_msan_unpoison.hpp" #include "cpu/gemm/os_blas.hpp" +#include "cpu/gemm/bf16/ref_gemm_bf16.hpp" #include "cpu/gemm/f32/ref_gemm_f32.hpp" #include "cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp" #include "cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp" @@ -133,13 +134,14 @@ dnnl_status_t extended_sgemm(const char *transa, const char *transb, } #endif -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE if (mayiuse(sse41)) { float *dummy_ao = nullptr; float *dummy_bo = nullptr; - return gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, alpha, - A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias, + auto status = gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, + alpha, A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias, force_jit_nocopy_gemm); + if (status == status::success) return status; } #endif @@ -200,10 +202,12 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, LDA, ao, B, LDB, bo, beta, C, LDC, co); if (status == dnnl_success) return status; -#if DNNL_X64 - if (mayiuse(sse41)) - return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, - B, LDB, bo, beta, C, LDC, co, false); +#if DNNL_X64 && !__BUILD_GEMM_NONE + if (mayiuse(sse41)) { + auto status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, + LDA, ao, B, LDB, bo, beta, C, LDC, co, false); + if (status == status::success) return status; + } #elif DNNL_PPC64 #ifdef __MMA__ int ATflag = (*transa == 'T') || (*transa == 't'); @@ -236,18 +240,23 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; -#if DNNL_X64 - bool use_jit = mayiuse(avx512_core); +#if DNNL_X64 && !__BUILD_GEMM_NONE + bool use_jit = avx512_gemm_available(); bool use_s8u8 = true && utils::everyone_is(0, *ao, *bo) // so far a requirement && IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41)); - if (use_jit) - return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, - B, LDB, bo, beta, C, LDC, co, false); - else if (use_s8u8) - return simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, alpha, A, - LDA, ao, B, LDB, bo, beta, C, LDC, co); + if (use_jit) { + auto status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, + LDA, ao, B, LDB, bo, beta, C, LDC, co, false); + if (status == status::success) return status; + } + + if (use_s8u8) { + auto status = simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + if (status == status::success) return status; + } #endif #if DNNL_PPC64 @@ -284,16 +293,18 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, ldb, C, ldc, alpha, beta, false); if (status != dnnl_success) return status; -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE char *dummyOffsetC = nullptr; bfloat16_t *dummy_ao = nullptr; bfloat16_t *dummy_bo = nullptr; float *dummy_co = nullptr; - if (mayiuse(avx512_core)) - return gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, + if (avx512_gemm_available()) { + auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, (const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B, ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false); + if (status == status::success) return status; + } #elif DNNL_PPC64 #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) bool trA = *transa == 't' || *transa == 'T'; @@ -308,7 +319,8 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, #endif #endif - return dnnl_unimplemented; + return ref_gemm_bf16bf16f32( + transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } } // namespace cpu diff --git a/src/cpu/gemm/gemm.hpp b/src/cpu/gemm/gemm.hpp index 35082757a11..c24b58981fe 100644 --- a/src/cpu/gemm/gemm.hpp +++ b/src/cpu/gemm/gemm.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2022 Intel Corporation +* Copyright 2018-2023 Intel Corporation * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,34 @@ #if DNNL_X64 #include "cpu/x64/cpu_isa_traits.hpp" + +// Kernels ISA section for configuring knobs. +#define __BUILD_GEMM_AMX BUILD_GEMM_KERNELS_ALL +#define __BUILD_GEMM_AVX512 __BUILD_GEMM_AMX || BUILD_GEMM_AVX512 +#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2 +#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41 +#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE + +#if __BUILD_GEMM_AVX512 +#define avx512_gemm_available() mayiuse(avx512_core) +#define avx512_amx_gemm_available() mayiuse(avx512_core_amx) +#define avx512_bf16_gemm_available() mayiuse(avx512_core_bf16) +#define avx512_vnni_gemm_available() mayiuse(avx512_core_vnni) +#define avx512_bf16_ymm_gemm_available() mayiuse(avx512_core_bf16_ymm) +#else +#define avx512_gemm_available() false +#define avx512_amx_gemm_available() false +#define avx512_bf16_gemm_available() false +#define avx512_vnni_gemm_available() false +#define avx512_bf16_ymm_gemm_available() false +#endif + +#else +#define __BUILD_GEMM_AMX 0 +#define __BUILD_GEMM_AVX512 0 +#define __BUILD_GEMM_AVX2 0 +#define __BUILD_GEMM_SSE41 0 +#define __BUILD_GEMM_NONE 0 #endif #if DNNL_AARCH64 @@ -78,9 +106,9 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, #if !defined(USE_MKL_IGEMM) && defined(DNNL_X64) #define IGEMM_S8U8S32_ISA_STR \ JIT_IMPL_NAME_HELPER(IGEMM_S8U8S32_IMPL_STR ":", \ - mayiuse(avx512_core_vnni) \ + avx512_vnni_gemm_available() \ ? avx512_core_vnni \ - : (mayiuse(avx512_core) ? avx512_core : isa_undef), \ + : (avx512_gemm_available() ? avx512_core : isa_undef), \ "") #else #define IGEMM_S8U8S32_ISA_STR IGEMM_S8U8S32_IMPL_STR diff --git a/src/cpu/gemm/gemm_pack.cpp b/src/cpu/gemm/gemm_pack.cpp index 496a32e5b66..f67549cbf27 100644 --- a/src/cpu/gemm/gemm_pack.cpp +++ b/src/cpu/gemm/gemm_pack.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "cpu/platform.hpp" +#include "cpu/gemm/gemm.hpp" #include "cpu/gemm/gemm_pack.hpp" #if DNNL_X64 @@ -27,13 +28,13 @@ namespace impl { namespace cpu { bool pack_sgemm_supported() { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::pack_sgemm_supported(); #endif return false; } bool pack_gemm_bf16bf16f32_supported() { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::pack_gemm_bf16bf16f32_supported(); #endif return false; @@ -42,7 +43,7 @@ bool pack_gemm_bf16bf16f32_supported() { dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -53,7 +54,7 @@ dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -64,7 +65,7 @@ dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -75,7 +76,7 @@ dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -85,7 +86,7 @@ dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, dnnl_status_t sgemm_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const float *src, float *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -96,7 +97,7 @@ dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const bfloat16_t *src, bfloat16_t *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -106,7 +107,7 @@ dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -116,7 +117,7 @@ dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -127,7 +128,7 @@ dnnl_status_t sgemm_compute(const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const float *A, const dim_t *lda, const float *B, const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_compute( transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); #endif @@ -138,7 +139,7 @@ dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_compute( transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); #endif @@ -149,7 +150,7 @@ dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_compute( transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); #endif @@ -160,7 +161,7 @@ dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_compute( transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); #endif diff --git a/src/cpu/rnn/rnn_utils.hpp b/src/cpu/rnn/rnn_utils.hpp index 8af361d63a0..21ca3f11fb4 100644 --- a/src/cpu/rnn/rnn_utils.hpp +++ b/src/cpu/rnn/rnn_utils.hpp @@ -873,7 +873,7 @@ bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, rnn.diff_weights_overwrite = rd.flags & rnn_flags::diff_weights_overwrite; -#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL +#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL || BUILD_GEMM_KERNELS_NONE // XXX: Threadpool runtime may use different number of threads at execute // and create stages. GEMM packed API is not aware of number of threads as // of now. In order to synchronize all layers, GEMM pack API should be diff --git a/src/cpu/x64/CMakeLists.txt b/src/cpu/x64/CMakeLists.txt index 5cacdb215d0..75b007892b6 100644 --- a/src/cpu/x64/CMakeLists.txt +++ b/src/cpu/x64/CMakeLists.txt @@ -56,6 +56,35 @@ else() PROPERTIES COMPILE_FLAGS "${OPT_LEVEL}") endif() +# Discard GeMM kernel files when requested +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(AVX512|AVX2|SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AMX ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*amx*) + foreach(amx_file ${SOURCES_AMX}) + list(REMOVE_ITEM SOURCES "${amx_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(AVX2|SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AVX512 ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*avx512*) + foreach(avx512_file ${SOURCES_AVX512}) + list(REMOVE_ITEM SOURCES "${avx512_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AVX ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*avx*) + foreach(avx_file ${SOURCES_AVX}) + list(REMOVE_ITEM SOURCES "${avx_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(NONE)$") + file(GLOB_RECURSE SOURCES_SSE41 ${CMAKE_CURRENT_SOURCE_DIR}/gemm/*) + foreach(sse41_file ${SOURCES_SSE41}) + list(REMOVE_ITEM SOURCES "${sse41_file}") + endforeach() +endif() + set(OBJ_LIB ${LIB_PACKAGE_NAME}_cpu_x64) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp index 0d2dd1bd0be..77537a4754e 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp @@ -27,6 +27,8 @@ namespace impl { namespace cpu { namespace x64 { +#define avx512_gemm_available() false + int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const { while (!(((idx / unroll_n_) < std::max(1, um / nelt_per_vecreg_)) || ((idx % unroll_n_) < un))) @@ -36,7 +38,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const { void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( int um, int un, int k_idx, int n_idx) { - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) { prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]); offb_ += 16; @@ -46,7 +48,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload( void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um == 16) || (un < unroll_n_)) { if ((k_idx + m_idx + n_idx) == 0) { prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]); @@ -63,7 +65,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if ((um < unroll_m_) && (m_idx == 0)) { if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0) && (n_idx % 6 == 0)) || ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1) @@ -87,7 +89,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( int um, int un, int k_idx, int n_idx) { - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um == unroll_m_) && (un == 2)) { if (k_idx % 3 == 0) { if (n_idx == 1) { @@ -111,7 +113,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload( void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( int k_idx, int n_idx, int m_idx) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_) == 0) && (n_idx == 1)) { @@ -126,7 +128,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA( void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( int um, int un, int k_idx, int n_idx, int m_idx) { - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um == unroll_m_) && (un == unroll_n_)) { if (((k_idx == 0) && (n_idx % 2 == 1) && (m_idx == 0)) || ((k_idx == 1) && (n_idx == 2) && (m_idx == 0)) @@ -160,7 +162,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA( void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( int um, int un, int k_idx, int n_idx) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if (um == unroll_m_) { if (n_idx == std::min(1, un - 1)) { if (k_idx == unroll_k_ - 1) @@ -173,7 +175,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload( } void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if (um < unroll_m_) { prefetchw(ptr[CO2_ + elt_size_ * 0]); prefetchw(ptr[CO2_ + elt_size_ * 8]); @@ -228,7 +230,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { mov(C_, ptr[rsp + get_size_of_abi_save_regs() + C_off]); mov(LDC_, ptr[rsp + get_size_of_abi_save_regs() + LDC_off]); - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_; i++) vpxorq(Xbyak::Zmm(i), Xbyak::Zmm(i), Xbyak::Zmm(i)); @@ -267,7 +269,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { add(AA_, A_); mov(CO1_, C_); - if ((unroll_x == unroll_m_) || (!mayiuse(avx512_core))) + if ((unroll_x == unroll_m_) || (!avx512_gemm_available())) lea(CO2_, ptr[C_ + LDC_ * 2]); add(C_, unroll_x * elt_size_); @@ -292,12 +294,12 @@ void jit_avx2_kernel_sgemm_kern::generate() { T_NEAR); } - if (!mayiuse(avx512_core)) + if (!avx512_gemm_available()) prefetcht2(ptr[AA_ - addr_off_ * elt_size_]); switch (unroll_x) { case 8: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastf64x4, @@ -319,7 +321,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; case 4: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastf32x4, @@ -340,7 +342,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; case 2: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastsd, @@ -357,7 +359,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { &Xbyak::CodeGenerator::vmovsd); break; case 1: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vbroadcastss, @@ -377,7 +379,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; default: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { loop(unroll_x, unroll_y, &Xbyak::CodeGenerator::vmovups, @@ -400,7 +402,7 @@ void jit_avx2_kernel_sgemm_kern::generate() { break; } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { sub(AA_, -16 * elt_size_); } else { if ((unroll_y != unroll_n_) || (unroll_x <= 4)) { diff --git a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp index 766b07bc5c9..1a66bbb9057 100644 --- a/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp +++ b/src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp @@ -26,21 +26,22 @@ namespace dnnl { namespace impl { namespace cpu { namespace x64 { +#define avx512_gemm_available() false class jit_avx2_kernel_sgemm_kern : public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_kernel_sgemm_kern); const int elt_size_ = 4; const int elt_size_bin_ = 2; - int nelt_per_vecreg_ = mayiuse(avx512_core) ? 16 : 8; + int nelt_per_vecreg_ = avx512_gemm_available() ? 16 : 8; const int unroll_m_reg_ = 3; int unroll_m_ = unroll_m_reg_ * nelt_per_vecreg_; - const int unroll_n_ = mayiuse(avx512_core) ? 8 : 4; + const int unroll_n_ = avx512_gemm_available() ? 8 : 4; const int unroll_k_ = 4; const int unroll_k_bin_ = 2; - const int unroll_m_bin_ = mayiuse(avx512_core) ? 6 : 5; - const int second_fetch_ = mayiuse(avx512_core) ? 32 : 34; - unsigned int unroll_n_bin_ = mayiuse(avx512_core) ? 3 : 2; + const int unroll_m_bin_ = avx512_gemm_available() ? 6 : 5; + const int second_fetch_ = avx512_gemm_available() ? 32 : 34; + unsigned int unroll_n_bin_ = avx512_gemm_available() ? 3 : 2; bool beta_zero_; Xbyak::Reg64 M_ = rdi, N_ = rsi, K_ = rdx, A_ = r8, B_ = r9, C_ = r10, @@ -48,14 +49,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Reg64 I_ = r12, J_ = r13, AA_ = rcx, KK_ = K_, BO_ = rbp, CO1_ = r14, CO2_ = r15; Xbyak::Reg64 AO_ = rbx, LL_ = rax; - int zmm_a_idx_ = 0, zmm_b_idx_ = mayiuse(avx512_core) ? 6 : 3, - zmm_acc_idx_ = mayiuse(avx512_core) ? 8 : 4; - int nb_zmm_a_ = mayiuse(avx512_core) ? unroll_m_reg_ * 2 : unroll_m_reg_, - nb_zmm_b_ = mayiuse(avx512_core) ? 2 : 1; - - int addr_off_ = mayiuse(avx512_core) ? 128 : 32; - int PREFETCHSIZEB_ = mayiuse(avx512_core) ? (-128 + 16 * 8) : 64; - int PREFETCHSIZEA_ = mayiuse(avx512_core) ? (-128 + 16 * 2) + int zmm_a_idx_ = 0, zmm_b_idx_ = avx512_gemm_available() ? 6 : 3, + zmm_acc_idx_ = avx512_gemm_available() ? 8 : 4; + int nb_zmm_a_ = avx512_gemm_available() ? unroll_m_reg_ * 2 : unroll_m_reg_, + nb_zmm_b_ = avx512_gemm_available() ? 2 : 1; + + int addr_off_ = avx512_gemm_available() ? 128 : 32; + int PREFETCHSIZEB_ = avx512_gemm_available() ? (-128 + 16 * 8) : 64; + int PREFETCHSIZEA_ = avx512_gemm_available() ? (-128 + 16 * 2) : (PREFETCHSIZEB_ * 2 + 16); int off_ = 0, offb_ = 0; @@ -74,10 +75,10 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { void loadA_betweenFMAs(int um, int un, int k_idx, int n_idx, int m_idx, void (Xbyak::CodeGenerator::*aload)( const T_desta &, const T_srca &)) { - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = avx512_gemm_available() ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); - if (!(mayiuse(avx512_core) || (um <= 8) || ((um == 16) && (un == 4)))) { + if (!(avx512_gemm_available() || (um <= 8) || ((um == 16) && (un == 4)))) { if (n_idx == un - 1) { (this->*aload)(T_reg(zmm_a_idx_ + m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) @@ -100,10 +101,10 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { const T_desta &, const T_srca &)) { int i; - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = avx512_gemm_available() ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); - if (mayiuse(avx512_core) || (um <= 8) || ((um == 16) && (un == 4))) { + if (avx512_gemm_available() || (um <= 8) || ((um == 16) && (un == 4))) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { (this->*aload)(T_reg(zmm_a_idx_ + i + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) @@ -130,31 +131,31 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Label K_loop_body_label; int i, j, p, b_idx; - int addb_off = ((!mayiuse(avx512_core)) && (nb_zmm_b_ == 2)) ? 1 : 0; + int addb_off = ((!avx512_gemm_available()) && (nb_zmm_b_ == 2)) ? 1 : 0; - int next_zmm_a = mayiuse(avx512_core) + int next_zmm_a = avx512_gemm_available() ? unroll_m_reg_ : std::max(1, um / nelt_per_vecreg_); off_ = 0, offb_ = 0; - if (mayiuse(avx512_core)) L_aligned(K_loop_body_label); + if (avx512_gemm_available()) L_aligned(K_loop_body_label); if (cfetch) prefetchC_beforeKloop(um); - if (!mayiuse(avx512_core)) L_aligned(K_loop_body_label); + if (!avx512_gemm_available()) L_aligned(K_loop_body_label); for (p = 0; p < unroll_k_; p++) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if ((um == unroll_m_) && (p == unroll_k_ - 1)) { prefetcht2(ptr[AA_ - elt_size_ * 128]); } } for (j = 0; j < un; j++) { - b_idx = mayiuse(avx512_core) ? j % nb_zmm_b_ : p % nb_zmm_b_; + b_idx = avx512_gemm_available() ? j % nb_zmm_b_ : p % nb_zmm_b_; - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um == unroll_m_) && (un == unroll_n_)) { if ((j == un - 1) && (p == unroll_k_ - 1)) sub(BO_, -un * unroll_k_ * elt_size_); @@ -182,7 +183,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { prefetchB_beforeBload(um, un, p, j); - if (!mayiuse(avx512_core) && (um == unroll_m_) + if (!avx512_gemm_available() && (um == unroll_m_) && (un == unroll_n_) && (j == un - 1) && (p == unroll_k_ - 1)) { (this->*bload)(T_reg(zmm_b_idx_ + b_idx), @@ -205,14 +206,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if (cfetch) prefetchC_afterBload(um, un, p, j); - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if ((um == unroll_m_) && (p == unroll_k_ - 1) && (j == std::min(un - 1, 3))) lea(AA_, ptr[AA_ + elt_size_ * unroll_n_]); } } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { for (j = un; j < unroll_n_; j++) { if (um < unroll_m_) { if (((p % (nb_zmm_a_ / unroll_m_reg_) == 0) @@ -228,7 +229,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { loadA_after(um, un, p, aload); } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { lea(AO_, ptr[AO_ + um * unroll_k_ * elt_size_]); lea(BO_, ptr[BO_ + un * unroll_k_ * elt_size_]); } else { @@ -261,7 +262,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { T_reg(zmm_b_idx_ + (j % nb_zmm_b_)), T_reg(i + zmm_a_idx_)); - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { if (i == 0) { if (j % 3 == 0) { prefetcht0(ptr[AO_ @@ -290,17 +291,17 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { - j)]); } - if (mayiuse(avx512_core) && (un < 2)) + if (avx512_gemm_available() && (un < 2)) prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_)]); - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { for (i = un; i < 8; i += 4) { prefetcht0(ptr[AO_ + elt_size_ * (PREFETCHSIZEA_ + off_)]); off_ += 16; } } - if (mayiuse(avx512_core) || (um <= nelt_per_vecreg_)) { + if (avx512_gemm_available() || (um <= nelt_per_vecreg_)) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { (this->*aload)(T_reg(zmm_a_idx_ + i), ptr[AO_ @@ -310,7 +311,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { lea(AO_, ptr[AO_ + um * elt_size_]); lea(BO_, ptr[BO_ + un * elt_size_]); } else { @@ -334,14 +335,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { Xbyak::Label end_K_loop_label, end_main_K_loop_label; Xbyak::Label K_loop_with_prefetch_label, K_loop_with_prefetch_rem_label; - Xbyak::Reg64 A_reg = (mayiuse(avx512_core)) ? AO_ + Xbyak::Reg64 A_reg = (avx512_gemm_available()) ? AO_ : ((um == unroll_m_) && (un == unroll_n_)) ? A_ : AO_; - if (mayiuse(avx512_core) || (unroll_m_ != um) || (unroll_n_ != un)) + if (avx512_gemm_available() || (unroll_m_ != um) || (unroll_n_ != un)) mov(AO_, A_); - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { nb_zmm_a_ = unroll_m_reg_; nb_zmm_b_ = 1; @@ -366,10 +367,10 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { zmm_acc_idx_ = zmm_b_idx_ + nb_zmm_b_; acc_idx = 0; - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { j = zmm_b_idx_; for (k = 0; k < nb_zmm_b_; k++) { - if (!mayiuse(avx512_core) && (un > 1)) { + if (!avx512_gemm_available() && (un > 1)) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -383,14 +384,14 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } for (k = 0; k < nb_zmm_a_ / unroll_m_reg_; k++) { - if (mayiuse(avx512_core)) + if (avx512_gemm_available()) j = zmm_a_idx_ + k * unroll_m_reg_; else j = zmm_a_idx_ + k * std::max(1, um / nelt_per_vecreg_); for (i = nelt_per_vecreg_; i <= std::max(um, nelt_per_vecreg_); i += nelt_per_vecreg_) { - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -405,10 +406,10 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { j = zmm_b_idx_; for (k = 0; k < nb_zmm_b_; k++) { - if (!mayiuse(avx512_core) && (un > 1)) { + if (!avx512_gemm_available() && (un > 1)) { acc_idx = next_acc(acc_idx, um, un); vxorps(T_reg(zmm_acc_idx_ + acc_idx), T_reg(zmm_acc_idx_ + acc_idx), @@ -421,7 +422,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if (un > 1) { if ((um == unroll_m_) @@ -490,15 +491,15 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { vxorps(T_reg(i), T_reg(i), T_reg(i)); } - if (!((mayiuse(avx512_core) || (unroll_m_ != um) || (unroll_n_ != un)))) + if (!((avx512_gemm_available() || (unroll_m_ != um) || (unroll_n_ != un)))) mov(AO_, A_); mov(LL_, KK_); sar(LL_, unroll_k_bin_); jle(end_main_K_loop_label, T_NEAR); - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) + if (avx512_gemm_available() + || (!avx512_gemm_available() && (un == unroll_n_) && (um == unroll_m_))) { sub(LL_, second_fetch_); jle(K_loop_with_prefetch_label, T_NEAR); @@ -507,26 +508,26 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { k_loop_body( 0, um, un, aload, bload); - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) + if (avx512_gemm_available() + || (!avx512_gemm_available() && (un == unroll_n_) && (um == unroll_m_))) { L_aligned(K_loop_with_prefetch_label); } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { lea(CO2_, ptr[CO1_ + (nelt_per_vecreg_ - 1) * elt_size_]); add(LL_, un); jle(K_loop_with_prefetch_rem_label, T_NEAR); } - if (mayiuse(avx512_core) - || (!mayiuse(avx512_core) && (un == unroll_n_) + if (avx512_gemm_available() + || (!avx512_gemm_available() && (un == unroll_n_) && (um == unroll_m_))) { k_loop_body( 1, um, un, aload, bload); } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { L_aligned(K_loop_with_prefetch_rem_label); add(LL_, second_fetch_ - un); jle(end_main_K_loop_label, T_NEAR); @@ -537,7 +538,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { L_aligned(end_main_K_loop_label); - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((un == unroll_n_) && ((um == 16) || (um == 8))) { prefetcht2(ptr[AA_ - 16 * elt_size_]); } @@ -568,7 +569,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if ((um < unroll_m_) && (um >= nelt_per_vecreg_)) offAA = 32 - (un / 2) * 16; - if (mayiuse(avx512_core)) + if (avx512_gemm_available()) lea(CO2_, ptr[CO1_ + LDC_]); else { if ((um == nelt_per_vecreg_) && (un == unroll_n_)) { @@ -578,7 +579,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } for (j = 0; j < un; j++) { - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { reg_C = (j == 0) ? CO1_ : CO2_; if (j >= 2) { add(CO2_, LDC_); } } else @@ -588,7 +589,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { if (!is_beta_zero) { if (sepload) { for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { - if (!mayiuse(avx512_core) && (j % 2 == 1)) { + if (!avx512_gemm_available() && (j % 2 == 1)) { (this->*sload)(vec_reg_t(i), ptr[reg_C + LDC_ + elt_size_ * i @@ -614,7 +615,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if (j > 0) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; @@ -623,7 +624,7 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { // store accumulated value in C_ for (i = 0; i < std::max(um / nelt_per_vecreg_, 1); i++) { - if (!mayiuse(avx512_core) && (j % 2 == 1)) { + if (!avx512_gemm_available() && (j % 2 == 1)) { (this->*store)(ptr[reg_C + LDC_ + elt_size_ * i * nelt_per_vecreg_], vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); @@ -632,21 +633,21 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { ptr[reg_C + elt_size_ * i * nelt_per_vecreg_], vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); } - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { vpxorq(vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_), vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_), vec_reg_t(zmm_acc_idx_ + j + i * unroll_n_)); } } - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um == unroll_m_) && (un == 1)) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; } } - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if (j == std::min(1, un - 1)) { if (j == 0) add(CO1_, LDC_); @@ -662,9 +663,9 @@ class jit_avx2_kernel_sgemm_kern : public jit_generator { } } - if (mayiuse(avx512_core)) lea(CO1_, ptr[CO2_ + LDC_]); + if (avx512_gemm_available()) lea(CO1_, ptr[CO2_ + LDC_]); - if (!mayiuse(avx512_core)) { + if (!avx512_gemm_available()) { if ((um >= nelt_per_vecreg_) && (un < unroll_n_)) { prefetcht2(ptr[AA_ + elt_size_ * offAA]); offAA += 16; diff --git a/src/cpu/x64/gemm/gemm_driver.cpp b/src/cpu/x64/gemm/gemm_driver.cpp index 0b66a9adba7..adbea9bb57e 100644 --- a/src/cpu/x64/gemm/gemm_driver.cpp +++ b/src/cpu/x64/gemm/gemm_driver.cpp @@ -29,6 +29,7 @@ #include "cpu/platform.hpp" #include "cpu/gemm/f32/gemm_utils_f32.hpp" +#include "cpu/gemm/gemm.hpp" #include "cpu/gemm/gemm_msan_unpoison.hpp" #include "cpu/x64/jit_generator.hpp" @@ -73,7 +74,7 @@ template int get_vector_length() { int v_bytes; - if (mayiuse(avx512_core)) + if (avx512_gemm_available()) v_bytes = cpu_isa_traits::vlen; else if (mayiuse(avx)) v_bytes = cpu_isa_traits::vlen; @@ -388,7 +389,7 @@ void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_f32 = data_traits::data_type == data_type::f32; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && avx512_amx_gemm_available(); dim_t m_stk = col_offset_ws ? 1 : m; dim_t n_stk = row_offset_ws ? 1 : n; @@ -532,8 +533,8 @@ static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && avx512_amx_gemm_available(); + bool is_bf16_amx = is_bf16 && avx512_amx_gemm_available(); bool is_amx = is_int8_amx || is_bf16_amx; const std::shared_ptr &a_packed = arg->a_packed; @@ -810,8 +811,8 @@ static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && avx512_amx_gemm_available(); + bool is_bf16_amx = is_bf16 && avx512_amx_gemm_available(); bool is_amx = is_int8_amx || is_bf16_amx; // B buffer needs to be large due to zero-padding. @@ -1047,7 +1048,7 @@ static inline bool nocopy_checker( if (arg->a_packed || arg->b_packed) return false; - else if (mayiuse(avx512_core)) + else if (avx512_gemm_available()) return nocopy_checker_avx512( nthr, transa, transb, m, n, k, lda, ldb, ldc); else @@ -1085,21 +1086,21 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, bool condition_2D_bsrc = false; if (isSgemm) { // If m is large and n is small then do 1D partitioning for AVX2. - if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN)) + if (!avx512_gemm_available() && n <= N2D_MAX && (m >= nthrs * M2D_MIN)) condition_2D_bsrc = false; else condition_2D_bsrc = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2)) && (m >= 2 * M2D_MIN); } else { - int scale = mayiuse(avx512_core) ? nthrs : 20; + int scale = avx512_gemm_available() ? nthrs : 20; condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n); } // TODO Check if we should use k-partitioning. int condition_1D_copya = false; - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68; if (m >= 1000 && (n >= nthrs * thresh)) { condition_2D_bsrc = false; @@ -1117,7 +1118,7 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, // TODO: the reasons seems to be in copy_sum_bx routines. At least, // after simple optimization of copy_sum_ax for avx512, similar // restriction on offset B became unnecessary. Revisit. - if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) { + if (is_int8 && arg->ao != 0 && (arg->bo != 0 || avx512_gemm_available())) { condition_2D_bsrc = false; condition_1D_copya = true; } @@ -1160,7 +1161,7 @@ static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, } else if ((n <= 64 || n >= 256)) { while (((nthrs_n > 1) && (n / nthrs_n < arg->un) && (m / nthrs_m >= 2 * arg->um) - && mayiuse(avx512_core)) + && avx512_gemm_available()) || ((nthrs_n % 2 == 0) && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2) @@ -1288,7 +1289,7 @@ static inline void set_thread_opts_pack(int nthrs, choose_k_blocking(); // Choose m/n blocking. - auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um; + auto min_mblk = avx512_gemm_available() ? (MBLK / 2) : arg->um; min_mblk = do_m_blocking ? min_mblk : m; min_mblk = do_m_blocking_only ? arg->um : min_mblk; auto min_nblk = do_n_blocking ? NBLK / 2 : n; @@ -1342,7 +1343,7 @@ static inline int set_thread_opts(int nthrs, int nthrs_spawn, dim_t BK = 0; auto m = arg->m, n = arg->n, k = arg->k; - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs, &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK); } else { @@ -1416,8 +1417,8 @@ static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs, constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && avx512_amx_gemm_available(); + bool is_bf16_amx = is_bf16 && avx512_amx_gemm_available(); bool is_amx = is_int8_amx || is_bf16_amx; const std::shared_ptr &a_packed = arg->a_packed; @@ -1573,7 +1574,7 @@ static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) { const bool is_f32 = data_traits::data_type == data_type::f32; - const bool is_avx512 = mayiuse(avx512_core); + const bool is_avx512 = avx512_gemm_available(); const bool is_avx = mayiuse(avx); const bool is_only_avx2 = mayiuse(avx2) && !is_avx512; @@ -1657,21 +1658,29 @@ static dnnl_status_t call_no_copy_sgemm( int nthrs, gemm_info_t *arg) { if (arg->packing == pack_type::none) { +#if __BUILD_GEMM_AVX2 auto transa_char = (arg->transa != do_trans) ? "N" : "T"; auto transb_char = (arg->transb != do_trans) ? "N" : "T"; +#endif - if (mayiuse(avx512_core)) + if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, &arg->ldc, (float *)arg->co); - else +#endif + } else { +#if __BUILD_GEMM_AVX2 return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, &arg->ldc, (float *)arg->co); +#endif + } } else return pack_no_copy(arg); + return status::unimplemented; } template @@ -1687,12 +1696,14 @@ static dnnl_status_t gemm_threading_driver( if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success; +#if __BUILD_GEMM_AVX512 if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg)) return dnnl_success; if (!is_a_packed && !is_b_packed && jump_to_gemm_smalln_tn(arg) == dnnl_success) return dnnl_success; +#endif if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success) return dnnl_success; @@ -1926,7 +1937,8 @@ static dnnl_status_t gemm_threading_driver( == data_type::f32); assert(arg->packing == pack_type::none); - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 thread_arg[ithr].result = avx512_common_gemm_f32:: sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -1935,7 +1947,9 @@ static dnnl_status_t gemm_threading_driver( arg->lda, (float *)b, arg->ldb, &beta_eff, (float *)c_eff, ldc_eff, nullptr); +#endif } else { +#if __BUILD_GEMM_AVX2 thread_arg[ithr].result = avx_gemm_f32::sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -1944,6 +1958,7 @@ static dnnl_status_t gemm_threading_driver( arg->lda, (float *)b, arg->ldb, &beta_eff, (float *)c_eff, ldc_eff, nullptr); +#endif } break; } @@ -1999,7 +2014,7 @@ dnnl_status_t gemm_driver(const char *transA, const char *transB, // gemm_driver supports bfloat16 gemm for Intel AVX512 and // Intel AVX512 BF16. assert(IMPLICATION(data_traits::data_type == data_type::bf16, - mayiuse(avx512_core) && !force_nocopy)); + avx512_gemm_available() && !force_nocopy)); // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX, // Intel SSE4.1 and Intel DL Boost. diff --git a/src/cpu/x64/gemm/gemm_info.cpp b/src/cpu/x64/gemm/gemm_info.cpp index 849edde47e1..8a3dacaf27e 100644 --- a/src/cpu/x64/gemm/gemm_info.cpp +++ b/src/cpu/x64/gemm/gemm_info.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,8 @@ #include "common/bfloat16.hpp" #include "common/dnnl_traits.hpp" -#include "common/dnnl_sel_build.hpp" + +#include "cpu/gemm/gemm.hpp" #include "cpu/x64/cpu_isa_traits.hpp" #include "cpu/x64/jit_generator.hpp" @@ -75,7 +76,7 @@ void prepare_bo(int32_t &bo_gemm_info, const uint8_t *bo_orig) { template <> void prepare_bo(int32_t &bo_gemm_info, const int8_t *bo_orig) { int bo_s32 = bo_orig ? *bo_orig : 0; - if (!mayiuse(avx512_core_amx)) bo_s32 += 128; + if (!avx512_amx_gemm_available()) bo_s32 += 128; bo_gemm_info = bo_s32; } @@ -212,14 +213,14 @@ void gemm_info_t::jit_init(void) { { constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; const bool max_isa_supports_bf16_ymm - = mayiuse(avx512_core_bf16_ymm) && !mayiuse(avx512_core_amx); + = avx512_bf16_ymm_gemm_available() && !avx512_amx_gemm_available(); use_bf16_ymm = is_bf16 && max_isa_supports_bf16_ymm; } switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core_amx)) { + if (avx512_amx_gemm_available()) { this->um = 32; this->un = 32; this->uk = 64; @@ -230,13 +231,13 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 0; this->blocking_small_k = 0; this->bn_small_k = 0; - } else if (mayiuse(avx512_core)) { + } else if (avx512_gemm_available()) { this->um = 48; this->un = 8; this->uk = 1; this->bm = 9984; this->bn = 384; - this->bk = mayiuse(avx512_core_vnni) ? 1536 : 768; + this->bk = avx512_vnni_gemm_available() ? 1536 : 768; this->bk_traditional = 384; this->blocking_small_k = 48; @@ -278,7 +279,7 @@ void gemm_info_t::jit_init(void) { break; case data_type::bf16: - if (mayiuse(avx512_core_amx)) { + if (avx512_amx_gemm_available()) { this->um = 32; this->un = 32; this->uk = 32; @@ -289,7 +290,7 @@ void gemm_info_t::jit_init(void) { this->bk_traditional = 0; this->blocking_small_k = 0; this->bn_small_k = 0; - } else if (mayiuse(avx512_core)) { + } else if (avx512_gemm_available()) { this->um = use_bf16_ymm ? 24 : 48; this->un = 8; this->uk = 1; @@ -304,7 +305,7 @@ void gemm_info_t::jit_init(void) { break; case data_type::f32: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { this->um = 48; this->un = 8; this->uk = 1; @@ -359,12 +360,15 @@ void gemm_info_t::jit_init(void) { static std::once_flag initialized; static std::atomic st(dnnl_success); std::call_once(initialized, [&, um] { + MAYBE_UNUSED(um); +#if __BUILD_GEMM_AVX512 const bool b_is_s8 = data_traits::data_type == data_type::s8; +#endif constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; - bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); - bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); + bool is_int8_amx = is_int8 && avx512_amx_gemm_available(); + bool is_bf16_amx = is_bf16 && avx512_amx_gemm_available(); bool is_amx = is_int8_amx || is_bf16_amx; static maybe_unique_ptr copy_a[2][2] = {{nullptr}}; @@ -373,7 +377,7 @@ void gemm_info_t::jit_init(void) { switch (data_traits::data_type) { case data_type::s8: if (mayiuse(amx_int8)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_amx_int8) { +#if __BUILD_GEMM_AMX for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -382,124 +386,124 @@ void gemm_info_t::jit_init(void) { copy_b[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( false, isTrans, sizeof(b_t))); - } - } - } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx512_core) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); - - copy_a[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); - copy_b[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); } +#endif + } else if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); + + copy_a[no_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); + copy_b[do_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); +#endif } else if (mayiuse(avx2_vnni)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx2_vnni) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx2) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx) { - copy_a[no_trans][no_sum].reset( - new jit_avx_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_sse41) { - copy_a[no_trans][no_sum].reset( - new jit_sse41_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_sse41_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_sse41_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_sse41_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_SSE41 + copy_a[no_trans][no_sum].reset( + new jit_sse41_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_sse41_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_sse41_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_sse41_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_bt_kern()); +#endif } break; case data_type::bf16: if (mayiuse(amx_bf16)) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_amx_bf16) { +#if __BUILD_GEMM_AMX for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -508,213 +512,215 @@ void gemm_info_t::jit_init(void) { copy_b[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( false, isTrans, sizeof(b_t))); - } - } - } else if (mayiuse(avx512_core) && !use_bf16_ymm) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_not_use_bf16_ymm) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bt_kern()); - } - } else if (mayiuse(avx512_core) && use_bf16_ymm) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_use_bf16_ymm) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bt_kern()); } +#endif + } else if (avx512_gemm_available() && !use_bf16_ymm) { +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_bt_kern()); +#endif + } else if (avx512_gemm_available() && use_bf16_ymm) { +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_bt_kern()); +#endif } break; case data_type::f32: - if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx512_core) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bt_kern()); - } + if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_f32_copy_bt_kern()); +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx2) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_f32_copy_bt_kern()); +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx) { - copy_a[no_trans][no_sum].reset( - new jit_avx_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx_f32_copy_bt_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_sse41) { - copy_a[no_trans][no_sum].reset( - new jit_sse41_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_sse41_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_sse41_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_sse41_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_sse41_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_sse41_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_sse41_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_sse41_f32_copy_bt_kern()); +#endif } break; default: break; } +#if __BUILD_GEMM_AMX constexpr bool is_a_s8 = data_traits::data_type == data_type::s8; constexpr bool is_b_s8 = data_traits::data_type == data_type::s8; constexpr bool is_c_s32 = data_traits::data_type == data_type::s32; +#endif static maybe_unique_ptr kernel[2][2][2][2] = {{{{nullptr}}}}; switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core_amx)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core_bf16_amx_int8) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( - is_a_s8, is_b_s8, is_c_s32, isBeta0)); - } - } - } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int doColSum : {no_sum, do_sum}) - for (int doRowSum : {no_sum, do_sum}) { - kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( - new jit_avx512_core_gemm_s8u8s32_kern( - isBeta0, doColSum, doRowSum)); - } + if (avx512_amx_gemm_available()) { +#if __BUILD_GEMM_AMX + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx512_core_amx_gemm_kern( + is_a_s8, is_b_s8, is_c_s32, isBeta0)); } +#endif + } else if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int doColSum : {no_sum, do_sum}) + for (int doRowSum : {no_sum, do_sum}) { + kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( + new jit_avx512_core_gemm_s8u8s32_kern( + isBeta0, doColSum, doRowSum)); + } +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx2) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int doColSum : {no_sum, do_sum}) - for (int doRowSum : {no_sum, do_sum}) { - kernel[isBeta0][do_alpha1][doColSum][doRowSum] - .reset(new jit_avx2_gemm_s8u8s32_kern( - isBeta0, doColSum, doRowSum, - um)); - } - } +#if __BUILD_GEMM_AVX2 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int doColSum : {no_sum, do_sum}) + for (int doRowSum : {no_sum, do_sum}) { + kernel[isBeta0][do_alpha1][doColSum][doRowSum] + .reset(new jit_avx2_gemm_s8u8s32_kern( + isBeta0, doColSum, doRowSum, + um)); + } +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_c_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_r_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b_gemm_s8u8s32_kern()); - - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); - } +#if __BUILD_GEMM_AVX2 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_avx_kernel_c_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_avx_kernel_r_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_avx_kernel_b_gemm_s8u8s32_kern()); + + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_b0_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_sse41) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_c_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_r_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b_gemm_s8u8s32_kern()); - - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); - } +#if __BUILD_GEMM_SSE41 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_sse41_kernel_c_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_sse41_kernel_r_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_sse41_kernel_b_gemm_s8u8s32_kern()); + + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); +#endif } break; case data_type::bf16: - if (mayiuse(avx512_core_amx)) { - DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core_bf16_amx_bf16) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( - is_a_s8, is_b_s8, is_c_s32, isBeta0)); - } - } - } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int isAlpha1 : {no_alpha1, do_alpha1}) { - kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( - new jit_avx512_core_gemm_bf16bf16f32_kern( - isBeta0, isAlpha1, !use_bf16_ymm)); - } + if (avx512_amx_gemm_available()) { +#if __BUILD_GEMM_AMX + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx512_core_amx_gemm_kern( + is_a_s8, is_b_s8, is_c_s32, isBeta0)); } +#endif + } else if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int isAlpha1 : {no_alpha1, do_alpha1}) { + kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( + new jit_avx512_core_gemm_bf16bf16f32_kern( + isBeta0, isAlpha1, !use_bf16_ymm)); + } +#endif } break; case data_type::f32: if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_avx2) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx2_kernel_sgemm_kern(isBeta0)); - } +#if __BUILD_GEMM_AVX2 + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx2_kernel_sgemm_kern(isBeta0)); } +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_avx) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_sgemm_kern()); - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_sgemm_kern()); - } +#if __BUILD_GEMM_AVX2 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_sgemm_kern()); + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_b0_sgemm_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_sse41) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_sgemm_kern()); - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_sgemm_kern()); - } +#if __BUILD_GEMM_SSE41 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_sgemm_kern()); + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_b0_sgemm_kern()); +#endif } break; @@ -727,43 +733,43 @@ void gemm_info_t::jit_init(void) { static maybe_unique_ptr gemv_u8s8s32_kernel = nullptr; switch (data_traits::data_type) { case data_type::s8: - if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemv_kern_s8_avx512_core) { - gemv_s8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); - gemv_s8u8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); - gemv_u8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); - } + if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + gemv_s8s8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); + gemv_s8u8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); + gemv_u8s8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); +#endif } break; case data_type::bf16: - if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemv_kern_bf16_avx512_core) { - for (int isTrans : {no_trans, do_trans}) - gemv_kernel[isTrans].reset( - new jit_avx512_core_gemv_bf16bf16f32_kern( - isTrans)); - } + if (avx512_gemm_available()) { +#if __BUILD_GEMM_AVX512 + for (int isTrans : {no_trans, do_trans}) + gemv_kernel[isTrans].reset( + new jit_avx512_core_gemv_bf16bf16f32_kern( + isTrans)); +#endif } break; case data_type::f32: if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemv_kern_f32_avx) { - gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); - gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); - } +#if __BUILD_GEMM_AVX2 + gemv_kernel[no_trans].reset( + new jit_sse41_gemv_n_f32_kern()); + gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemv_kern_f32_sse41) { - gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); - gemv_kernel[do_trans].reset( - new jit_sse41_gemv_t_f32_kern()); - } +#if __BUILD_GEMM_SSE41 + gemv_kernel[no_trans].reset( + new jit_sse41_gemv_n_f32_kern()); + gemv_kernel[do_trans].reset( + new jit_sse41_gemv_t_f32_kern()); +#endif } break; default: assert(!"unsupported data type!"); @@ -919,7 +925,7 @@ bool gemm_info_t::hasKernels(void) { if (!this->copyA || !this->copyB) return false; - if (mayiuse(avx512_core)) + if (avx512_gemm_available()) if (!this->gemv_s8u8s32_kernel || !this->gemv_u8s8s32_kernel || !this->gemv_s8s8s32_kernel) return false; @@ -927,7 +933,7 @@ bool gemm_info_t::hasKernels(void) { break; case data_type::bf16: - if (mayiuse(avx512_core)) { + if (avx512_gemm_available()) { for (int isBeta0 : {no_beta0, do_beta0}) if (!this->kernel[isBeta0][no_sum][no_sum]) return false; diff --git a/src/cpu/x64/gemm/gemm_pack.cpp b/src/cpu/x64/gemm/gemm_pack.cpp index 4e1cc59731d..53675e25f8b 100644 --- a/src/cpu/x64/gemm/gemm_pack.cpp +++ b/src/cpu/x64/gemm/gemm_pack.cpp @@ -42,7 +42,7 @@ bool pack_sgemm_supported() { } bool pack_gemm_bf16bf16f32_supported() { - return mayiuse(avx512_core); + return avx512_gemm_available(); } #if USE_MKL_PACKED_GEMM @@ -84,7 +84,7 @@ static inline bool use_reference_igemm(void) { if (is_s8u8) return !mayiuse(sse41); else - return !mayiuse(avx512_core); + return !avx512_gemm_available(); } #else diff --git a/src/cpu/x64/gemm/gemv_driver.cpp b/src/cpu/x64/gemm/gemv_driver.cpp index b83c40ce387..d4e2e3d82ce 100644 --- a/src/cpu/x64/gemm/gemv_driver.cpp +++ b/src/cpu/x64/gemm/gemv_driver.cpp @@ -266,7 +266,7 @@ static inline int thread_checker( } #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL if (is_f32) { - static const bool is_avx2 = mayiuse(avx2) && !mayiuse(avx512_core); + static const bool is_avx2 = mayiuse(avx2) && !avx512_gemm_available(); static auto l2_cache_per_thread = platform::get_per_core_cache_size(2); static int n_cores_per_socket diff --git a/tests/gtests/in/gemm_in.h b/tests/gtests/in/gemm_in.h index bcabd886ce9..1ec1bf9b64a 100644 --- a/tests/gtests/in/gemm_in.h +++ b/tests/gtests/in/gemm_in.h @@ -128,6 +128,7 @@ CPU_INST_TEST_CASE(TestGEMM_stkmem, test_params {'n', 'n', 2, 16, 256, 1.0f, 0.0f, 256, 16, 16}); #if defined(FP32) || defined(BF16BF16F32) +#if !BUILD_GEMM_KERNELS_NONE INST_TEST_CASE(TestGEMM_packed, test_params {'t', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, {}, {false, true}, true, dnnl_invalid_arguments}, @@ -198,6 +199,7 @@ INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 't', 'n', 200, 300, 8000, 1.0f, 3.0f, 200, 300, 300)); #endif +#endif #elif defined(BF16BF16BF16) @@ -254,6 +256,7 @@ constexpr test_igemm_params fix_no_offsets = {'F', false, false, false}; constexpr test_igemm_params col_no_offsets = {'C', false, false, false}; constexpr test_igemm_params row_no_offsets = {'R', false, false, false}; +#if !BUILD_GEMM_KERNELS_NONE INST_TEST_CASE(TestGEMM_expected_failures, test_params {'t', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, {}, {}, true, dnnl_invalid_arguments}, @@ -290,6 +293,7 @@ INST_TEST_CASE(TestGEMM_expected_failures, true, dnnl_invalid_arguments}, test_params {'n', 'd', 3, 2, 1, 1.0, 0.0, 3, 3, 3, {}, {false, true}, true, dnnl_invalid_arguments}); +#endif CPU_INST_TEST_CASE(TestGEMM_stkmem, test_params {'n', 'n', 10, 4000, 2, 1.0, 0.0, 2, 4000, 4000, @@ -733,6 +737,7 @@ CPU_INST_TEST_CASE(TestGEMV_kblocking, test_params {'t', 'n', 1, 550, 7000, 1.0f, 1.0f, 7000, 550, 550, fix_no_offsets}); +#if !BUILD_GEMM_KERNELS_NONE CPU_INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 'N', 'n', 30, 20, 10, 1.0f, 1.0f, 60, 50, 80, fix_use_oc), @@ -830,6 +835,7 @@ CPU_INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 'n', 'T', 1, 200, 200, 1.0f, 0.0f, 200, 200, 200, fix_no_offsets)); #endif +#endif CPU_INST_TEST_CASE(TestGEMM_heavy, test_params {'n', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, @@ -841,6 +847,7 @@ CPU_INST_TEST_CASE(TestGEMM_heavy, test_params {'t', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, fix_use_oc}); +#if !BUILD_GEMM_KERNELS_NONE CPU_INST_TEST_CASE(TestGEMM_packed_heavy, make_test_params_pack({false, true}, 'n', 'n', 3000, 3000, 3000, 1.0f, 0.0f, 3000, 3000, 3000, fix_use_oc), @@ -874,5 +881,5 @@ CPU_INST_TEST_CASE(TestGEMM_packed_heavy, 3.0f, 8000, 8000, 200, row_use_oc), make_test_params_pack({false, true}, 't', 'n', 200, 300, 8000, 1.0f, 0.0f, 200, 300, 300, col_use_oc)); - +#endif #endif