From 654635d18a9725f4964b6e62211ef806ddd3a94d Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Fri, 24 Jan 2025 17:33:49 +0800 Subject: [PATCH] apply Alexandra comments --- cmake/features.cmake | 2 +- .../src/emitters/snippets/brgemm_base.cpp | 156 +++------------- .../src/emitters/snippets/brgemm_base.hpp | 92 ++------- .../snippets/x64/kernel_executors/brgemm.cpp | 4 +- .../snippets/x64/kernel_executors/brgemm.hpp | 6 +- .../x64/kernel_executors/brgemm_amx.cpp | 4 +- .../x64/kernel_executors/brgemm_amx.hpp | 5 +- .../x64/kernel_executors/brgemm_base.cpp | 175 ++++++++++++++++++ .../x64/kernel_executors/brgemm_base.hpp | 117 ++++++++++++ .../tpp/aarch64/jit_brgemm_emitter.cpp | 3 +- .../tpp/aarch64/jit_brgemm_emitter.hpp | 4 +- .../kernel_executors/brgemm.cpp | 35 ++-- .../kernel_executors/brgemm.hpp | 39 ++-- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 4 - 14 files changed, 382 insertions(+), 264 deletions(-) create mode 100644 src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp create mode 100644 src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp rename src/plugins/intel_cpu/src/emitters/tpp/{aarch64 => common}/kernel_executors/brgemm.cpp (83%) rename src/plugins/intel_cpu/src/emitters/tpp/{aarch64 => common}/kernel_executors/brgemm.hpp (78%) diff --git a/cmake/features.cmake b/cmake/features.cmake index f1ec371fc58f49..dc8ebeeb9371ad 100644 --- a/cmake/features.cmake +++ b/cmake/features.cmake @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF) ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF) -ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU" OFF) +ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND (X86_64 OR AARCH64)" OFF) ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF) diff --git a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp index 9bc99888463f24..c694f732978c1b 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp @@ -6,8 +6,6 @@ #include "common/utils.hpp" #include "dnnl_extension_utils.h" -#include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "transformations/snippets/x64/op/brgemm_utils.hpp" #define DIM_CAST(X) static_cast(X) #define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) @@ -31,16 +29,32 @@ bool BrgemmBaseKernelConfig::is_empty() const { } bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const { - return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && - (EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params())); + return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC); } -void BrgemmBaseKernelConfig::update(dnnl_dim_t M, - dnnl_dim_t N, - dnnl_dim_t K, - dnnl_dim_t LDA, - dnnl_dim_t LDB, - dnnl_dim_t LDC, +void BrgemmBaseKernelConfig::update(int64_t M, int64_t N, int64_t K, float beta) { + // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) + // To process this case, we have to make this Config as empty (nullify runtime parameters) + if (utils::one_of(0, M, N, K)) { + m_M = 0; + m_N = 0; + m_K = 0; + m_beta = 0; + } else { + m_M = M; + m_N = N; + m_K = K; + m_beta = beta; + } + // m_hash = compute_hash(); +} + +void BrgemmBaseKernelConfig::update(int64_t M, + int64_t N, + int64_t K, + int64_t LDA, + int64_t LDB, + int64_t LDC, float beta) { // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) // To process this case, we have to make this Config as empty (nullify runtime parameters) @@ -65,7 +79,7 @@ void BrgemmBaseKernelConfig::update(dnnl_dim_t M, } size_t BrgemmBaseKernelConfig::compute_hash() const { - size_t seed = get_static_params()->hash(); + size_t seed = 0; HASH(m_M); HASH(m_N); HASH(m_K); @@ -76,48 +90,12 @@ size_t BrgemmBaseKernelConfig::compute_hash() const { return seed; } -BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, - const element::Type& in1_dtype, - cpu_isa_t primitive_isa, - size_t hash_seed) - : dt_in0(DTYPE_CAST(in0_dtype)), - dt_in1(DTYPE_CAST(in1_dtype)), - isa(primitive_isa), - m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} - -bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { - return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa); -} - -size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, - dnnl_data_type_t dt_in0, - dnnl_data_type_t dt_in1, - cpu_isa_t isa) { - size_t seed = hash_seed; - HASH(dt_in0); - HASH(dt_in1); - HASH(isa); - return seed; -} - #ifdef SNIPPETS_DEBUG_CAPS -std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { - std::stringstream ss; - PRINT(dt_in0); - PRINT(dt_in1); - PRINT(isa); - return ss.str(); -} - std::string BrgemmBaseKernelConfig::to_string() const { std::stringstream ss; - ss << get_static_params()->to_string() << "\n"; PRINT(m_M); PRINT(m_N); PRINT(m_K); - PRINT(m_LDA); - PRINT(m_LDB); - PRINT(m_LDC); PRINT(m_beta); return ss.str(); } @@ -248,89 +226,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); } -#ifndef OPENVINO_ARCH_X86_64 - config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta); - return; -#endif - - const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); - const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); - auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); - - const auto& brgemm_node = as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); - // In case of data repacking LDB is chosen in accordance with repacking buffer size - if (with_repacking(brgemm_node->get_type())) - LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1))); - - config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); -} - -void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr& kernel, - dnnl_data_type_t dt0, - dnnl_data_type_t dt1, - cpu_isa_t isa, - dnnl_dim_t M, - dnnl_dim_t N, - dnnl_dim_t K, - dnnl_dim_t LDA, - dnnl_dim_t LDB, - dnnl_dim_t LDC, - float beta, - bool with_amx, - char* palette) { - cpu::x64::brgemm_desc_t desc; - OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, - isa, - cpu::x64::brgemm_strd, - dt0, - dt1, - false, - false, - cpu::x64::brgemm_row_major, - 1.f, - beta, - LDA, - LDB, - LDC, - M, - N, - K, - nullptr) == dnnl_success, - "Cannot initialize brgemm descriptor due to invalid params"); - - if (with_amx) { - OV_CPU_JIT_EMITTER_ASSERT(palette && brgemm_init_tiles(desc, palette) == dnnl_success, - "Cannot initialize brgemm tiles due to invalid params"); - } - - cpu::x64::brgemm_kernel_t* kernel_ = nullptr; - OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success, - "Cannot create brgemm kernel due to invalid params"); - kernel = std::unique_ptr(kernel_); -} - -void BrgemmBaseKernelExecutor::execute_brgemm_kernel( - const std::shared_ptr& kernel, - const void* src, - const void* wei, - void* dst, - void* scratch, - bool with_comp) { - cpu::x64::brgemm_kernel_params_t brgemm_p; - brgemm_p.batch = nullptr; // default value - brgemm_p.ptr_A = src; - brgemm_p.ptr_B = wei; - brgemm_p.ptr_C = dst; - brgemm_p.ptr_D = dst; - brgemm_p.ptr_buf = scratch; - brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = with_comp; - brgemm_p.do_apply_comp = with_comp; - brgemm_p.skip_accm = 0; - brgemm_p.BS = 1; // default value - OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel"); - (*kernel)(&brgemm_p); + config.update(static_cast(M), static_cast(N), static_cast(K), beta); } #undef DIM_CAST diff --git a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp index a0a55e58df75b7..34ebcfc340dd2a 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp @@ -23,49 +23,35 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf BrgemmBaseKernelConfig() = default; bool is_completed() const override; - size_t hash() const override { - return m_hash; - } - bool is_empty() const; - void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); + + void update(int64_t M, int64_t N, int64_t K, float beta); + void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta); bool operator==(const BrgemmBaseKernelConfig& rhs) const; bool operator!=(const BrgemmBaseKernelConfig& rhs) const { return !(*this == rhs); } - dnnl_data_type_t get_dt_in0() const { - return get_static_params()->dt_in0; - } - dnnl_data_type_t get_dt_in1() const { - return get_static_params()->dt_in1; - } - - dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { - return get_static_params()->isa; - } - float get_beta() const { - return m_beta; - } - - dnnl_dim_t get_M() const { + int64_t get_M() const { return m_M; } - dnnl_dim_t get_N() const { + int64_t get_N() const { return m_N; } - dnnl_dim_t get_K() const { + int64_t get_K() const { return m_K; } - - dnnl_dim_t get_LDA() const { + float get_beta() const { + return m_beta; + } + int64_t get_LDA() const { return m_LDA; } - dnnl_dim_t get_LDB() const { + int64_t get_LDB() const { return m_LDB; } - dnnl_dim_t get_LDC() const { + int64_t get_LDC() const { return m_LDC; } @@ -74,40 +60,9 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf #endif protected: - struct StaticBaseParams { - StaticBaseParams(const element::Type& in0_dtype, - const element::Type& in1_dtype, - dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, - size_t hash_seed); - virtual ~StaticBaseParams() = default; - - const dnnl_data_type_t dt_in0{dnnl_f32}, dt_in1{dnnl_f32}; - const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef}; - - size_t hash() const { - return m_hash; - } - - bool operator==(const StaticBaseParams& rhs) const; - bool operator!=(const StaticBaseParams& rhs) const { - return !(*this == rhs); - } -#ifdef SNIPPETS_DEBUG_CAPS - std::string to_string() const; -#endif - protected: - static size_t compute_hash(size_t hash_seed, - dnnl_data_type_t dt_in0, - dnnl_data_type_t dt_in1, - dnnl::impl::cpu::x64::cpu_isa_t isa); - - const size_t m_hash{0}; - }; - - virtual std::shared_ptr get_static_params() const = 0; size_t compute_hash() const; - dnnl_dim_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0}; + int64_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0}; float m_beta{0}; size_t m_hash{SIZE_MAX}; }; @@ -124,27 +79,6 @@ class BrgemmBaseKernelExecutor { static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmBaseKernelConfig& config); - - static void create_brgemm_kernel(std::shared_ptr& kernel, - dnnl_data_type_t dt0, - dnnl_data_type_t dt1, - dnnl::impl::cpu::x64::cpu_isa_t isa, - dnnl_dim_t M, - dnnl_dim_t N, - dnnl_dim_t K, - dnnl_dim_t LDA, - dnnl_dim_t LDB, - dnnl_dim_t LDC, - float beta, - bool with_amx = false, - char* palette = nullptr); - - static void execute_brgemm_kernel(const std::shared_ptr& kernel, - const void* src, - const void* wei, - void* dst, - void* scratch, - bool with_comp); }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index 58a31a1804782a..8d03c023989d08 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -21,7 +21,7 @@ BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), + : BrgemmBaseKernelConfig_x64(), m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) { m_hash = compute_hash(); } @@ -78,7 +78,7 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const { - return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); + return BrgemmBaseKernelExecutor_x64::update_config(expr, linear_ir, config); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 651a9704c47b05..69c8ca114c7912 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -4,12 +4,12 @@ #pragma once -#include "emitters/snippets/brgemm_base.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { +struct BrgemmKernelConfig : public BrgemmBaseKernelConfig_x64 { public: BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, @@ -59,7 +59,7 @@ struct BrgemmCompiledKernel { std::shared_ptr brgemm_kernel = nullptr; }; -class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor, +class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor_x64, public CPUKernelExecutor { public: struct call_args { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp index 12c52d43b2c4b8..e3d7455ebd4eeb 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp @@ -24,7 +24,7 @@ namespace intel_cpu { BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), + : BrgemmBaseKernelConfig_x64(), m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { m_hash = compute_hash(); } @@ -117,7 +117,7 @@ std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel const auto& cache = m_kernel_cache.lock(); OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in BrgemmAMXKernelExecutor::compile_kernel()"); - auto brgemm_key = [&config](dnnl_dim_t K, dnnl_dim_t LDA, float beta) { + auto brgemm_key = [&config](int64_t K, int64_t LDA, float beta) { auto key = config; key.update(config.get_M(), config.get_N(), K, LDA, config.get_LDB(), config.get_LDC(), beta); return key; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp index 43bf38f1930a64..a785eca8e7069e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp @@ -11,11 +11,12 @@ #include "emitters/snippets/brgemm_base.hpp" #include "emitters/snippets/cpu_kernel_executor_table.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig { +struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig_x64 { public: BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, @@ -75,7 +76,7 @@ struct BrgemmAMXCompiledKernel { std::shared_ptr brgemm_copy_a_kernel{nullptr}; }; -class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor, +class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor_x64, public CPUKernelExecutor { public: struct call_args { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp new file mode 100644 index 00000000000000..f3f1bf3f7e6628 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp @@ -0,0 +1,175 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_base.hpp" + +#include "common/utils.hpp" +#include "dnnl_extension_utils.h" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" + +#define DIM_CAST(X) static_cast(X) +#define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) +#define PRINT(X) ss << #X << " = " << X << "\n" +#define EQ(X) X == rhs.X +#define HASH(X) seed = hash_combine(seed, X) + +using namespace Xbyak; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { + +bool BrgemmBaseKernelConfig_x64::operator==(const BrgemmBaseKernelConfig_x64& rhs) const { + return BrgemmBaseKernelConfig::operator==(rhs) && + (EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params())); +} + +size_t BrgemmBaseKernelConfig_x64::compute_hash() const { + size_t seed = get_static_params()->hash(); + HASH(BrgemmBaseKernelConfig::compute_hash()); + return seed; +} + +BrgemmBaseKernelConfig_x64::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + cpu_isa_t primitive_isa, + size_t hash_seed) + : dt_in0(DTYPE_CAST(in0_dtype)), + dt_in1(DTYPE_CAST(in1_dtype)), + isa(primitive_isa), + m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} + +bool BrgemmBaseKernelConfig_x64::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { + return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa); +} + +size_t BrgemmBaseKernelConfig_x64::StaticBaseParams::compute_hash(size_t hash_seed, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + cpu_isa_t isa) { + size_t seed = hash_seed; + HASH(dt_in0); + HASH(dt_in1); + HASH(isa); + return seed; +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmBaseKernelConfig_x64::StaticBaseParams::to_string() const { + std::stringstream ss; + PRINT(dt_in0); + PRINT(dt_in1); + PRINT(isa); + return ss.str(); +} + +std::string BrgemmBaseKernelConfig_x64::to_string() const { + std::stringstream ss; + ss << get_static_params()->to_string() << "\n"; + PRINT(m_M); + PRINT(m_N); + PRINT(m_K); + PRINT(m_LDA); + PRINT(m_LDB); + PRINT(m_LDC); + PRINT(m_beta); + return ss.str(); +} +#endif + +void BrgemmBaseKernelExecutor_x64::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig_x64& config) { + BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); + + const auto LDA = snippets::utils::get_dim_stride(expr->get_input_port(0)); + const auto LDC = snippets::utils::get_dim_stride(expr->get_output_port(0)); + auto LDB = snippets::utils::get_dim_stride(expr->get_input_port(1)); + + const auto& brgemm_node = as_type_ptr(expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); + // In case of data repacking LDB is chosen in accordance with repacking buffer size + if (with_repacking(brgemm_node->get_type())) + LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1))); + + config.update(config.get_M(), config.get_N(), config.get_K(), LDA, LDB, LDC, config.get_beta()); +} + +void BrgemmBaseKernelExecutor_x64::create_brgemm_kernel(std::shared_ptr& kernel, + dnnl_data_type_t dt0, + dnnl_data_type_t dt1, + cpu_isa_t isa, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta, + bool with_amx, + char* palette) { + cpu::x64::brgemm_desc_t desc; + OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, + isa, + cpu::x64::brgemm_strd, + dt0, + dt1, + false, + false, + cpu::x64::brgemm_row_major, + 1.f, + beta, + LDA, + LDB, + LDC, + M, + N, + K, + nullptr) == dnnl_success, + "Cannot initialize brgemm descriptor due to invalid params"); + + if (with_amx) { + OV_CPU_JIT_EMITTER_ASSERT(palette && brgemm_init_tiles(desc, palette) == dnnl_success, + "Cannot initialize brgemm tiles due to invalid params"); + } + + cpu::x64::brgemm_kernel_t* kernel_ = nullptr; + OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success, + "Cannot create brgemm kernel due to invalid params"); + kernel = std::unique_ptr(kernel_); +} + +void BrgemmBaseKernelExecutor_x64::execute_brgemm_kernel( + const std::shared_ptr& kernel, + const void* src, + const void* wei, + void* dst, + void* scratch, + bool with_comp) { + cpu::x64::brgemm_kernel_params_t brgemm_p; + brgemm_p.batch = nullptr; // default value + brgemm_p.ptr_A = src; + brgemm_p.ptr_B = wei; + brgemm_p.ptr_C = dst; + brgemm_p.ptr_D = dst; + brgemm_p.ptr_buf = scratch; + brgemm_p.ptr_bias = nullptr; + brgemm_p.do_post_ops = with_comp; + brgemm_p.do_apply_comp = with_comp; + brgemm_p.skip_accm = 0; + brgemm_p.BS = 1; // default value + OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel"); + (*kernel)(&brgemm_p); +} + +#undef DIM_CAST +#undef DTYPE_CAST +#undef PRINT +#undef EQ +#undef HASH + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp new file mode 100644 index 00000000000000..24869089a3e43f --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp @@ -0,0 +1,117 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "cpu/x64/cpu_isa_traits.hpp" +#include "emitters/snippets/brgemm_base.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/utils.hpp" +#include "openvino/core/type/element_type.hpp" +#include "snippets/lowered/loop_info.hpp" +#include "snippets/lowered/loop_manager.hpp" +#include "utils/general_utils.h" + +namespace ov { +namespace intel_cpu { + +struct BrgemmBaseKernelConfig_x64 : public BrgemmBaseKernelConfig { +public: + BrgemmBaseKernelConfig_x64() = default; + + size_t hash() const override { + return m_hash; + } + + bool operator==(const BrgemmBaseKernelConfig_x64& rhs) const; + bool operator!=(const BrgemmBaseKernelConfig_x64& rhs) const { + return !(*this == rhs); + } + + dnnl_data_type_t get_dt_in0() const { + return get_static_params()->dt_in0; + } + dnnl_data_type_t get_dt_in1() const { + return get_static_params()->dt_in1; + } + + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { + return get_static_params()->isa; + } + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override; +#endif + +protected: + struct StaticBaseParams { + StaticBaseParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, + size_t hash_seed); + virtual ~StaticBaseParams() = default; + + const dnnl_data_type_t dt_in0{dnnl_f32}, dt_in1{dnnl_f32}; + const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef}; + + size_t hash() const { + return m_hash; + } + + bool operator==(const StaticBaseParams& rhs) const; + bool operator!=(const StaticBaseParams& rhs) const { + return !(*this == rhs); + } +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + protected: + static size_t compute_hash(size_t hash_seed, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + dnnl::impl::cpu::x64::cpu_isa_t isa); + + const size_t m_hash{0}; + }; + + virtual std::shared_ptr get_static_params() const = 0; + size_t compute_hash() const; +}; + +class BrgemmBaseKernelExecutor_x64 : public BrgemmBaseKernelExecutor { +public: + virtual ~BrgemmBaseKernelExecutor_x64() = default; + +protected: + static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig_x64& config); + + static void create_brgemm_kernel(std::shared_ptr& kernel, + dnnl_data_type_t dt0, + dnnl_data_type_t dt1, + dnnl::impl::cpu::x64::cpu_isa_t isa, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta, + bool with_amx = false, + char* palette = nullptr); + + static void execute_brgemm_kernel(const std::shared_ptr& kernel, + const void* src, + const void* wei, + void* dst, + void* scratch, + bool with_comp); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp index ac57ebbad42ab7..22764248442000 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp @@ -7,6 +7,7 @@ #include "snippets/utils/utils.hpp" #include "transformations/tpp/common/op/brgemm.hpp" +using namespace ov::intel_cpu::tpp; using namespace Xbyak_aarch64; namespace ov { @@ -27,7 +28,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, const auto& brgemm_node = as_type_ptr(expr->get_node()); const auto& brg0Prc = brgemm_node->get_input_element_type(0); const auto& brg1Prc = brgemm_node->get_input_element_type(1); - BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa); + BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc); m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); } diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp index d98e97800e4b6e..855771a702b6f7 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp @@ -5,7 +5,7 @@ #pragma once #include "emitters/plugin/aarch64/jit_emitter.hpp" -#include "emitters/tpp/aarch64/kernel_executors/brgemm.hpp" +#include "emitters/tpp/common/kernel_executors/brgemm.hpp" namespace ov { namespace intel_cpu { @@ -38,7 +38,7 @@ class jit_brgemm_emitter : public jit_emitter { const uintptr_t get_execute_function_ptr() const; const uintptr_t get_compiled_kernel_ptr() const; - std::shared_ptr m_kernel_executor = nullptr; + std::shared_ptr m_kernel_executor = nullptr; }; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp similarity index 83% rename from src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.cpp rename to src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp index ae3a77022f3a31..c1fa163b23b6ea 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp @@ -12,7 +12,7 @@ using namespace dnnl::impl; namespace ov { namespace intel_cpu { -namespace aarch64 { +namespace tpp { #define COMPILE_BRGEMM_TPP_KERNEL(...) \ [&]() { \ setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \ @@ -23,35 +23,36 @@ namespace aarch64 { return res; \ }() -BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, - const element::Type& in1_dtype, - dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa) +BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype) : BrgemmBaseKernelConfig(), - m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { + m_static_params(std::make_shared(in0_dtype, in1_dtype)) { m_hash = compute_hash(); } -BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, - const element::Type& in1_dtype, - dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa) - : StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)) { +size_t BrgemmKernelConfig::compute_hash() const { + size_t seed = get_static_params()->hash(); + return hash_combine(seed, BrgemmBaseKernelConfig::compute_hash()); +} + +BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype) { m_type_in0 = tpp::ov_to_xsmm_dtype(in0_dtype); m_type_in1 = tpp::ov_to_xsmm_dtype(in1_dtype); m_type_exec = LIBXSMM_DATATYPE_F32; m_type_out0 = LIBXSMM_DATATYPE_F32; m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); m_prefetching_flags = false; - isa = primitive_isa; + m_hash = compute_hash(in0_dtype, in1_dtype); } -size_t BrgemmKernelConfig::StaticParams::compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa) { - return hash_combine(0, aarch_isa); +size_t BrgemmKernelConfig::StaticParams::compute_hash(const element::Type& in0_dtype, const element::Type& in1_dtype) { + return hash_combine(hash_combine(0, in0_dtype), in1_dtype); } + bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { - return StaticBaseParams::operator==(rhs) && isa == rhs.isa && m_type_in0 == rhs.m_type_in0 && - m_type_in1 == rhs.m_type_in1 && m_type_exec == rhs.m_type_exec && m_type_out0 == rhs.m_type_out0 && - m_compile_flags == rhs.m_compile_flags && m_prefetching_flags == rhs.m_prefetching_flags; + return m_type_in0 == rhs.m_type_in0 && m_type_in1 == rhs.m_type_in1 && m_type_exec == rhs.m_type_exec && + m_type_out0 == rhs.m_type_out0 && m_compile_flags == rhs.m_compile_flags && + m_prefetching_flags == rhs.m_prefetching_flags; } BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) @@ -97,7 +98,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression const auto num_ins = expr->get_node()->get_input_size(); const auto num_outs = expr->get_node()->get_output_size(); - size_t io_strides[num_ins + num_outs]; + std::vector io_strides(num_ins + num_outs); for (size_t i = 0; i < num_ins; i++) { io_strides[i] = @@ -133,6 +134,6 @@ void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* i (*(brg_kernel->brgemm_kernel))(&gemm_p); } -} // namespace aarch64 +} // namespace tpp } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp similarity index 78% rename from src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.hpp rename to src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp index aa54e57cc178df..ed5efa90451ab0 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp @@ -11,29 +11,22 @@ namespace ov { namespace intel_cpu { -namespace aarch64 { +namespace tpp { struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { public: - BrgemmKernelConfig( - const element::Type& in0_dtype, - const element::Type& in1_dtype, - dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa = dnnl::impl::cpu::aarch64::cpu_isa_t::isa_undef); + BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype); BrgemmKernelConfig() = delete; std::unique_ptr get_clone_ptr() const override { return std::unique_ptr(new BrgemmKernelConfig(*this)); } - dnnl_data_type_t get_dt_in0() const { - return get_static_params()->dt_in0; - } - dnnl_data_type_t get_dt_in1() const { - return get_static_params()->dt_in1; - } - dnnl::impl::cpu::aarch64::cpu_isa_t get_isa() const { - return m_static_params->isa; + size_t hash() const override { + return m_hash; } + size_t compute_hash() const; + libxsmm_bitfield get_static_compile_flags() const { return m_static_params->m_compile_flags; } @@ -64,32 +57,36 @@ struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { } private: - struct StaticParams : public StaticBaseParams { - StaticParams(const element::Type& in0_dtype, - const element::Type& in1_dtype, - dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa); + struct StaticParams { + StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype); virtual ~StaticParams() = default; bool operator==(const StaticParams& rhs) const; bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } - size_t compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa); + size_t hash() const { + return m_hash; + } + size_t compute_hash(const element::Type& in0_dtype, const element::Type& in1_dtype); - dnnl::impl::cpu::aarch64::cpu_isa_t isa; libxsmm_datatype m_type_in0; libxsmm_datatype m_type_in1; libxsmm_datatype m_type_out0; libxsmm_datatype m_type_exec; libxsmm_bitfield m_compile_flags; bool m_prefetching_flags; + + size_t m_hash{SIZE_MAX}; }; - std::shared_ptr get_static_params() const override { + std::shared_ptr get_static_params() const { return m_static_params; } libxsmm_bitfield m_compile_flags{0}; std::shared_ptr m_static_params{nullptr}; + + size_t m_hash{SIZE_MAX}; }; // The `update_kernel` method verifies that a compiled kernel is not nullptr. @@ -118,6 +115,6 @@ class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor, }; #define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) -} // namespace aarch64 +} // namespace tpp } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index e4a080e2e94fa4..6d9d189e7ee744 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -465,13 +465,10 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { #endif // OPENVINO_ARCH_X86_64 #if defined(OPENVINO_ARCH_ARM64) -# define SNIPPETS_REGISTER_PASS_ABSOLUTE_ARM64(PASS_PLACE, PASS, ...) \ - backend_passes.emplace_back(PassPosition(PASS_PLACE), std::make_shared(__VA_ARGS__)) # define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) \ backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \ std::make_shared(__VA_ARGS__)) #else -# define SNIPPETS_REGISTER_PASS_ABSOLUTE_ARM64(PASS_PLACE, PASS, ...) # define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) #endif // OPENVINO_ARCH_ARM64 @@ -523,7 +520,6 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { #undef SNIPPETS_REGISTER_PASS_RELATIVE_COMMON #undef SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64 #undef SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 -#undef SNIPPETS_REGISTER_PASS_ABSOLUTE_ARM64 #undef SNIPPETS_REGISTER_PASS_RELATIVE_ARM64 return backend_passes;