diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index c9daa89803572f..b61be61922e2a1 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -49,9 +49,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" -std::shared_ptr get_present_state(const std::shared_ptr& K, - const std::shared_ptr& V, - const ov::OutputVector& op_inputs); + std::shared_ptr rotaryEmbedding(ov::Output input, ov::Output past_seqlen, std::shared_ptr seqlen_k, @@ -113,7 +111,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( const auto hidden_size = get_dimensions(node_shape, {2}); const auto total_num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); - auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); + auto head_size_node = std::make_shared(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); @@ -175,8 +173,8 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( K = construct_kv_cache(past_key, K); V = construct_kv_cache(past_value, V); - auto present_k = K.get_node_shared_ptr(); - auto present_v = V.get_node_shared_ptr(); + auto present_k = K; + auto present_v = V; const size_t kv_num_heads_factor = num_heads / kv_num_heads; if (kv_num_heads_factor > 1) { @@ -232,7 +230,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); // reshape the result from (batch_size, sequence_length, num_heads, head_size) // to (batch_size, sequence_length, num_heads * head_size) - auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true)->output(0); return {output, present_k, present_v}; } diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index 965a46954f6b44..365dac5119dd84 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -51,7 +51,6 @@ std::vector get_qkv_sizes(const PartialShape& input_shape, int num_head return qkv_sizes; } -// TODO void GroupQueryAttention::validate_and_infer_types() { OV_OP_SCOPE(GroupQueryAttention_validate_and_infer_types); PartialShape input_shape = get_input_partial_shape(0); @@ -67,6 +66,7 @@ void GroupQueryAttention::validate_and_infer_types() { PartialShape kv_past_shape = get_input_partial_shape(3); // FIXME: Original GQA spec depends on the identical tensor set for input/output, but we cannot know it in advance, // hence we base on sequence dimension static/dynamic + // https://github.com/openvinotoolkit/openvino/pull/27648 if (kv_past_shape[2].is_dynamic()) { output_kv_len = kv_past_shape[2] + sequence_len; } else {