Skip to content

Commit

Permalink
Added a test case with partially dynamic SDPA
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 23, 2025
1 parent b65a324 commit a193890
Showing 1 changed file with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,36 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBroadca
}
}

TEST_F(TransformationTestsF, ScaledDotProductAttentionPartiallyDynamic) {
const PartialShape query_shape{-1, 24, -1, 64};
const PartialShape key_shape{-1, 24, -1, 64};
const PartialShape value_shape{-1, 24, -1, 64};
const PartialShape attention_mask_shape{-1, 24, -1, 64};
const PartialShape scale_shape{};

const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, key_shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, value_shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, attention_mask_shape);
const auto scale = std::make_shared<ov::op::v0::Parameter>(element::f32, scale_shape);
const auto casual = false;
{
const auto scaled_dot_product_attention =
std::make_shared<ov::op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, casual);

model = std::make_shared<ov::Model>(NodeVector{scaled_dot_product_attention},
ParameterVector{query, key, value, attention_mask, scale});
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
}

{
const auto scaled_dot_product_attention =
scaled_dot_product_attention_decomposition(query, key, value, attention_mask, scale, casual);
model_ref = std::make_shared<ov::Model>(NodeVector{scaled_dot_product_attention},
ParameterVector{query, key, value, attention_mask, scale});
}
}

TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
const PartialShape query_shape{-1, -1, -1};
const PartialShape key_shape{-1, -1, -1};
Expand Down

0 comments on commit a193890

Please sign in to comment.