Skip to content

Commit

Permalink
[CPU] OV CPU plugin fails to infer SegNext model
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego committed Jan 23, 2025
1 parent 8dfb9cc commit aeb4a2f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
*/

auto get_aligned_shapes =
[shape_a, shape_b, rank_a, rank_b, &matmul]() -> std::tuple<bool, ov::PartialShape, ov::PartialShape> {
[shape_a, shape_b, rank_a, rank_b, &matmul, fc_input_a, fc_input_b]() -> std::tuple<bool, ov::PartialShape, ov::PartialShape> {
ov::PartialShape shape_a_aligned(shape_a), shape_b_aligned(shape_b);
size_t max_size = std::max(rank_a, rank_b);
for (size_t i = 0, cnt = max_size - rank_a; i < cnt; ++i) {
Expand All @@ -79,10 +79,10 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
}

if (matmul->get_transpose_a()) {
if (matmul->get_transpose_a() && !is_type<op::v1::Transpose>(fc_input_a.get_node())) {
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
}
if (matmul->get_transpose_b()) {
if (matmul->get_transpose_b() && !is_type<op::v1::Transpose>(fc_input_b.get_node())) {
std::swap(*(shape_b_aligned.end() - 1), *(shape_b_aligned.end() - 2));
}

Expand Down Expand Up @@ -138,12 +138,12 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
}

// Weights normalization
if (!matmul->get_transpose_b()) {
if (!matmul->get_transpose_b() && !is_type<op::v1::Transpose>(fc_input_b.get_node())) {
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
}

// Input normalization
if (matmul->get_transpose_a()) {
if (matmul->get_transpose_a() && !is_type<op::v1::Transpose>(fc_input_a.get_node())) {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
manager,
[](const_node_ptr& node) -> bool {
const auto consumers = node->get_output_target_inputs(0);
return std::all_of(consumers.begin(), consumers.end(), [](const ov::Input<ov::Node>& consumer) {
return !ov::is_type<ov::op::v0::MatMul>(consumer.get_node());
return !std::all_of(consumers.begin(), consumers.end(), [](const ov::Input<ov::Node>& consumer) {
return ov::is_type<ov::op::v0::MatMul>(consumer.get_node());
});
},
ov::pass::KeepConstAndDecompression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ namespace {

const auto testParams2D_FP16_2_smoke =
::testing::Combine(::testing::Values(static_shapes_to_test_representation({{2, 3}, {2, 3}, {3, 4}})),
::testing::Values(std::pair<bool, bool>{false, true}),
::testing::Values(std::pair<bool, bool>{false, true}, std::pair<bool, bool>{true, false}),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));
Expand Down

0 comments on commit aeb4a2f

Please sign in to comment.