Skip to content

Commit

Permalink
[Snippets][CPU][Port to 2025.0] Disable MHA tokenization in LLM (#28611)
Browse files Browse the repository at this point in the history
### Details:
- *The second inference in LLM is usually single token inference. It
means that `M` dimension of MatMuls in SDPA pattern will have the value
`1` (during compilation model this dimension is dynamic (unknown)).
Snippets cannot provide efficient execution for single token inference.
So we decided to disable MHA tokenization by Snippets in CPU Plugin on
LLMs'. We consider the presence of
`ScaledDotProductAttentionWithKVCache` op in the model as a sign that
this model is LLM.*
- *Cherry-picked from
#28601

### Tickets:
 - *160634*
 - *160978 (contains performance validation results)*
  • Loading branch information
a-sidorova authored Jan 23, 2025
1 parent 19f68e8 commit 8d7e7ca
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1031,18 +1031,27 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
#if defined(OPENVINO_ARCH_X86_64)
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
// We consider the presence of `ScaledDotProductAttentionWithKVCache` and `PagedAttentionExtension` ops
// in the model as a sign that this model is LLM.
const auto is_LLM = ov::op::util::has_op_with_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(model) ||
ov::op::util::has_op_with_type<ov::op::PagedAttentionExtension>(model);

// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const auto is_infer_prc_supported_by_MHA =
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA;
#else
const bool isMHASupported = false;
#endif

if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);
Expand Down

0 comments on commit 8d7e7ca

Please sign in to comment.