Skip to content

Commit

Permalink
[Snippets][CPU] Disable MHA tokenization in LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 22, 2025
1 parent d757efd commit 08aeea7
Showing 1 changed file with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1031,18 +1031,34 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
#if defined(OPENVINO_ARCH_X86_64)
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
// We consider the presence of `ScaledDotProductAttentionWithKVCache` op in the model as a sign that this model is
// LLM.
const auto is_LLM = [this]() {
// Note: the variable `ops` should not exist during `SnippetsTokenization` execution.
// Otherwise, it will extend the life time of ops (since they're stored as shared ptrs) and
// they will be visible in the model during the tokenization passes even after removing or replacing.
const auto ops = model->get_ops();
return std::any_of(ops.cbegin(), ops.cend(), [](const std::shared_ptr<ov::Node>& op) {
return ov::is_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(op);
});
}();

// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const auto is_infer_prc_supported_by_MHA =
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA;
#else
const bool isMHASupported = false;
#endif

if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);
Expand Down

0 comments on commit 08aeea7

Please sign in to comment.