Skip to content

Commit

Permalink
[Snippets][CPU] Disable MHA tokenization in LLM (#28601)
Browse files Browse the repository at this point in the history
### Details:
- *The second inference in LLM is usually single token inference. It
means that `M` dimension of MatMuls in SDPA pattern will have the value
`1` (during compilation model this dimension is dynamic (unknown)).
Snippets cannot provide efficient execution for single token inference.
So we decided to disable MHA tokenization by Snippets in CPU Plugin on
LLMs'. We consider the presence of
`ScaledDotProductAttentionWithKVCache` op in the model as a sign that
this model is LLM.*

### Tickets:
 - *160634*
 - *160978*

### TODO:
- [x] Performance validation on LLMs (the results are in the ticket
CVS-160978)
  • Loading branch information
a-sidorova authored Jan 24, 2025
1 parent 9232859 commit 62e8e08
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1031,18 +1031,27 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
#if defined(OPENVINO_ARCH_X86_64)
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
// We consider the presence of `ScaledDotProductAttentionWithKVCache` and `PagedAttentionExtension` ops
// in the model as a sign that this model is LLM.
const auto is_LLM = ov::op::util::has_op_with_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(model) ||
ov::op::util::has_op_with_type<ov::op::PagedAttentionExtension>(model);

// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const auto is_infer_prc_supported_by_MHA =
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA;
#else
const bool isMHASupported = false;
#endif

if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);
Expand Down

0 comments on commit 62e8e08

Please sign in to comment.