Skip to content

Commit

Permalink
apply Alexandra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 26, 2025
1 parent 9a52cf2 commit cebafe9
Show file tree
Hide file tree
Showing 14 changed files with 425 additions and 288 deletions.
2 changes: 1 addition & 1 deletion cmake/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND (X86_64 OR AARCH64)" OFF)

ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)

Expand Down
174 changes: 32 additions & 142 deletions src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,58 @@
// Copyright (C) 2020-2024 Intel Corporation
// Copyright (C) 2020-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_base.hpp"

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "utils/general_utils.h"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
#define DTYPE_CAST(X) static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(X))
#define PRINT(X) ss << #X << " = " << X << "\n"
#define EQ(X) X == rhs.X
#define HASH(X) seed = hash_combine(seed, X)

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;
#define PRINT(X) ss << #X << " = " << X << "\n"
#define EQ(X) X == rhs.X
#define HASH(X) seed = dnnl::impl::hash_combine(seed, X)

namespace ov {
namespace intel_cpu {

bool BrgemmBaseKernelConfig::is_completed() const {
return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty();
return !one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty();
}

bool BrgemmBaseKernelConfig::is_empty() const {
return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta);
}

bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const {
return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) &&
(EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params()));
return EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC);
}

void BrgemmBaseKernelConfig::update(int64_t M, int64_t N, int64_t K, float beta) {
// If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example)
// To process this case, we have to make this Config as empty (nullify runtime parameters)
if (one_of(0, M, N, K)) {
m_M = 0;
m_N = 0;
m_K = 0;
m_beta = 0;
} else {
m_M = M;
m_N = N;
m_K = K;
m_beta = beta;
}
}

void BrgemmBaseKernelConfig::update(dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
void BrgemmBaseKernelConfig::update(int64_t M,
int64_t N,
int64_t K,
int64_t LDA,
int64_t LDB,
int64_t LDC,
float beta) {
// If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example)
// To process this case, we have to make this Config as empty (nullify runtime parameters)
if (utils::one_of(0, M, N, K)) {
if (one_of(0, M, N, K)) {
m_M = 0;
m_N = 0;
m_K = 0;
Expand All @@ -61,11 +69,10 @@ void BrgemmBaseKernelConfig::update(dnnl_dim_t M,
m_LDC = LDC;
m_beta = beta;
}
m_hash = compute_hash();
}

size_t BrgemmBaseKernelConfig::compute_hash() const {
size_t seed = get_static_params()->hash();
size_t seed = 0;
HASH(m_M);
HASH(m_N);
HASH(m_K);
Expand All @@ -76,42 +83,9 @@ size_t BrgemmBaseKernelConfig::compute_hash() const {
return seed;
}

BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
cpu_isa_t primitive_isa,
size_t hash_seed)
: dt_in0(DTYPE_CAST(in0_dtype)),
dt_in1(DTYPE_CAST(in1_dtype)),
isa(primitive_isa),
m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {}

bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const {
return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa);
}

size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed,
dnnl_data_type_t dt_in0,
dnnl_data_type_t dt_in1,
cpu_isa_t isa) {
size_t seed = hash_seed;
HASH(dt_in0);
HASH(dt_in1);
HASH(isa);
return seed;
}

#ifdef SNIPPETS_DEBUG_CAPS
std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const {
std::stringstream ss;
PRINT(dt_in0);
PRINT(dt_in1);
PRINT(isa);
return ss.str();
}

std::string BrgemmBaseKernelConfig::to_string() const {
std::stringstream ss;
ss << get_static_params()->to_string() << "\n";
PRINT(m_M);
PRINT(m_N);
PRINT(m_K);
Expand Down Expand Up @@ -248,93 +222,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

#ifndef OPENVINO_ARCH_X86_64
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta);
return;
#endif

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr<brgemm_kernel_t>& kernel,
dnnl_data_type_t dt0,
dnnl_data_type_t dt1,
cpu_isa_t isa,
dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
float beta,
bool with_amx,
char* palette) {
cpu::x64::brgemm_desc_t desc;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc,
isa,
cpu::x64::brgemm_strd,
dt0,
dt1,
false,
false,
cpu::x64::brgemm_row_major,
1.f,
beta,
LDA,
LDB,
LDC,
M,
N,
K,
nullptr) == dnnl_success,
"Cannot initialize brgemm descriptor due to invalid params");

if (with_amx) {
OV_CPU_JIT_EMITTER_ASSERT(palette && brgemm_init_tiles(desc, palette) == dnnl_success,
"Cannot initialize brgemm tiles due to invalid params");
}

cpu::x64::brgemm_kernel_t* kernel_ = nullptr;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success,
"Cannot create brgemm kernel due to invalid params");
kernel = std::unique_ptr<brgemm_kernel_t>(kernel_);
}

void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
const void* src,
const void* wei,
void* dst,
void* scratch,
bool with_comp) {
cpu::x64::brgemm_kernel_params_t brgemm_p;
brgemm_p.batch = nullptr; // default value
brgemm_p.ptr_A = src;
brgemm_p.ptr_B = wei;
brgemm_p.ptr_C = dst;
brgemm_p.ptr_D = dst;
brgemm_p.ptr_buf = scratch;
brgemm_p.ptr_bias = nullptr;
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;
brgemm_p.BS = 1; // default value
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
config.update(static_cast<int64_t>(M), static_cast<int64_t>(N), static_cast<int64_t>(K), beta);
}

#undef DIM_CAST
#undef DTYPE_CAST
#undef PRINT
#undef EQ
#undef HASH
Expand Down
99 changes: 14 additions & 85 deletions src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
// Copyright (C) 2020-2024 Intel Corporation
// Copyright (C) 2020-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <cpu/x64/brgemm/brgemm.hpp>

#include "cpu/x64/cpu_isa_traits.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/lowered/loop_info.hpp"
Expand All @@ -23,49 +19,35 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf
BrgemmBaseKernelConfig() = default;

bool is_completed() const override;
size_t hash() const override {
return m_hash;
}

bool is_empty() const;
void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta);

void update(int64_t M, int64_t N, int64_t K, float beta);
void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta);

bool operator==(const BrgemmBaseKernelConfig& rhs) const;
bool operator!=(const BrgemmBaseKernelConfig& rhs) const {
return !(*this == rhs);
}

dnnl_data_type_t get_dt_in0() const {
return get_static_params()->dt_in0;
}
dnnl_data_type_t get_dt_in1() const {
return get_static_params()->dt_in1;
}

dnnl::impl::cpu::x64::cpu_isa_t get_isa() const {
return get_static_params()->isa;
}
float get_beta() const {
return m_beta;
}

dnnl_dim_t get_M() const {
int64_t get_M() const {
return m_M;
}
dnnl_dim_t get_N() const {
int64_t get_N() const {
return m_N;
}
dnnl_dim_t get_K() const {
int64_t get_K() const {
return m_K;
}

dnnl_dim_t get_LDA() const {
float get_beta() const {
return m_beta;
}
int64_t get_LDA() const {
return m_LDA;
}
dnnl_dim_t get_LDB() const {
int64_t get_LDB() const {
return m_LDB;
}
dnnl_dim_t get_LDC() const {
int64_t get_LDC() const {
return m_LDC;
}

Expand All @@ -74,42 +56,10 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf
#endif

protected:
struct StaticBaseParams {
StaticBaseParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa,
size_t hash_seed);
virtual ~StaticBaseParams() = default;

const dnnl_data_type_t dt_in0{dnnl_f32}, dt_in1{dnnl_f32};
const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef};

size_t hash() const {
return m_hash;
}

bool operator==(const StaticBaseParams& rhs) const;
bool operator!=(const StaticBaseParams& rhs) const {
return !(*this == rhs);
}
#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const;
#endif
protected:
static size_t compute_hash(size_t hash_seed,
dnnl_data_type_t dt_in0,
dnnl_data_type_t dt_in1,
dnnl::impl::cpu::x64::cpu_isa_t isa);

const size_t m_hash{0};
};

virtual std::shared_ptr<StaticBaseParams> get_static_params() const = 0;
size_t compute_hash() const;

dnnl_dim_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0};
int64_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0};
float m_beta{0};
size_t m_hash{SIZE_MAX};
};

class BrgemmBaseKernelExecutor {
Expand All @@ -124,27 +74,6 @@ class BrgemmBaseKernelExecutor {
static void update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmBaseKernelConfig& config);

static void create_brgemm_kernel(std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
dnnl_data_type_t dt0,
dnnl_data_type_t dt1,
dnnl::impl::cpu::x64::cpu_isa_t isa,
dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
float beta,
bool with_amx = false,
char* palette = nullptr);

static void execute_brgemm_kernel(const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
const void* src,
const void* wei,
void* dst,
void* scratch,
bool with_comp);
};

} // namespace intel_cpu
Expand Down
Loading

0 comments on commit cebafe9

Please sign in to comment.