diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp new file mode 100644 index 00000000000000..51c21f4808c61e --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/group_query_attention.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; + +} // namespace pass +} // namespace ov + +class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GroupQueryAttentionDecomposition", "0"); + GroupQueryAttentionDecomposition(); + ov::OutputVector decompose(std::shared_ptr node); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 87813bae65538d..66651b3907f344 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -108,6 +108,7 @@ #include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/group_normalization_decomposition.hpp" +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" #include "transformations/op_conversions/hsigmoid_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp" #include "transformations/op_conversions/log_softmax_decomposition.hpp" @@ -156,6 +157,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); + ADD_MATCHER(decomp, GroupQueryAttentionDecomposition) ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition) ADD_MATCHER(decomp, Gelu7Downgrade) ADD_MATCHER(decomp, BidirectionalSequenceDecomposition) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp new file mode 100644 index 00000000000000..1441e3e67a633d --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -0,0 +1,322 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/variadic_split.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); +ov::OutputVector make_split(const ov::Output& value, const std::vector& split_lengths, int64_t axis); +std::shared_ptr create_minus_inf(const ov::element::Type& T); + +ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { + MATCHER_SCOPE(GroupQeuryAttentionDecomposition); + auto pattern_node = ov::pass::pattern::wrap_type(); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto node = + ov::as_type_ptr(pattern_to_output.at(pattern_node).get_node_shared_ptr()); + + if (node == nullptr || transformation_callback(node)) { + return false; + } + + auto new_output_node = decompose(node); + ov::replace_node(node, new_output_node); + return true; + }; + + auto m = std::make_shared(pattern_node, matcher_name); + register_matcher(m, callback); +} + +ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( + std::shared_ptr node) { + using namespace ov::op; + + const auto num_heads = node->get_num_heads(); + const auto kv_num_heads = node->get_kv_num_heads(); + const auto scale = node->get_scale(); + const auto do_rotary = node->get_do_rotary(); + const auto rotary_interleaved = node->get_rotary_interleaved(); + // TODO: add softcap support + + auto Q = node->input_value(0); + auto K = node->input_value(1); + auto V = node->input_value(2); + auto past_key = node->input_value(3); + auto past_value = node->input_value(4); + auto seqlens_k = node->input_value(5); + auto total_sequence_length = node->input_value(6); // unused, it's not always equal (seqlens_k + 1) + auto cos_cache = node->input_value(7); + auto sin_cache = node->input_value(8); + + // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current + + const auto T = Q.get_element_type(); + const auto node_shape = std::make_shared(Q); + const auto batch_size = get_dimensions(node_shape, {0}); + const auto current_seqlen_size = get_dimensions(node_shape, {1}); + const auto hidden_size = get_dimensions(node_shape, {2}); + const auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + + if (v15::Null::is_null(K)) { + // Handle the packed QKV + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size, current_seqlen_size, total_num_heads_node, head_size_node}, + 0); + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + // (batch_size, sequence_len, num_head, head_size) + inputs_qkv = std::make_shared(inputs_qkv, perm); + // split the node into 3 even parts Q, K, V with shape (batch_size, num_head, sequence_len, head_size) + auto split = make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + Q = split[0]; + K = split[1]; + V = split[2]; + } else { + Q = std::make_shared(Q, perm); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); + } + + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto seqlens_elemi64 = std::make_shared(seqlens_k, ov::element::i64); + auto real_seqlens = std::make_shared(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = std::make_shared(real_seqlens, one, false); + auto past_sequence_length = std::make_shared(seqlens_1d, current_seqlen_size); + + if (do_rotary) { + Q = rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + K = rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + } + // present = concat(K, V) if 'past' input is unavailable + // or + // present = concat(past, K, V) + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { + auto past_datas = std::make_shared(past, zero, past_sequence_length, one, two); + auto curr_datas = std::make_shared(current, zero, current_seqlen_size, one, two); + return std::make_shared(ov::NodeVector{past_datas, curr_datas}, 2); + }; + + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = std::make_shared(K); + // (batch_size, num_heads, sequence_len, head_size) + const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = std::make_shared(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = std::make_shared(K, new_kv_shape, false); + V = std::make_shared(V, new_kv_shape, false); + K = std::make_shared(ov::OutputVector(kv_num_heads_factor, K), 2); + V = std::make_shared(ov::OutputVector(kv_num_heads_factor, V), 2); + auto q_shape = std::make_shared(Q); + // (batch_size, num_heads, sequence_len, head_size) + const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = std::make_shared(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = std::make_shared(K, extended_kv_shape, false); + V = std::make_shared(V, extended_kv_shape, false); + } + + // need to apply low-triangle mask to attention score. + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 + std::shared_ptr mask_per_line_node = + std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(T, ov::Shape{}, {0}); + auto minus_inf = create_minus_inf(T); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(T, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads, head_size) + auto qga_output_transposed = std::make_shared(qga_output, perm); + auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); + // reshape the result from (batch_size, sequence_length, num_heads, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true)->output(0); + + return {output, present_k, present_v}; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + using namespace ov::op; + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + + auto slice_cache_dim_shape = seqlen_k; + + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); + + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; + + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = std::make_shared(input, split_input_shape, false); + + auto in_split = make_split(reshaped_input, 2, -1); + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim}, 0); + auto in_split_0 = std::make_shared(in_split[0], split_input_shape, false); + auto in_split_1 = std::make_shared(in_split[1], split_input_shape, false); + + auto res_0 = std::make_shared(std::make_shared(in_split_0, cos), + std::make_shared(in_split_1, sin)); + auto res_1 = std::make_shared(std::make_shared(in_split_0, sin), + std::make_shared(in_split_1, cos)); + + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, one}, 0); + auto res_0_5d = std::make_shared(res_0, split_input_shape, false); + auto res_1_5d = std::make_shared(res_1, split_input_shape, false); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0_5d, res_1_5d}, -1); + return std::make_shared(concat_ret, input_shape, false); + } else { + auto in_split = make_split(input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + return std::make_shared(ov::NodeVector{res_0, res_1}, -1); + } +} + +// make split functions is a copy-past from ONNX FE. TODO: move it to one place +ov::OutputVector make_split(const ov::Output& value, + const std::vector& split_lengths, + int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split_lengths_node = + v0::Constant::create(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths); + const auto variadic_split = std::make_shared(value, axis_node, split_lengths_node); + + return variadic_split->outputs(); +} + +ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis) { + using namespace ov::op; + const auto axis_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}); + const auto split = std::make_shared(value, axis_node, num_splits); + + return split->outputs(); +} + +std::shared_ptr create_minus_inf(const ov::element::Type& T) { + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + if (T == ov::element::f32) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()}); + } else if (T == ov::element::f16) { + return ov::op::v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()}); + } else { + OPENVINO_THROW("GroupQueryAttention only supports f32 and f16"); + } +} diff --git a/src/core/include/openvino/op/group_query_attention.hpp b/src/core/include/openvino/op/group_query_attention.hpp new file mode 100644 index 00000000000000..25337e0faf9be5 --- /dev/null +++ b/src/core/include/openvino/op/group_query_attention.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +// This is an experimental operation that is implemented in the plugins. +class OPENVINO_API GroupQueryAttention : public Op { +public: + OPENVINO_OP("GroupQueryAttention", "opset15", op::Op); + + GroupQueryAttention() = default; + GroupQueryAttention(const ov::OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved); + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + unsigned int get_num_heads() const { + return m_num_heads; + } + unsigned int get_kv_num_heads() const { + return m_kv_num_heads; + } + float get_scale() const { + return m_scale; + } + bool get_do_rotary() const { + return m_do_rotary; + } + bool get_rotary_interleaved() const { + return m_rotary_interleaved; + } + +private: + unsigned int m_num_heads; + unsigned int m_kv_num_heads; + float m_scale = 0; + bool m_do_rotary = false; + bool m_rotary_interleaved = false; +}; + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/null.hpp b/src/core/include/openvino/op/null.hpp new file mode 100644 index 00000000000000..346ff3d77c532b --- /dev/null +++ b/src/core/include/openvino/op/null.hpp @@ -0,0 +1,47 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v15 { + +/// \brief Represents a missing optional input or output of an ONNX node +/// +/// Some ONNX operators have inputs or outputs that are marked as optional, +/// which means that a referring node MAY forgo providing values for such inputs +/// or computing these outputs. +/// An empty string is used in place of a name of such input or output. +/// +/// More: +/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs +class OPENVINO_API Null : public Op { +public: + OPENVINO_OP("Null", "opset15", op::Op); + Null() { + set_output_size(1); + } + + static bool is_null(const ov::Node* node) { + return ov::as_type(node) != nullptr; + } + + static bool is_null(const std::shared_ptr& node) { + return is_null(node.get()); + } + + static bool is_null(const Output& output) { + return is_null(output.get_node()); + } + + virtual std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + return std::make_shared(); + } +}; +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index 3605f4215305ff..bb233a471e9d03 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -167,6 +167,8 @@ #include "openvino/op/roll.hpp" #include "openvino/op/round.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/null.hpp" +#include "openvino/op/group_query_attention.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/scatter_update.hpp" diff --git a/src/core/include/openvino/opsets/opset15_tbl.hpp b/src/core/include/openvino/opsets/opset15_tbl.hpp index a9e8d2a8dcc840..8dd7c0d7b830cc 100644 --- a/src/core/include/openvino/opsets/opset15_tbl.hpp +++ b/src/core/include/openvino/opsets/opset15_tbl.hpp @@ -234,3 +234,5 @@ _OPENVINO_OP_REG(BitwiseLeftShift, ov::op::v15) _OPENVINO_OP_REG(BitwiseRightShift, ov::op::v15) _OPENVINO_OP_REG(SliceScatter, ov::op::v15) _OPENVINO_OP_REG(SearchSorted, ov::op::v15) +_OPENVINO_OP_REG(GroupQueryAttention, ov::op::v15) +_OPENVINO_OP_REG(Null, ov::op::v15) diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp new file mode 100644 index 00000000000000..475110e66bf5e3 --- /dev/null +++ b/src/core/src/op/group_query_attention.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" + +#include "itt.hpp" +#include "openvino/op/null.hpp" + +using namespace std; +namespace ov { +namespace op { +namespace v15 { + +GroupQueryAttention::GroupQueryAttention(const OutputVector& args, + unsigned int num_heads, + unsigned int kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved) + : Op(args), + m_num_heads(num_heads), + m_kv_num_heads(kv_num_heads), + m_scale(scale), + m_do_rotary(do_rotary), + m_rotary_interleaved(rotary_interleaved) { + constructor_validate_and_infer_types(); +} + +int64_t get_head_size(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + return input_shape[2].get_length() / (num_heads + kv_num_heads * 2); +} + +std::vector get_qkv_sizes(const PartialShape& input_shape, int num_heads, int kv_num_heads) { + int64_t per_head_size = get_head_size(input_shape, num_heads, kv_num_heads); + const std::vector qkv_sizes = {num_heads * per_head_size, + kv_num_heads * per_head_size, + kv_num_heads * per_head_size}; + return qkv_sizes; +} + +void GroupQueryAttention::validate_and_infer_types() { + OV_OP_SCOPE(v15_GroupQueryAttention_validate_and_infer_types); + PartialShape input_shape = get_input_partial_shape(0); + Dimension batch_size = input_shape[0]; + Dimension sequence_len = input_shape[1]; + Dimension head_size; + if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) { + head_size = get_head_size(input_shape, m_num_heads, m_kv_num_heads); + } else { + head_size = input_shape[2].get_length() / m_num_heads; + } + Dimension output_kv_len; + PartialShape kv_past_shape = get_input_partial_shape(3); + // FIXME: https://github.com/openvinotoolkit/openvino/pull/27648 + if (kv_past_shape[2].is_static()) { + output_kv_len = kv_past_shape[2] + sequence_len; + } else { + output_kv_len = ov::Dimension(); + } + auto element_type = get_input_element_type(0); + NODE_VALIDATION_CHECK(this, + element_type == element::f32 || element_type == element::f16, + "GroupQueryAttention only suuports f32 and f16"); + set_output_type(0, element_type, PartialShape{batch_size, sequence_len, head_size * m_num_heads}); + set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); + set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); +} + +bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(v15_GroupQueryAttention_visit_attributes); + visitor.on_attribute("do_rotary", m_do_rotary); + visitor.on_attribute("kv_num_heads", m_kv_num_heads); + visitor.on_attribute("num_heads", m_num_heads); + visitor.on_attribute("rotary_interleaved", m_rotary_interleaved); + visitor.on_attribute("scale", m_scale); + return true; +} + +std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(v15_GroupQueryAttention_clone_with_new_inputs); + return std::make_shared(new_args, + m_num_heads, + m_kv_num_heads, + m_scale, + m_do_rotary, + m_rotary_interleaved); +} + +} // namespace v15 +} // namespace op +} // namespace ov diff --git a/src/core/tests/opset.cpp b/src/core/tests/opset.cpp index b4908288c347cf..f201261789654e 100644 --- a/src/core/tests/opset.cpp +++ b/src/core/tests/opset.cpp @@ -76,7 +76,7 @@ INSTANTIATE_TEST_SUITE_P(opset, OpsetTestParams{ov::get_opset12, 178}, OpsetTestParams{ov::get_opset13, 186}, OpsetTestParams{ov::get_opset14, 188}, - OpsetTestParams{ov::get_opset15, 199}, + OpsetTestParams{ov::get_opset15, 201}, OpsetTestParams{ov::get_opset16, 4}), OpsetTestNameGenerator{}); diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp new file mode 100644 index 00000000000000..692d2529caf7e5 --- /dev/null +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" +#include "openvino/op/null.hpp" + +#include "core/null_node.hpp" +#include "core/operator_set.hpp" +#include "openvino/frontend/exception.hpp" + +using namespace ov::op; +using ov::Shape; + +namespace ov { +namespace frontend { +namespace onnx { +namespace com_microsoft { + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + const auto onnx_op_inputs = node.get_ov_inputs(); + const auto num_heads = node.get_attribute_value("num_heads"); + const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); + const auto scale = node.get_attribute_value("scale", 0.0f); + const auto do_rotary = node.get_attribute_value("do_rotary", 0); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0.0f); + + OutputVector ov_op_inputs; + ov_op_inputs.reserve(onnx_op_inputs.size()); + for (const auto& input : onnx_op_inputs) { + ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared() : input); + } + return std::make_shared(ov_op_inputs, + num_heads, + kv_num_heads, + scale, + do_rotary, + rotary_interleaved) + ->outputs(); +} + +ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + +} // namespace opset_1 +} // namespace com_microsoft +} // namespace onnx +} // namespace frontend +} // namespace ov diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt new file mode 100644 index 00000000000000..a7dacf0dc94ebc --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..d1400ad344e717 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt new file mode 100644 index 00000000000000..f5ec39c9b0bd8e --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..b61cf39552efc7 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 47a336f1749417..37cd5449484543 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1682,3 +1682,415 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_qlinear_mul) { test_case.add_expected_output(Shape{2, 2}, expected_output); test_case.run(); } + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {1.2561098, + 1.0199738, + -0.05948371, + -0.16574995, + 2.5059946, + -1.738188, + -0.03158256, + -0.35975295, + 1.0918287, + -0.90313876, + -0.4790303, + 0.67029977, + -0.87039495, + 0.7783688, + -0.81333745, + 0.89886224}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {2.118801, + -0.2640816, + -0.5926066, + -0.19455537, + 0.9903903, + 2.954185, + -0.35343042, + -0.07457897, + -0.25603274, + -0.03627284, + 0.56591415, + 0.02181074, + -0.1586003, + 0.96567893, + -0.8591481, + 0.85514885}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.53934956, 0.6341806, 1.0099611, 1.1771176, -1.270278, 2.3782496, -0.511299, 0.30222273, + -1.5435482, -0.5423737, -0.27520883, -0.4914196, -0.96554786, 1.1191509, -0.05004983, 0.85533774, + -0.49356747, 0.19581467, 0.8553029, 1.0041412, -0.9513843, 2.088453, -0.5698854, 0.25103146, + -1.4120293, -0.5311372, 0.03857604, -0.47871974, -0.8099488, 1.1256707, 0.08898184, 0.93131447}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.6519198, 1.1400802, 0.45031136, 0.5877534, -0.65952265, -1.8121169, 0.04630837, 0.5568472, + 0.20271924, 0.7458131, -0.17379119, 0.3623912, 2.5696063, -0.58594, -0.8126341, -0.7919839}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.33396345, -1.332403, 0.31613833, 0.40111685, 0.16033238, 1.0781744, -0.7741276, 0.07257013, + -0.9535321, -0.491965, 1.1324831, -0.43444604, -0.2675047, 1.1483997, 0.57366973, 1.1961825, + -0.24709277, -2.164195, 0.02267693, 0.07289726, 0.7654276, 0.5282906, -0.8852943, -0.02456442, + -0.7039771, -0.47064403, 1.7278847, -0.41034833, 0.02774171, 1.1607709, 0.83748007, 1.3403473}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.2216992, 1.7511603, 0.03145146, -0.62293506, -2.625969, 1.6767058, -0.17887366, 0.313817, + 0.1717277, -0.19334024, 0.4056727, 0.39516917, -0.25018305, 0.9460988, 1.0327814, -0.6345757}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +}