Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] allow scalar eltwise primitive fusion with gemm #28764

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2672,14 +2672,16 @@ bool primitive_inst::is_valid_fusion() const {
const auto& outer_dep = _deps[outer_dep_idx];

const auto& outer_dep_pshape = outer_dep.first->_impl_params->get_output_layout().get_partial_shape();
size_t outer_dep_pshape_count = outer_dep_pshape.is_static() ? ov::shape_size(outer_dep_pshape.to_shape()) : 0;
auto merged_shape = out_pshape;
bool can_broadcast = true;
if (fd.is_type<eltwise>())
can_broadcast = ov::PartialShape::broadcast_merge_into(merged_shape, outer_dep_pshape, fd.typed_desc<eltwise>()->broadcast_spec);

// Check if broadcast happens more than single axis.
// Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension.
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length()) {
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length() &&
outer_dep_pshape_count != 1) {
uint8_t broadcast_more_than_single_axis = 0;
auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape);

Expand Down Expand Up @@ -2715,7 +2717,7 @@ bool primitive_inst::is_valid_fusion() const {
cldnn::format::dimension(data_layout.format),
false);

if (gemm_dims[0] != data_dims[0])
if (gemm_dims[0] != data_dims[0] && outer_dep_pshape_count != 1)
return false;
} else if (_node->is_type<fully_connected>() && _node->get_preferred_impl_type() == impl_types::onednn) {
const auto& fc_layout = _impl_params->get_output_layout();
Expand Down
34 changes: 34 additions & 0 deletions src/plugins/intel_gpu/tests/unit/fusions/gemm_fusion_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,40 @@ TEST_P(gemm_2in_add, eltwise_postop_scalar) {
execute(p, false, true);
}

TEST_P(gemm_2in_add, eltwise_postop_scalar_dynamic) {
auto p = GetParam();

if (engine.get_device_info().supports_immad) {
ov::intel_gpu::ImplementationDesc gemmv_impl = { cldnn::format::type::any, "", impl_types::onednn };
cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "gemm_prim", gemmv_impl } }));
cfg_fused.set_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape(true));
}

auto add_data_layout = get_output_layout(p);
auto add_data_size = add_data_layout.get_partial_shape();
for (size_t i = 0; i < add_data_size.size(); i++)
add_data_size[i] = 1;
add_data_layout.set_partial_shape(add_data_size);

auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);

in_layout0.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[0].size()));
in_layout1.set_partial_shape(ov::PartialShape::dynamic(p.in_shapes[1].size()));

create_topologies(
input_layout("input0", in_layout0),
input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 0.5f)),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
);

tolerance = default_tolerance(p.default_type);
execute(p, true, true);
}

INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vector<gemm_test_params>{
// gemm_test_params{ CASE_GEMM_2IN_FP16_3, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum }, // TODO: check why failed in eltwise_postop_dynamic
gemm_test_params{ CASE_GEMM_2IN_FP16_4, 3, 4, "", broadcast_kinds::none, eltwise_mode::sum },
Expand Down
Loading