Skip to content

Commit

Permalink
[Snippets][CPU] Disable VNNI requirement for i8 brgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Jan 23, 2025
1 parent 8dde87c commit a37faf9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {
} else { \
Y \
}
#define SUPPORT_ONE(X, MESSAGE) SUPPORT(X, OV_CPU_JIT_EMITTER_THROW(MESSAGE);)
#define SUPPORT_TWO(X, Y, MESSAGE) SUPPORT(X, SUPPORT_ONE(Y, MESSAGE))
#define SUPPORT_THREE(X, Y, Z, MESSAGE) SUPPORT(X, SUPPORT_TWO(Y, Z, MESSAGE))
#define SUPPORT_ONE(X, MESSAGE) SUPPORT(X, OV_CPU_JIT_EMITTER_THROW(MESSAGE);)
#define SUPPORT_TWO(X, Y, MESSAGE) SUPPORT(X, SUPPORT_ONE(Y, MESSAGE))
#define SUPPORT_THREE(X, Y, Z, MESSAGE) SUPPORT(X, SUPPORT_TWO(Y, Z, MESSAGE))
#define SUPPORT_FOUR(A, B, C, D, MESSAGE) SUPPORT(A, SUPPORT_THREE(B, C, D, MESSAGE))

// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
if (is_with_amx) {
Expand All @@ -44,10 +45,11 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {
} else if (dt_in0 == ov::element::bf16) {
SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms")
} else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) {
SUPPORT_THREE(avx512_core_vnni,
avx2_vnni_2,
avx2_vnni,
"Unsupported hardware configuration: int8 is supported only on vnni platforms")
SUPPORT_FOUR(avx512_core,
avx512_core_vnni,
avx2_vnni_2,
avx2_vnni,
"Unsupported hardware configuration: int8 is supported only on vnni platforms")
} else {
SUPPORT_TWO(avx512_core,
cpu::x64::avx2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1063,21 +1063,22 @@ void Transformations::MainSnippets(void) {
const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 &&
config.inferencePrecision == ov::element::bf16));
const auto is_int8 = in_type0 == ov::element::i8;
const auto is_int8 = (in_type0 == element::i8 || in_type0 == element::u8) && (in_type1 == element::i8);
if (matmul->get_transpose_a())
return false;
if (is_fp32)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2);
if (is_int8)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni);
if (is_bf16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
if (is_fp16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16);
return true;
return false;
};
auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr<const ov::Node>& n,
const ov::PartialShape& shape) {
Expand Down

0 comments on commit a37faf9

Please sign in to comment.