From d3557a75ba66900572318537a9ce56644a1bfc58 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Mon, 9 Dec 2024 15:16:47 +0800 Subject: [PATCH 1/7] Add Group Query Attention support with OV base OPs --- .../com.microsoft/group_query_attention.cpp | 365 ++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp new file mode 100644 index 00000000000000..953d5d585697c9 --- /dev/null +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -0,0 +1,365 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "core/null_node.hpp" +#include "core/operator_set.hpp" +#include "openvino/frontend/exception.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/floor_mod.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/greater_eq.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/less_eq.hpp" +#include "openvino/op/log.hpp" +#include "openvino/op/logical_not.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/negative.hpp" +#include "openvino/op/pad.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils/split.hpp" + +using namespace ov::op; +using ov::Shape; + +namespace ov { +namespace frontend { +namespace onnx { +namespace com_microsoft { +namespace detail { +namespace { + +std::shared_ptr get_present_state(const std::shared_ptr& K, + const std::shared_ptr& V, + const ov::OutputVector& op_inputs); +std::shared_ptr rotaryEmbedding(std::shared_ptr input, + std::shared_ptr past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + const auto do_rotary = node.get_attribute_value("do_rotary"); + const auto num_heads = node.get_attribute_value("num_heads"); + const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); + const auto scale = node.get_attribute_value("scale", 0.0f); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved"); + // TODO: add softcap support + + auto nodes = node.get_ov_inputs(); + const auto node_shape = std::make_shared(nodes[0]); + const auto batch_size = detail::get_dimensions(node_shape, {0}); + const auto current_seqlen_size = detail::get_dimensions(node_shape, {1}); + const auto hidden_size = detail::get_dimensions(node_shape, {2}); + const auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); + + // Q K V (batch_size, sequence_len, num_heads, head_size) + ov::Output oQ, oK, oV; + int index = 0; + oQ = nodes[index++]; + oK = nodes[index++]; + oV = nodes[index++]; + if (ov::op::util::is_null(oK)) { + // Handle the packed QKV + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, + 0); + auto inputs_qkv = std::make_shared(oQ, packed_qkv_shape, false); + // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, num_head, head_size) + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 2); + oQ = split[0]; + oK = split[1]; + oV = split[2]; + } + + std::shared_ptr Q, K, V; + + const auto& past_key = nodes[index++].get_node_shared_ptr(); + const auto& past_value = nodes[index++].get_node_shared_ptr(); + const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); + const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) + const auto& cos_cache = nodes[index++].get_node_shared_ptr(); + const auto& sin_cache = nodes[index++].get_node_shared_ptr(); + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + Q = std::make_shared(oQ, perm); + K = std::make_shared(oK, perm); + V = std::make_shared(oV, perm); + + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); + auto real_seqlens = std::make_shared(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = std::make_shared(real_seqlens, one, false); + auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + + if (do_rotary) { + Q = detail::rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache, + sin_cache, + head_size_node, + rotary_interleaved); + K = detail::rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache, + sin_cache, + head_size_node, + rotary_interleaved); + } + // present = concat(K, V) if 'past' input is unavailable + // or + // present = concat(past, K, V) + auto construct_kv_cache = [&](const std::shared_ptr& past, const std::shared_ptr& current) { + auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); + auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); + }; + + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + std::shared_ptr alpha; + if (scale == 0.0f) { + alpha = std::make_shared(head_size_node); + } else { + alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale}); + } + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = std::make_shared(K); + // (batch_size, num_heads, sequence_len, head_size) + const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = std::make_shared(K, new_kv_shape, false); + V = std::make_shared(V, new_kv_shape, false); + K = std::make_shared(ov::NodeVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::NodeVector(kv_num_heads_factor, V), 2); + auto q_shape = std::make_shared(Q); + // (batch_size, num_heads, sequence_len, head_size) + const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = std::make_shared(K, extended_kv_shape, false); + V = std::make_shared(V, extended_kv_shape, false); + } + // compute softmax((Q x K') / sqrt(head_size)) + std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); + softmax_input = std::make_shared(softmax_input, alpha); + + // need to apply low-triangle mask to attention score. + auto past_seq_len_scalar = std::make_shared(past_sequence_length, one_without_shape, false); + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); + std::shared_ptr mask_per_line_node = + std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); + auto mask_shape = std::make_shared(ov::NodeVector{one, one, one, seqlens_1d}, 0); + mask_per_line_node = std::make_shared(mask_per_line_node, mask_shape, false); + auto pad_end_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, seqlens_1d}, 0); + auto paded_mask = std::make_shared(mask_per_line_node, pad_end_shape); + std::shared_ptr compare_mask = + std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); + auto compare_range_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, one}, 0); + compare_mask = std::make_shared(compare_mask, compare_range_shape, false); + auto lower_triangular_mask = std::make_shared(paded_mask, compare_mask); + auto higher_triangular_mask = std::make_shared(paded_mask, compare_mask); + auto negtive_const = v0::Constant::create(ov::element::f32, ov::Shape{}, {-1e20f}); + + auto convert_mask = std::make_shared(higher_triangular_mask, ov::element::f32); + auto input_offset_data = std::make_shared(convert_mask, negtive_const); + + convert_mask = std::make_shared(lower_triangular_mask, ov::element::f32); + auto softmax_input_masked = std::make_shared(softmax_input, convert_mask); + std::shared_ptr softmax_input_added = std::make_shared(softmax_input_masked, input_offset_data); + // softmax((Q x K' + mask) / sqrt(head_size)) + const auto softmax = std::make_shared(softmax_input_added, 3); + + // softmax((Q x K' + mask) / sqrt(head_size)) x V + std::shared_ptr output = std::make_shared(softmax, V); + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads, head_size) + output = std::make_shared(output, perm); + auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); + // reshape the result from (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + output = std::make_shared(output, dim_merge_shape, true); + + return {output, present_k, present_v}; +} + +ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + +} // namespace opset_1 + +namespace detail { +namespace { + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +std::shared_ptr rotaryEmbedding(std::shared_ptr input, + std::shared_ptr past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cache_shape, {-1}); + auto cache_1st_dim = get_dimensions(cache_shape, {0}); + + // TODO: check the shape + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + // auto dim_head_size = get_dimensions(input_shape, {3}); + // half_last_dim is same as cos_cache + std::shared_ptr half_last_dim = cache_last_dim; + + auto real_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, dim_head_size}, 0); + auto slice_cache_dim_shape = seqlen_k; + + // auto end_lens = std::make_shared(half_last_dim, one); + // auto masks = std::make_shared(one, + // zero, + // end_lens, + // op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}), + // ov::op::PadMode::CONSTANT); + auto masks = std::make_shared(one, half_last_dim); + + if (interleaved) { + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_one); + auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_one); + + auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -1); + + auto mask_shape = std::make_shared(ov::NodeVector{half_last_dim, one}, 0); + auto reshaped_mask = std::make_shared(masks, mask_shape, false); + auto negtive_mask = std::make_shared(reshaped_mask); + auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -1); + auto real_mask = std::make_shared(concat_mask, dim_head_size, false); + auto mask_f32 = std::make_shared(real_mask, ov::element::f32); + + auto real_input0 = std::make_shared(reshaped_input, input_shape, false); + auto real_input1 = std::make_shared(second_input, input_shape, false); + + auto new_cache_shape = std::make_shared(ov::NodeVector{cache_shape, two}, 0); + auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_shape, one}, 0); + auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); + auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); + auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); + auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); + auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); + auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); + auto sliced_cos_input = + std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto sliced_sin_input = + std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto add_input0 = std::make_shared(real_input0, sliced_cos_input); + auto add_input1 = std::make_shared(real_input1, sliced_sin_input); + auto multi_input1 = std::make_shared(add_input1, mask_f32); + auto result = std::make_shared(add_input0, multi_input1); + return result; + } else { + auto negtive_two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-2}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, two, half_last_dim}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_two); + auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_two); + + auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -2); + + auto mask_shape = std::make_shared(ov::NodeVector{one, half_last_dim}, 0); + auto reshaped_mask = std::make_shared(masks, mask_shape, false); + auto negtive_mask = std::make_shared(reshaped_mask); + auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -2); + auto real_mask = std::make_shared(concat_mask, dim_head_size, false); + auto mask_f32 = std::make_shared(real_mask, ov::element::f32); + + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 1, 2, 4, 3}); + auto input0 = reshaped_input; // std::make_shared(reshaped_input, perm); + auto input1 = second_input; // std::make_shared(second_input, perm); + auto real_input0 = std::make_shared(input0, input_shape, false); + auto real_input1 = std::make_shared(input1, input_shape, false); + + auto new_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, two, cache_last_dim}, 0); + auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, one, cache_last_dim}, 0); + auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); + auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); + auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); + auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); + auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); + auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); + // TODO: change zero to sequence_K + auto sliced_cos_input = + std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto sliced_sin_input = + std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); + auto add_input0 = std::make_shared(real_input0, sliced_cos_input); + auto add_input1 = std::make_shared(real_input1, sliced_sin_input); + auto multi_input1 = std::make_shared(add_input1, mask_f32); + auto result = std::make_shared(add_input0, multi_input1); + return result; + } +} +} // namespace +} // namespace detail +} // namespace com_microsoft +} // namespace onnx +} // namespace frontend +} // namespace ov From 27d16729f306d9e83ad8d5f1fb1b7b80eeae339b Mon Sep 17 00:00:00 2001 From: LiangGao Date: Tue, 31 Dec 2024 15:41:52 +0800 Subject: [PATCH 2/7] Update the transformation code --- .../com.microsoft/group_query_attention.cpp | 138 +++++++----------- 1 file changed, 54 insertions(+), 84 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 953d5d585697c9..b4e85812afffcf 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -38,6 +38,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/convert_like.hpp" #include "utils/split.hpp" using namespace ov::op; @@ -53,8 +54,8 @@ namespace { std::shared_ptr get_present_state(const std::shared_ptr& K, const std::shared_ptr& V, const ov::OutputVector& op_inputs); -std::shared_ptr rotaryEmbedding(std::shared_ptr input, - std::shared_ptr past_seqlen, +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, std::shared_ptr seqlen_k, std::shared_ptr cos_cache, std::shared_ptr sin_cache, @@ -82,40 +83,41 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + // Q K V (batch_size, sequence_len, num_heads, head_size) - ov::Output oQ, oK, oV; + ov::Output Q, K, V; int index = 0; - oQ = nodes[index++]; - oK = nodes[index++]; - oV = nodes[index++]; - if (ov::op::util::is_null(oK)) { + Q = nodes[index++]; + K = nodes[index++]; + V = nodes[index++]; + if (ov::op::util::is_null(K)) { // Handle the packed QKV auto packed_qkv_shape = std::make_shared( ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, 0); - auto inputs_qkv = std::make_shared(oQ, packed_qkv_shape, false); - // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, num_head, head_size) - auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 2); - oQ = split[0]; - oK = split[1]; - oV = split[2]; + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + // (batch_size, sequence_len, num_head, head_size) + inputs_qkv = std::make_shared(inputs_qkv, perm); + // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + Q = split[0]; + K = split[1]; + V = split[2]; + } else { + Q = std::make_shared(Q, perm); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); } - std::shared_ptr Q, K, V; - - const auto& past_key = nodes[index++].get_node_shared_ptr(); - const auto& past_value = nodes[index++].get_node_shared_ptr(); + const auto& past_key = nodes[index++]; + const auto& past_value = nodes[index++]; const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) const auto& cos_cache = nodes[index++].get_node_shared_ptr(); const auto& sin_cache = nodes[index++].get_node_shared_ptr(); - // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); - Q = std::make_shared(oQ, perm); - K = std::make_shared(oK, perm); - V = std::make_shared(oV, perm); - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); @@ -146,7 +148,7 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { // present = concat(K, V) if 'past' input is unavailable // or // present = concat(past, K, V) - auto construct_kv_cache = [&](const std::shared_ptr& past, const std::shared_ptr& current) { + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); @@ -172,8 +174,8 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); K = std::make_shared(K, new_kv_shape, false); V = std::make_shared(V, new_kv_shape, false); - K = std::make_shared(ov::NodeVector(kv_num_heads_factor, K), 2); - V = std::make_shared(ov::NodeVector(kv_num_heads_factor, V), 2); + K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); auto q_shape = std::make_shared(Q); // (batch_size, num_heads, sequence_len, head_size) const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); @@ -193,24 +195,22 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { seqlens_1d_scalar, one_without_shape, ov::element::i64); - auto mask_shape = std::make_shared(ov::NodeVector{one, one, one, seqlens_1d}, 0); - mask_per_line_node = std::make_shared(mask_per_line_node, mask_shape, false); - auto pad_end_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, seqlens_1d}, 0); - auto paded_mask = std::make_shared(mask_per_line_node, pad_end_shape); - std::shared_ptr compare_mask = + mask_per_line_node = std::make_shared(mask_per_line_node, zero); + auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()}); + auto mask_shape = std::make_shared(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0); + auto compare_mask = std::make_shared(mask_per_line_node, mask_shape); + + std::shared_ptr vertical_range = std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); - auto compare_range_shape = std::make_shared(ov::NodeVector{one, one, current_seqlen_size, one}, 0); - compare_mask = std::make_shared(compare_mask, compare_range_shape, false); - auto lower_triangular_mask = std::make_shared(paded_mask, compare_mask); - auto higher_triangular_mask = std::make_shared(paded_mask, compare_mask); - auto negtive_const = v0::Constant::create(ov::element::f32, ov::Shape{}, {-1e20f}); - - auto convert_mask = std::make_shared(higher_triangular_mask, ov::element::f32); - auto input_offset_data = std::make_shared(convert_mask, negtive_const); - - convert_mask = std::make_shared(lower_triangular_mask, ov::element::f32); - auto softmax_input_masked = std::make_shared(softmax_input, convert_mask); - std::shared_ptr softmax_input_added = std::make_shared(softmax_input_masked, input_offset_data); + vertical_range = std::make_shared(vertical_range, one); + + auto triu = std::make_shared(compare_mask, vertical_range); + auto typed_zero = std::make_shared(zero, softmax_input); + auto typed_minus_inf = std::make_shared(minus_inf, softmax_input); + auto minus_inf_mask = std::make_shared(typed_minus_inf, mask_shape); + auto atten_mask = std::make_shared(triu, minus_inf_mask, typed_zero); + + std::shared_ptr softmax_input_added = std::make_shared(softmax_input, atten_mask); // softmax((Q x K' + mask) / sqrt(head_size)) const auto softmax = std::make_shared(softmax_input_added, 3); @@ -245,8 +245,8 @@ std::shared_ptr get_dimensions(const std::shared_ptr& node, return get_dimensions(std::make_shared(node), dims); } -std::shared_ptr rotaryEmbedding(std::shared_ptr input, - std::shared_ptr past_seqlen, +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, std::shared_ptr seqlen_k, std::shared_ptr cos_cache, std::shared_ptr sin_cache, @@ -316,45 +316,15 @@ std::shared_ptr rotaryEmbedding(std::shared_ptr input, auto result = std::make_shared(add_input0, multi_input1); return result; } else { - auto negtive_two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-2}); - auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, two, half_last_dim}, 0); - auto reshaped_input = std::make_shared(input, split_input_shape, false); - auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_two); - auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_two); - - auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -2); - - auto mask_shape = std::make_shared(ov::NodeVector{one, half_last_dim}, 0); - auto reshaped_mask = std::make_shared(masks, mask_shape, false); - auto negtive_mask = std::make_shared(reshaped_mask); - auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -2); - auto real_mask = std::make_shared(concat_mask, dim_head_size, false); - auto mask_f32 = std::make_shared(real_mask, ov::element::f32); - - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 1, 2, 4, 3}); - auto input0 = reshaped_input; // std::make_shared(reshaped_input, perm); - auto input1 = second_input; // std::make_shared(second_input, perm); - auto real_input0 = std::make_shared(input0, input_shape, false); - auto real_input1 = std::make_shared(input1, input_shape, false); - - auto new_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, two, cache_last_dim}, 0); - auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, one, cache_last_dim}, 0); - auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); - auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); - auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); - auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); - auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); - auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); - // TODO: change zero to sequence_K - auto sliced_cos_input = - std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto sliced_sin_input = - std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto add_input0 = std::make_shared(real_input0, sliced_cos_input); - auto add_input1 = std::make_shared(real_input1, sliced_sin_input); - auto multi_input1 = std::make_shared(add_input1, mask_f32); - auto result = std::make_shared(add_input0, multi_input1); - return result; + auto cos = + std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = + std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto in_split = ov::op::util::make_split(input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), std::make_shared(in_split[1], cos)); + + return std::make_shared(ov::NodeVector{res_0, res_1}, -1); } } } // namespace From 8c002b0894b53d9e37e55366c1f53518cf02cfe4 Mon Sep 17 00:00:00 2001 From: LiangGao Date: Wed, 8 Jan 2025 15:49:03 +0800 Subject: [PATCH 3/7] Use scaled_dot_product_attention to improve the perfomance --- .../com.microsoft/group_query_attention.cpp | 143 ++++++------------ 1 file changed, 50 insertions(+), 93 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index b4e85812afffcf..cc9dde3d256e0d 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -38,6 +38,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/convert_like.hpp" #include "utils/split.hpp" @@ -159,12 +160,6 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { auto present_k = K; auto present_v = V; - std::shared_ptr alpha; - if (scale == 0.0f) { - alpha = std::make_shared(head_size_node); - } else { - alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale}); - } const size_t kv_num_heads_factor = num_heads / kv_num_heads; if (kv_num_heads_factor > 1) { const auto kv_shape = std::make_shared(K); @@ -183,47 +178,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { K = std::make_shared(K, extended_kv_shape, false); V = std::make_shared(V, extended_kv_shape, false); } - // compute softmax((Q x K') / sqrt(head_size)) - std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); - softmax_input = std::make_shared(softmax_input, alpha); // need to apply low-triangle mask to attention score. - auto past_seq_len_scalar = std::make_shared(past_sequence_length, one_without_shape, false); - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 std::shared_ptr mask_per_line_node = std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), seqlens_1d_scalar, one_without_shape, - ov::element::i64); - mask_per_line_node = std::make_shared(mask_per_line_node, zero); - auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()}); - auto mask_shape = std::make_shared(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0); - auto compare_mask = std::make_shared(mask_per_line_node, mask_shape); - - std::shared_ptr vertical_range = - std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); - vertical_range = std::make_shared(vertical_range, one); - - auto triu = std::make_shared(compare_mask, vertical_range); - auto typed_zero = std::make_shared(zero, softmax_input); - auto typed_minus_inf = std::make_shared(minus_inf, softmax_input); - auto minus_inf_mask = std::make_shared(typed_minus_inf, mask_shape); - auto atten_mask = std::make_shared(triu, minus_inf_mask, typed_zero); - - std::shared_ptr softmax_input_added = std::make_shared(softmax_input, atten_mask); - // softmax((Q x K' + mask) / sqrt(head_size)) - const auto softmax = std::make_shared(softmax_input_added, 3); - - // softmax((Q x K' + mask) / sqrt(head_size)) x V - std::shared_ptr output = std::make_shared(softmax, V); + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0}); + auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits::infinity()}); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } // transpose the result from (batch_size, num_heads, sequence_length, head_size) // to (batch_size, sequence_length, num_heads, head_size) - output = std::make_shared(output, perm); + auto qga_output_transposed = std::make_shared(qga_output, perm); auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); // reshape the result from (batch_size, sequence_length, num_heads, head_size) // to (batch_size, sequence_length, num_heads * head_size) - output = std::make_shared(output, dim_merge_shape, true); + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); return {output, present_k, present_v}; } @@ -254,75 +245,41 @@ std::shared_ptr rotaryEmbedding(ov::Output input, bool interleaved) { auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto cache_shape = std::make_shared(cos_cache); - auto cache_last_dim = get_dimensions(cache_shape, {-1}); - auto cache_1st_dim = get_dimensions(cache_shape, {0}); + auto slice_cache_dim_shape = seqlen_k; - // TODO: check the shape - auto input_shape = std::make_shared(input); + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); - // auto dim_head_size = get_dimensions(input_shape, {3}); - // half_last_dim is same as cos_cache - std::shared_ptr half_last_dim = cache_last_dim; + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto real_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, dim_head_size}, 0); - auto slice_cache_dim_shape = seqlen_k; + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); - // auto end_lens = std::make_shared(half_last_dim, one); - // auto masks = std::make_shared(one, - // zero, - // end_lens, - // op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}), - // ov::op::PadMode::CONSTANT); - auto masks = std::make_shared(one, half_last_dim); + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; - if (interleaved) { auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); auto reshaped_input = std::make_shared(input, split_input_shape, false); - auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_one); - auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_one); - - auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -1); - - auto mask_shape = std::make_shared(ov::NodeVector{half_last_dim, one}, 0); - auto reshaped_mask = std::make_shared(masks, mask_shape, false); - auto negtive_mask = std::make_shared(reshaped_mask); - auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -1); - auto real_mask = std::make_shared(concat_mask, dim_head_size, false); - auto mask_f32 = std::make_shared(real_mask, ov::element::f32); - - auto real_input0 = std::make_shared(reshaped_input, input_shape, false); - auto real_input1 = std::make_shared(second_input, input_shape, false); - - auto new_cache_shape = std::make_shared(ov::NodeVector{cache_shape, two}, 0); - auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_shape, one}, 0); - auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); - auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); - auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); - auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); - auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); - auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); - auto sliced_cos_input = - std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto sliced_sin_input = - std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto add_input0 = std::make_shared(real_input0, sliced_cos_input); - auto add_input1 = std::make_shared(real_input1, sliced_sin_input); - auto multi_input1 = std::make_shared(add_input1, mask_f32); - auto result = std::make_shared(add_input0, multi_input1); - return result; + + auto in_split = ov::op::util::make_split(reshaped_input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); + return std::make_shared(concat_ret, input_shape, false); } else { - auto cos = - std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto sin = - std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); auto in_split = ov::op::util::make_split(input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), std::make_shared(in_split[1], cos)); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); return std::make_shared(ov::NodeVector{res_0, res_1}, -1); } From 651be6064d4389374d61266380d9e18a02f02809 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 16 Jan 2025 13:15:18 +0800 Subject: [PATCH 4/7] Add OV GQA op and decomposition pass Fix interleave logic in decomposition Add ONNX frontend tests --- .../group_query_attention_decomposition.hpp | 24 + .../common_optimizations.cpp | 2 + .../group_query_attention_decomposition.cpp | 322 ++++++++++++++ .../openvino/op/group_query_attention.hpp | 54 +++ src/core/include/openvino/op/null.hpp | 47 ++ src/core/include/openvino/op/ops.hpp | 2 + .../include/openvino/opsets/opset15_tbl.hpp | 2 + src/core/src/op/group_query_attention.cpp | 92 ++++ src/core/tests/opset.cpp | 2 +- .../com.microsoft/group_query_attention.cpp | 277 +----------- .../gqa_past_0_input_1_rotary.prototxt | 247 +++++++++++ ...past_0_input_1_rotary_interleaved.prototxt | 247 +++++++++++ .../gqa_past_1_input_1_rotary.prototxt | 244 +++++++++++ ...past_1_input_1_rotary_interleaved.prototxt | 244 +++++++++++ .../tests/onnx_import_com_microsoft.in.cpp | 412 ++++++++++++++++++ 15 files changed, 1957 insertions(+), 261 deletions(-) create mode 100644 src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp create mode 100644 src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp create mode 100644 src/core/include/openvino/op/group_query_attention.hpp create mode 100644 src/core/include/openvino/op/null.hpp create mode 100644 src/core/src/op/group_query_attention.cpp create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt create mode 100644 src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp new file mode 100644 index 00000000000000..51c21f4808c61e --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/group_query_attention.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; + +} // namespace pass +} // namespace ov + +class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GroupQueryAttentionDecomposition", "0"); + GroupQueryAttentionDecomposition(); + ov::OutputVector decompose(std::shared_ptr node); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 87813bae65538d..66651b3907f344 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -108,6 +108,7 @@ #include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/group_normalization_decomposition.hpp" +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" #include "transformations/op_conversions/hsigmoid_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp" #include "transformations/op_conversions/log_softmax_decomposition.hpp" @@ -156,6 +157,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); + ADD_MATCHER(decomp, GroupQueryAttentionDecomposition) ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition) ADD_MATCHER(decomp, Gelu7Downgrade) ADD_MATCHER(decomp, BidirectionalSequenceDecomposition) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp new file mode 100644 index 00000000000000..1441e3e67a633d --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -0,0 +1,322 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/variadic_split.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); +ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); +std::shared_ptr create_minus_inf(const ov::element::Type& T); + +ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { + MATCHER_SCOPE(GroupQeuryAttentionDecomposition); + auto pattern_node = ov::pass::pattern::wrap_type(); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto node = + ov::as_type_ptr(pattern_to_output.at(pattern_node).get_node_shared_ptr()); + + if (node == nullptr || transformation_callback(node)) { + return false; + } + + auto new_output_node = decompose(node); + ov::replace_node(node, new_output_node); + return true; + }; + + auto m = std::make_shared(pattern_node, matcher_name); + register_matcher(m, callback); +} + +ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( + std::shared_ptr node) { + using namespace ov::op; + + const auto num_heads = node->get_num_heads(); + const auto kv_num_heads = node->get_kv_num_heads(); + const auto scale = node->get_scale(); + const auto do_rotary = node->get_do_rotary(); + const auto rotary_interleaved = node->get_rotary_interleaved(); + // TODO: add softcap support + + auto Q = node->input_value(0); + auto K = node->input_value(1); + auto V = node->input_value(2); + auto past_key = node->input_value(3); + auto past_value = node->input_value(4); + auto seqlens_k = node->input_value(5); + auto total_sequence_length = node->input_value(6); // unused, it's not always equal (seqlens_k + 1) + auto cos_cache = node->input_value(7); + auto sin_cache = node->input_value(8); + + // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current + + const auto T = Q.get_element_type(); + const auto node_shape = std::make_shared(Q); + const auto batch_size = get_dimensions(node_shape, {0}); + const auto current_seqlen_size = get_dimensions(node_shape, {1}); + const auto hidden_size = get_dimensions(node_shape, {2}); + const auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + + if (v15::Null::is_null(K)) { + // Handle the packed QKV + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, + 0); + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + // (batch_size, sequence_len, num_head, head_size) + inputs_qkv = std::make_shared(inputs_qkv, perm); + // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) + auto split = make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + Q = split[0]; + K = split[1]; + V = split[2]; + } else { + Q = std::make_shared(Q, perm); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); + } + + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); + auto real_seqlens = std::make_shared(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = std::make_shared(real_seqlens, one, false); + auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + + if (do_rotary) { + Q = rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + K = rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + } + // present = concat(K, V) if 'past' input is unavailable + // or + // present = concat(past, K, V) + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { + auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); + auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); + }; + + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = std::make_shared(K); + // (batch_size, num_heads, sequence_len, head_size) + const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = std::make_shared(K, new_kv_shape, false); + V = std::make_shared(V, new_kv_shape, false); + K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); + auto q_shape = std::make_shared(Q); + // (batch_size, num_heads, sequence_len, head_size) + const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = std::make_shared(K, extended_kv_shape, false); + V = std::make_shared(V, extended_kv_shape, false); + } + + // need to apply low-triangle mask to attention score. + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 + std::shared_ptr mask_per_line_node = + std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); + auto minus_inf = create_minus_inf(T); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(T, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads, head_size) + auto qga_output_transposed = std::make_shared(qga_output, perm); + auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); + // reshape the result from (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true)->output(0); + + return {output, present_k, present_v}; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + using namespace ov::op; + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + + auto slice_cache_dim_shape = seqlen_k; + + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); + + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; + + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + + auto in_split = make_split(reshaped_input, 2, -1); + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim}, 0); + auto in_split_0 = std::make_shared(in_split[0], split_input_shape, false); + auto in_split_1 = std::make_shared(in_split[1], split_input_shape, false); + + auto res_0 = std::make_shared(std::make_shared(in_split_0, cos), + std::make_shared(in_split_1, sin)); + auto res_1 = std::make_shared(std::make_shared(in_split_0, sin), + std::make_shared(in_split_1, cos)); + + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, one}, 0); + auto res_0_5d = std::make_shared(res_0, split_input_shape, false); + auto res_1_5d = std::make_shared(res_1, split_input_shape, false); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0_5d, res_1_5d}, -1); + return std::make_shared(concat_ret, input_shape, false); + } else { + auto in_split = make_split(input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + return std::make_shared(ov::NodeVector{res_0, res_1}, -1); + } +} + +// make split functions is a copy-past from ONNX FE. TODO: move it to one place +ov::OutputVector make_split(const ov::Output& value, + const std::vector& split_lengths, + int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split_lengths_node = + v0::Constant::create(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths); + const auto variadic_split = std::make_shared(value, axis_node, split_lengths_node); + + return variadic_split->outputs(); +} + +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split = std::make_shared(value, axis_node, num_splits); + + return split->outputs(); +} + +std::shared_ptr create_minus_inf(const ov::element::Type& T) { + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + if (T == ov::element::f32) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); + } else if (T == ov::element::f16) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); + } else { + OPENVINO_THROW("GroupQueryAttention only supports f32 and f16"); + } +} diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp new file mode 100644 index 00000000000000..25337e0faf9be5 --- /dev/null +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +// This is an experimental operation that is implemented in the plugins. +class OPENVINO_API GroupQueryAttention : public Op { +public: + OPENVINO_OP("GroupQueryAttention", "opset15", op::Op); + + GroupQueryAttention() = default; + GroupQueryAttention(const ov::OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved); + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + unsigned int get_num_heads() const { + return m_num_heads; + } + unsigned int get_kv_num_heads() const { + return m_kv_num_heads; + } + float get_scale() const { + return m_scale; + } + bool get_do_rotary() const { + return m_do_rotary; + } + bool get_rotary_interleaved() const { + return m_rotary_interleaved; + } + +private: + unsigned int m_num_heads; + unsigned int m_kv_num_heads; + float m_scale = 0; + bool m_do_rotary = false; + bool m_rotary_interleaved = false; +}; + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/null.hpp b/src/core/include/openvino/op/null.hpp new file mode 100644 index 00000000000000..346ff3d77c532b --- /dev/null +++ b/src/core/include/openvino/op/null.hpp @@ -0,0 +1,47 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +/// \brief Represents a missing optional input or output of an ONNX node +/// +/// Some ONNX operators have inputs or outputs that are marked as optional, +/// which means that a referring node MAY forgo providing values for such inputs +/// or computing these outputs. +/// An empty string is used in place of a name of such input or output. +/// +/// More: +/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs +class OPENVINO_API Null : public Op { +public: + OPENVINO_OP("Null", "opset15", op::Op); + Null() { + set_output_size(1); + } + + static bool is_null(const ov::Node* node) { + return ov::as_type(node) != nullptr; + } + + static bool is_null(const std::shared_ptr& node) { + return is_null(node.get()); + } + + static bool is_null(const Output& output) { + return is_null(output.get_node()); + } + + virtual std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + return std::make_shared(); + } +}; +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 1df29258cb8928..f87e14f65e9bf9 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -168,6 +168,8 @@ #include "openvino/op/roll.hpp" #include "openvino/op/round.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/group_query_attention.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/scatter_update.hpp" diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index a9e8d2a8dcc840..8dd7c0d7b830cc 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -234,3 +234,5 @@ _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) _OPENVINO_OP_REG(SearchSorted, ov::op::v15) +_OPENVINO_OP_REG(GroupQueryAttention, ov::op::v15) +_OPENVINO_OP_REG(Null, ov::op::v15) diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp new file mode 100644 index 00000000000000..475110e66bf5e3 --- /dev/null +++ b/src/core/src/op/group_query_attention.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" + +#include "itt.hpp" +#include "openvino/op/null.hpp" + +using namespace std; +namespace ov { +namespace op { +namespace v15 { + +GroupQueryAttention::GroupQueryAttention(const OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved) + : Op(args), + m_num_heads(num_heads), + m_kv_num_heads(kv_num_heads), + m_scale(scale), + m_do_rotary(do_rotary), + m_rotary_interleaved(rotary_interleaved) { + constructor_validate_and_infer_types(); +} + +int64_t get_head_size(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + return input_shape[2].get_length() / (num_heads + kv_num_heads * 2); +} + +std::vector get_qkv_sizes(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + int64_t per_head_size = get_head_size(input_shape, num_heads, kv_num_heads); + const std::vector qkv_sizes = {num_heads * per_head_size, + kv_num_heads * per_head_size, + kv_num_heads * per_head_size}; + return qkv_sizes; +} + +void GroupQueryAttention::validate_and_infer_types() { + OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); + PartialShape input_shape = get_input_partial_shape(0); + Dimension batch_size = input_shape[0]; + Dimension sequence_len = input_shape[1]; + Dimension head_size; + if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) { + head_size = get_head_size(input_shape, m_num_heads, m_kv_num_heads); + } else { + head_size = input_shape[2].get_length() / m_num_heads; + } + Dimension output_kv_len; + PartialShape kv_past_shape = get_input_partial_shape(3); + // FIXME: https://github.com/openvinotoolkit/openvino/pull/27648 + if (kv_past_shape[2].is_static()) { + output_kv_len = kv_past_shape[2] + sequence_len; + } else { + output_kv_len = ov::Dimension(); + } + auto element_type = get_input_element_type(0); + NODE_VALIDATION_CHECK(this, + element_type == element::f32 || element_type == element::f16, + "GroupQueryAttention only suuports f32 and f16"); + set_output_type(0, element_type, PartialShape{batch_size, sequence_len, head_size * m_num_heads}); + set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); + set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); +} + +bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(v15_GroupQueryAttention_visit_attributes); + visitor.on_attribute("do_rotary", m_do_rotary); + visitor.on_attribute("kv_num_heads", m_kv_num_heads); + visitor.on_attribute("num_heads", m_num_heads); + visitor.on_attribute("rotary_interleaved", m_rotary_interleaved); + visitor.on_attribute("scale", m_scale); + return true; +} + +std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(v15_GroupQueryAttention_clone_with_new_inputs); + return std::make_shared(new_args, + m_num_heads, + m_kv_num_heads, + m_scale, + m_do_rotary, + m_rotary_interleaved); +} + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index dc24a832560464..5f1f60b733c915 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -76,7 +76,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 199}, + OpsetTestParams{ov::get_opset15, 201}, OpsetTestParams{ov::get_opset16, 5}), OpsetTestNameGenerator{}); diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index cc9dde3d256e0d..692d2529caf7e5 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -2,45 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/group_query_attention.hpp" +#include "openvino/op/null.hpp" + #include "core/null_node.hpp" #include "core/operator_set.hpp" #include "openvino/frontend/exception.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/broadcast.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/divide.hpp" -#include "openvino/op/equal.hpp" -#include "openvino/op/floor.hpp" -#include "openvino/op/floor_mod.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/greater.hpp" -#include "openvino/op/greater_eq.hpp" -#include "openvino/op/less.hpp" -#include "openvino/op/less_eq.hpp" -#include "openvino/op/log.hpp" -#include "openvino/op/logical_not.hpp" -#include "openvino/op/logical_or.hpp" -#include "openvino/op/matmul.hpp" -#include "openvino/op/maximum.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/negative.hpp" -#include "openvino/op/pad.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/select.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/slice.hpp" -#include "openvino/op/softmax.hpp" -#include "openvino/op/sqrt.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/op/transpose.hpp" -#include "openvino/op/unsqueeze.hpp" -#include "openvino/op/scaled_dot_product_attention.hpp" -#include "openvino/op/convert_like.hpp" -#include "utils/split.hpp" using namespace ov::op; using ov::Shape; @@ -49,243 +16,33 @@ namespace ov { namespace frontend { namespace onnx { namespace com_microsoft { -namespace detail { -namespace { - -std::shared_ptr get_present_state(const std::shared_ptr& K, - const std::shared_ptr& V, - const ov::OutputVector& op_inputs); -std::shared_ptr rotaryEmbedding(ov::Output input, - ov::Output past_seqlen, - std::shared_ptr seqlen_k, - std::shared_ptr cos_cache, - std::shared_ptr sin_cache, - std::shared_ptr dim_head_size, - bool interleaved); -std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); -} // namespace -} // namespace detail namespace opset_1 { ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { - const auto do_rotary = node.get_attribute_value("do_rotary"); + const auto onnx_op_inputs = node.get_ov_inputs(); const auto num_heads = node.get_attribute_value("num_heads"); const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); const auto scale = node.get_attribute_value("scale", 0.0f); - const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved"); - // TODO: add softcap support - - auto nodes = node.get_ov_inputs(); - const auto node_shape = std::make_shared(nodes[0]); - const auto batch_size = detail::get_dimensions(node_shape, {0}); - const auto current_seqlen_size = detail::get_dimensions(node_shape, {1}); - const auto hidden_size = detail::get_dimensions(node_shape, {2}); - const auto total_num_heads_node = - v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); - auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); - - // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) - auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); - - // Q K V (batch_size, sequence_len, num_heads, head_size) - ov::Output Q, K, V; - int index = 0; - Q = nodes[index++]; - K = nodes[index++]; - V = nodes[index++]; - if (ov::op::util::is_null(K)) { - // Handle the packed QKV - auto packed_qkv_shape = std::make_shared( - ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, - 0); - auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); - // (batch_size, sequence_len, num_head, head_size) - inputs_qkv = std::make_shared(inputs_qkv, perm); - // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) - auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); - Q = split[0]; - K = split[1]; - V = split[2]; - } else { - Q = std::make_shared(Q, perm); - K = std::make_shared(K, perm); - V = std::make_shared(V, perm); - } - - const auto& past_key = nodes[index++]; - const auto& past_value = nodes[index++]; - const auto& seqlens_k = nodes[index++].get_node_shared_ptr(); - const auto& total_sequence_length = nodes[index++]; // unused, it's not always equal (seqlens_k + 1) - const auto& cos_cache = nodes[index++].get_node_shared_ptr(); - const auto& sin_cache = nodes[index++].get_node_shared_ptr(); - - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); - auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); - auto real_seqlens = std::make_shared(seqlens_elemi64, one); - - // Only consider batch is 1 - auto seqlens_1d = std::make_shared(real_seqlens, one, false); - auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + const auto do_rotary = node.get_attribute_value("do_rotary", 0); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0.0f); - if (do_rotary) { - Q = detail::rotaryEmbedding(Q, - past_sequence_length, - seqlens_1d, - cos_cache, - sin_cache, - head_size_node, - rotary_interleaved); - K = detail::rotaryEmbedding(K, - past_sequence_length, - seqlens_1d, - cos_cache, - sin_cache, - head_size_node, - rotary_interleaved); + OutputVector ov_op_inputs; + ov_op_inputs.reserve(onnx_op_inputs.size()); + for (const auto& input : onnx_op_inputs) { + ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); } - // present = concat(K, V) if 'past' input is unavailable - // or - // present = concat(past, K, V) - auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { - auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); - auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); - return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); - }; - - K = construct_kv_cache(past_key, K); - V = construct_kv_cache(past_value, V); - auto present_k = K; - auto present_v = V; - - const size_t kv_num_heads_factor = num_heads / kv_num_heads; - if (kv_num_heads_factor > 1) { - const auto kv_shape = std::make_shared(K); - // (batch_size, num_heads, sequence_len, head_size) - const auto kv_shape_prev_2 = detail::get_dimensions(kv_shape, {0, 1}); - const auto kv_shape_last_2 = detail::get_dimensions(kv_shape, {2, 3}); - auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); - K = std::make_shared(K, new_kv_shape, false); - V = std::make_shared(V, new_kv_shape, false); - K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); - V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); - auto q_shape = std::make_shared(Q); - // (batch_size, num_heads, sequence_len, head_size) - const auto q_shape_prev_2 = detail::get_dimensions(q_shape, {0, 1}); - auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); - K = std::make_shared(K, extended_kv_shape, false); - V = std::make_shared(V, extended_kv_shape, false); - } - - // need to apply low-triangle mask to attention score. - // two steps, construct the total_sequence x total_sequence triangle, then slice the current length - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 - std::shared_ptr mask_per_line_node = - std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), - seqlens_1d_scalar, - one_without_shape, - ov::element::i64); // [0,1,2,...,] - auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 - auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 - auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 - auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0}); - auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits::infinity()}); - auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 - auto atten_mask_sliced = std::make_shared(atten_mask, - past_sequence_length, - seqlens_1d, - one, - zero); // slice to current query seqlen, 12x12 or 1x13 - - // compute softmax((Q x K') / sqrt(head_size)) x V - std::shared_ptr qga_output; - if (scale != 0.0f) { - auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale}); - qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); - } else { - qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); - } - - // transpose the result from (batch_size, num_heads, sequence_length, head_size) - // to (batch_size, sequence_length, num_heads, head_size) - auto qga_output_transposed = std::make_shared(qga_output, perm); - auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); - // reshape the result from (batch_size, sequence_length, num_heads, head_size) - // to (batch_size, sequence_length, num_heads * head_size) - auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); - - return {output, present_k, present_v}; + return std::make_shared(ov_op_inputs, + num_heads, + kv_num_heads, + scale, + do_rotary, + rotary_interleaved) + ->outputs(); } ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); } // namespace opset_1 - -namespace detail { -namespace { - -std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { - static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); - const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); - return std::make_shared(shape, dims_const, zero); -} - -std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { - return get_dimensions(std::make_shared(node), dims); -} - -std::shared_ptr rotaryEmbedding(ov::Output input, - ov::Output past_seqlen, - std::shared_ptr seqlen_k, - std::shared_ptr cos_cache, - std::shared_ptr sin_cache, - std::shared_ptr dim_head_size, - bool interleaved) { - auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); - auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - - auto slice_cache_dim_shape = seqlen_k; - - auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); - - if (interleaved) { - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - - auto cache_shape = std::make_shared(cos_cache); - auto cache_last_dim = get_dimensions(cos_cache, {-1}); - - auto input_shape = std::make_shared(input); - - auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); - std::shared_ptr half_last_dim = cache_last_dim; - - auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); - auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); - auto reshaped_input = std::make_shared(input, split_input_shape, false); - - auto in_split = ov::op::util::make_split(reshaped_input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), - std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), - std::make_shared(in_split[1], cos)); - - auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); - return std::make_shared(concat_ret, input_shape, false); - } else { - auto in_split = ov::op::util::make_split(input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), - std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), - std::make_shared(in_split[1], cos)); - - return std::make_shared(ov::NodeVector{res_0, res_1}, -1); - } -} -} // namespace -} // namespace detail } // namespace com_microsoft } // namespace onnx } // namespace frontend diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt new file mode 100644 index 00000000000000..a7dacf0dc94ebc --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..d1400ad344e717 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt new file mode 100644 index 00000000000000..f5ec39c9b0bd8e --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..b61cf39552efc7 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 47a336f1749417..37cd5449484543 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1682,3 +1682,415 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_qlinear_mul) { test_case.add_expected_output(Shape{2, 2}, expected_output); test_case.run(); } + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {1.2561098, + 1.0199738, + -0.05948371, + -0.16574995, + 2.5059946, + -1.738188, + -0.03158256, + -0.35975295, + 1.0918287, + -0.90313876, + -0.4790303, + 0.67029977, + -0.87039495, + 0.7783688, + -0.81333745, + 0.89886224}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {2.118801, + -0.2640816, + -0.5926066, + -0.19455537, + 0.9903903, + 2.954185, + -0.35343042, + -0.07457897, + -0.25603274, + -0.03627284, + 0.56591415, + 0.02181074, + -0.1586003, + 0.96567893, + -0.8591481, + 0.85514885}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.53934956, 0.6341806, 1.0099611, 1.1771176, -1.270278, 2.3782496, -0.511299, 0.30222273, + -1.5435482, -0.5423737, -0.27520883, -0.4914196, -0.96554786, 1.1191509, -0.05004983, 0.85533774, + -0.49356747, 0.19581467, 0.8553029, 1.0041412, -0.9513843, 2.088453, -0.5698854, 0.25103146, + -1.4120293, -0.5311372, 0.03857604, -0.47871974, -0.8099488, 1.1256707, 0.08898184, 0.93131447}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.6519198, 1.1400802, 0.45031136, 0.5877534, -0.65952265, -1.8121169, 0.04630837, 0.5568472, + 0.20271924, 0.7458131, -0.17379119, 0.3623912, 2.5696063, -0.58594, -0.8126341, -0.7919839}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.33396345, -1.332403, 0.31613833, 0.40111685, 0.16033238, 1.0781744, -0.7741276, 0.07257013, + -0.9535321, -0.491965, 1.1324831, -0.43444604, -0.2675047, 1.1483997, 0.57366973, 1.1961825, + -0.24709277, -2.164195, 0.02267693, 0.07289726, 0.7654276, 0.5282906, -0.8852943, -0.02456442, + -0.7039771, -0.47064403, 1.7278847, -0.41034833, 0.02774171, 1.1607709, 0.83748007, 1.3403473}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.2216992, 1.7511603, 0.03145146, -0.62293506, -2.625969, 1.6767058, -0.17887366, 0.313817, + 0.1717277, -0.19334024, 0.4056727, 0.39516917, -0.25018305, 0.9460988, 1.0327814, -0.6345757}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} From df91594c59fc60362af052dfe27c6802899a7fab Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Fri, 7 Feb 2025 09:51:51 +0800 Subject: [PATCH 5/7] Apply suggestions from code review Co-authored-by: Tomasz Jankowski --- .../op_conversions/group_query_attention_decomposition.hpp | 2 +- src/core/include/openvino/op/group_query_attention.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp index 51c21f4808c61e..3cad7ab229b110 100644 --- a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -18,7 +18,7 @@ class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("GroupQueryAttentionDecomposition", "0"); + OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition"); GroupQueryAttentionDecomposition(); ov::OutputVector decompose(std::shared_ptr node); }; diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 25337e0faf9be5..744dac6aebd82a 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -12,7 +12,7 @@ namespace v15 { // This is an experimental operation that is implemented in the plugins. class OPENVINO_API GroupQueryAttention : public Op { public: - OPENVINO_OP("GroupQueryAttention", "opset15", op::Op); + OPENVINO_OP("GroupQueryAttention", "opset15"); GroupQueryAttention() = default; GroupQueryAttention(const ov::OutputVector& args, From 080fdd443e8358a9a9e4b42d2b91f080a7ed5e50 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 7 Feb 2025 11:06:42 +0800 Subject: [PATCH 6/7] Remove redundant type check --- .../group_query_attention_decomposition.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 1441e3e67a633d..5f1476e4bceae8 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -41,7 +41,6 @@ std::shared_ptr get_dimensions(const std::shared_ptr& dims); ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); -std::shared_ptr create_minus_inf(const ov::element::Type& T); ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { MATCHER_SCOPE(GroupQeuryAttentionDecomposition); @@ -191,7 +190,12 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); - auto minus_inf = create_minus_inf(T); + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + std::shared_ptr minus_inf = nullptr; + if (T == ov::element::f32) + minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); + else if (T == ov::element::f16) + minus_inf = ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 auto atten_mask_sliced = std::make_shared(atten_mask, past_sequence_length, @@ -309,14 +313,3 @@ ov::OutputVector make_split(const ov::Output& value, int64_t num_split return split->outputs(); } - -std::shared_ptr create_minus_inf(const ov::element::Type& T) { - // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp - if (T == ov::element::f32) { - return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); - } else if (T == ov::element::f16) { - return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); - } else { - OPENVINO_THROW("GroupQueryAttention only supports f32 and f16"); - } -} From 7977828d2dbd5add7ebbe2c80e1aa4cd7d0b899e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 7 Feb 2025 17:14:35 +0800 Subject: [PATCH 7/7] set input total_sequence_length to Null --- .../group_query_attention_decomposition.cpp | 1 - .../include/openvino/op/group_query_attention.hpp | 12 ++++++------ src/core/src/op/group_query_attention.cpp | 4 ++-- .../src/op/com.microsoft/group_query_attention.cpp | 2 ++ 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 5f1476e4bceae8..4dae56ef8b348b 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -81,7 +81,6 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto past_key = node->input_value(3); auto past_value = node->input_value(4); auto seqlens_k = node->input_value(5); - auto total_sequence_length = node->input_value(6); // unused, it's not always equal (seqlens_k + 1) auto cos_cache = node->input_value(7); auto sin_cache = node->input_value(8); diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp index 744dac6aebd82a..1efdfa53a07e3b 100644 --- a/src/core/include/openvino/op/group_query_attention.hpp +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -16,8 +16,8 @@ class OPENVINO_API GroupQueryAttention : public Op { GroupQueryAttention() = default; GroupQueryAttention(const ov::OutputVector& args, - unsigned int num_heads, - unsigned int kv_num_heads, + int64_t num_heads, + int64_t kv_num_heads, float scale, bool do_rotary, bool rotary_interleaved); @@ -25,10 +25,10 @@ class OPENVINO_API GroupQueryAttention : public Op { bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - unsigned int get_num_heads() const { + int64_t get_num_heads() const { return m_num_heads; } - unsigned int get_kv_num_heads() const { + int64_t get_kv_num_heads() const { return m_kv_num_heads; } float get_scale() const { @@ -42,8 +42,8 @@ class OPENVINO_API GroupQueryAttention : public Op { } private: - unsigned int m_num_heads; - unsigned int m_kv_num_heads; + int64_t m_num_heads; + int64_t m_kv_num_heads; float m_scale = 0; bool m_do_rotary = false; bool m_rotary_interleaved = false; diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 475110e66bf5e3..448ee2471b581a 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -13,8 +13,8 @@ namespace op { namespace v15 { GroupQueryAttention::GroupQueryAttention(const OutputVector& args, - unsigned int num_heads, - unsigned int kv_num_heads, + int64_t num_heads, + int64_t kv_num_heads, float scale, bool do_rotary, bool rotary_interleaved) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 692d2529caf7e5..696360e7475201 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -31,6 +31,8 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { for (const auto& input : onnx_op_inputs) { ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); } + // total_sequence_length is not used currently in OV GQA + ov_op_inputs[6] = std::make_shared(); return std::make_shared(ov_op_inputs, num_heads, kv_num_heads,