Skip to content

Commit

Permalink
updated to use a separate flag
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Jan 31, 2025
1 parent 22a871a commit a237417
Showing 1 changed file with 10 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,23 +317,14 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion(bool fuse_mlp_swi
return ov::shape_size(shape.to_shape()) == 1;
};

auto get_const_value = [](const std::shared_ptr<ov::op::v0::Constant>& const_layer) -> float {
float const_value = -1.f;
if (const_layer->get_element_type() == ov::element::f16) {
const_value = std::stof(const_layer->get_data_ptr<ov::float16>()->to_string());
} else if (const_layer->get_element_type() == ov::element::f32) {
const_value = *const_layer->get_data_ptr<float>();
}
return const_value;
};

float const_value = -1.f;
std::vector<float> const_values;
bool can_be_merged = true;
std::shared_ptr<ov::op::v0::Constant> const_node = nullptr;
for (auto& output : output_split->outputs()) {
auto target_node = output.get_target_inputs().begin()->get_node();
if (output.get_target_inputs().size() > 1 ||
!ov::is_type<ov::op::v1::Multiply>(target_node)) {
const_value = -1.f;
can_be_merged = false;
break;
}

Expand All @@ -342,25 +333,21 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion(bool fuse_mlp_swi
if (is_scalar_const(input.get_source_output())) {
const_node = std::dynamic_pointer_cast<ov::op::v0::Constant>(
input.get_source_output().get_node_shared_ptr());

if (const_value < 0.f) {
const_value = get_const_value(const_node);
} else if (const_value != get_const_value(const_node)) {
const_value = -1.f;
break;
}
const_values.emplace_back(const_node->cast_vector<float>()[0]);
} else {
const_value = -1.f;
can_be_merged = false;
break;
}
}
}
}

if (const_value < 0.f)
break;
if (const_values.size() != split_size ||
!std::equal(const_values.begin() + 1, const_values.end(), const_values.begin())) {
can_be_merged = false;
}

if (const_value > 0.f) {
if (can_be_merged) {
auto new_mul = std::make_shared<ov::op::v1::Multiply>(new_fc, const_node);
new_mul->set_friendly_name(new_fc->get_friendly_name() + "_mul");
ov::NodeVector fused_mul_nodes;
Expand Down

0 comments on commit a237417

Please sign in to comment.