Skip to content

Commit

Permalink
apply Alexandra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 24, 2025
1 parent f336769 commit 654635d
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 264 deletions.
2 changes: 1 addition & 1 deletion cmake/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND (X86_64 OR AARCH64)" OFF)

ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)

Expand Down
156 changes: 26 additions & 130 deletions src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
#define DTYPE_CAST(X) static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(X))
Expand All @@ -31,16 +29,32 @@ bool BrgemmBaseKernelConfig::is_empty() const {
}

bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const {
return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) &&
(EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params()));
return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC);
}

void BrgemmBaseKernelConfig::update(dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
void BrgemmBaseKernelConfig::update(int64_t M, int64_t N, int64_t K, float beta) {
// If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example)
// To process this case, we have to make this Config as empty (nullify runtime parameters)
if (utils::one_of(0, M, N, K)) {
m_M = 0;
m_N = 0;
m_K = 0;
m_beta = 0;
} else {
m_M = M;
m_N = N;
m_K = K;
m_beta = beta;
}
// m_hash = compute_hash();
}

void BrgemmBaseKernelConfig::update(int64_t M,
int64_t N,
int64_t K,
int64_t LDA,
int64_t LDB,
int64_t LDC,
float beta) {
// If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example)
// To process this case, we have to make this Config as empty (nullify runtime parameters)
Expand All @@ -65,7 +79,7 @@ void BrgemmBaseKernelConfig::update(dnnl_dim_t M,
}

size_t BrgemmBaseKernelConfig::compute_hash() const {
size_t seed = get_static_params()->hash();
size_t seed = 0;
HASH(m_M);
HASH(m_N);
HASH(m_K);
Expand All @@ -76,48 +90,12 @@ size_t BrgemmBaseKernelConfig::compute_hash() const {
return seed;
}

BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
cpu_isa_t primitive_isa,
size_t hash_seed)
: dt_in0(DTYPE_CAST(in0_dtype)),
dt_in1(DTYPE_CAST(in1_dtype)),
isa(primitive_isa),
m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {}

bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const {
return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa);
}

size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed,
dnnl_data_type_t dt_in0,
dnnl_data_type_t dt_in1,
cpu_isa_t isa) {
size_t seed = hash_seed;
HASH(dt_in0);
HASH(dt_in1);
HASH(isa);
return seed;
}

#ifdef SNIPPETS_DEBUG_CAPS
std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const {
std::stringstream ss;
PRINT(dt_in0);
PRINT(dt_in1);
PRINT(isa);
return ss.str();
}

std::string BrgemmBaseKernelConfig::to_string() const {
std::stringstream ss;
ss << get_static_params()->to_string() << "\n";
PRINT(m_M);
PRINT(m_N);
PRINT(m_K);
PRINT(m_LDA);
PRINT(m_LDB);
PRINT(m_LDC);
PRINT(m_beta);
return ss.str();
}
Expand Down Expand Up @@ -248,89 +226,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

#ifndef OPENVINO_ARCH_X86_64
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta);
return;
#endif

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr<brgemm_kernel_t>& kernel,
dnnl_data_type_t dt0,
dnnl_data_type_t dt1,
cpu_isa_t isa,
dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
float beta,
bool with_amx,
char* palette) {
cpu::x64::brgemm_desc_t desc;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc,
isa,
cpu::x64::brgemm_strd,
dt0,
dt1,
false,
false,
cpu::x64::brgemm_row_major,
1.f,
beta,
LDA,
LDB,
LDC,
M,
N,
K,
nullptr) == dnnl_success,
"Cannot initialize brgemm descriptor due to invalid params");

if (with_amx) {
OV_CPU_JIT_EMITTER_ASSERT(palette && brgemm_init_tiles(desc, palette) == dnnl_success,
"Cannot initialize brgemm tiles due to invalid params");
}

cpu::x64::brgemm_kernel_t* kernel_ = nullptr;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success,
"Cannot create brgemm kernel due to invalid params");
kernel = std::unique_ptr<brgemm_kernel_t>(kernel_);
}

void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
const void* src,
const void* wei,
void* dst,
void* scratch,
bool with_comp) {
cpu::x64::brgemm_kernel_params_t brgemm_p;
brgemm_p.batch = nullptr; // default value
brgemm_p.ptr_A = src;
brgemm_p.ptr_B = wei;
brgemm_p.ptr_C = dst;
brgemm_p.ptr_D = dst;
brgemm_p.ptr_buf = scratch;
brgemm_p.ptr_bias = nullptr;
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;
brgemm_p.BS = 1; // default value
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
config.update(static_cast<int64_t>(M), static_cast<int64_t>(N), static_cast<int64_t>(K), beta);
}

#undef DIM_CAST
Expand Down
92 changes: 13 additions & 79 deletions src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,35 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf
BrgemmBaseKernelConfig() = default;

bool is_completed() const override;
size_t hash() const override {
return m_hash;
}

bool is_empty() const;
void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta);

void update(int64_t M, int64_t N, int64_t K, float beta);
void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta);

bool operator==(const BrgemmBaseKernelConfig& rhs) const;
bool operator!=(const BrgemmBaseKernelConfig& rhs) const {
return !(*this == rhs);
}

dnnl_data_type_t get_dt_in0() const {
return get_static_params()->dt_in0;
}
dnnl_data_type_t get_dt_in1() const {
return get_static_params()->dt_in1;
}

dnnl::impl::cpu::x64::cpu_isa_t get_isa() const {
return get_static_params()->isa;
}
float get_beta() const {
return m_beta;
}

dnnl_dim_t get_M() const {
int64_t get_M() const {
return m_M;
}
dnnl_dim_t get_N() const {
int64_t get_N() const {
return m_N;
}
dnnl_dim_t get_K() const {
int64_t get_K() const {
return m_K;
}

dnnl_dim_t get_LDA() const {
float get_beta() const {
return m_beta;
}
int64_t get_LDA() const {
return m_LDA;
}
dnnl_dim_t get_LDB() const {
int64_t get_LDB() const {
return m_LDB;
}
dnnl_dim_t get_LDC() const {
int64_t get_LDC() const {
return m_LDC;
}

Expand All @@ -74,40 +60,9 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf
#endif

protected:
struct StaticBaseParams {
StaticBaseParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa,
size_t hash_seed);
virtual ~StaticBaseParams() = default;

const dnnl_data_type_t dt_in0{dnnl_f32}, dt_in1{dnnl_f32};
const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef};

size_t hash() const {
return m_hash;
}

bool operator==(const StaticBaseParams& rhs) const;
bool operator!=(const StaticBaseParams& rhs) const {
return !(*this == rhs);
}
#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const;
#endif
protected:
static size_t compute_hash(size_t hash_seed,
dnnl_data_type_t dt_in0,
dnnl_data_type_t dt_in1,
dnnl::impl::cpu::x64::cpu_isa_t isa);

const size_t m_hash{0};
};

virtual std::shared_ptr<StaticBaseParams> get_static_params() const = 0;
size_t compute_hash() const;

dnnl_dim_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0};
int64_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0};
float m_beta{0};
size_t m_hash{SIZE_MAX};
};
Expand All @@ -124,27 +79,6 @@ class BrgemmBaseKernelExecutor {
static void update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmBaseKernelConfig& config);

static void create_brgemm_kernel(std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
dnnl_data_type_t dt0,
dnnl_data_type_t dt1,
dnnl::impl::cpu::x64::cpu_isa_t isa,
dnnl_dim_t M,
dnnl_dim_t N,
dnnl_dim_t K,
dnnl_dim_t LDA,
dnnl_dim_t LDB,
dnnl_dim_t LDC,
float beta,
bool with_amx = false,
char* palette = nullptr);

static void execute_brgemm_kernel(const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
const void* src,
const void* wei,
void* dst,
void* scratch,
bool with_comp);
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
bool is_with_comp,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
: BrgemmBaseKernelConfig_x64(),
m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) {
m_hash = compute_hash();
}
Expand Down Expand Up @@ -78,7 +78,7 @@ std::shared_ptr<BrgemmCompiledKernel> BrgemmKernelExecutor::compile_kernel(const
void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmKernelConfig& config) const {
return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config);
return BrgemmBaseKernelExecutor_x64::update_config(expr, linear_ir, config);
}

void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

#pragma once

#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp"

namespace ov {
namespace intel_cpu {

struct BrgemmKernelConfig : public BrgemmBaseKernelConfig {
struct BrgemmKernelConfig : public BrgemmBaseKernelConfig_x64 {
public:
BrgemmKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
Expand Down Expand Up @@ -59,7 +59,7 @@ struct BrgemmCompiledKernel {
std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgemm_kernel = nullptr;
};

class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor,
class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor_x64,
public CPUKernelExecutor<BrgemmKernelConfig, BrgemmCompiledKernel> {
public:
struct call_args {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace intel_cpu {
BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
: BrgemmBaseKernelConfig_x64(),
m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
m_hash = compute_hash();
}
Expand Down Expand Up @@ -117,7 +117,7 @@ std::shared_ptr<BrgemmAMXCompiledKernel> BrgemmAMXKernelExecutor::compile_kernel
const auto& cache = m_kernel_cache.lock();
OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in BrgemmAMXKernelExecutor::compile_kernel()");

auto brgemm_key = [&config](dnnl_dim_t K, dnnl_dim_t LDA, float beta) {
auto brgemm_key = [&config](int64_t K, int64_t LDA, float beta) {
auto key = config;
key.update(config.get_M(), config.get_N(), K, LDA, config.get_LDB(), config.get_LDC(), beta);
return key;
Expand Down
Loading

0 comments on commit 654635d

Please sign in to comment.