From 856b8b04971f8ffa834eeebec9289abee0676e36 Mon Sep 17 00:00:00 2001 From: Dmitry Matveev Date: Thu, 26 Sep 2024 15:37:13 +0100 Subject: [PATCH] NPUW: Support DQ for GQ & GPTQ prefills (#26794) ### Details: - Nothing special, just a couple new patterns - Reduce via Add with Multiply under loop is used for N-tok cases as opposed to 1-tok case ### Tickets: - E-138545 --- .../plugin/npuw/partitioning/partitioning.cpp | 2 + .../plugin/npuw/partitioning/patterns/opt.cpp | 251 +++++++++++++++++- .../plugin/npuw/partitioning/patterns/opt.hpp | 10 + 3 files changed, 253 insertions(+), 10 deletions(-) diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp index 130da23b3c35c5..cf82694e0601b7 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp @@ -1630,6 +1630,8 @@ void Partitioner::optimize(const std::string& func_name) { rewr.add_matcher(); rewr.add_matcher(std::ref(ctx)); rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); rewr.run_on_model(f._model); ov::pass::Validate().run_on_model(f._model); diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp index 7d44ae04835e1d..7fab6298bc989f 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp @@ -144,6 +144,8 @@ DQMatMulCWi::DQMatMulCWi() { register_matcher(std::make_shared(qmm, "OptDQMatMulCWi"), std::move(callback)); } +// 1 token case (generate) +// // FROM: // ???(Act) --------------------------------------------> // Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> MatMul @@ -193,23 +195,21 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) { auto act_shape = matched_out_mmi.get_shape(); auto out_shape = matched_node_matmul->output(0).get_shape(); - if (ov::element::i4 == matched_qweight->get_element_type() && + if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 && ov::element::f32 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 && - qweight_shape.size() == 3 && act_shape.size() == 3 && qcoeff_shape[0] == qweight_shape[0] && - qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] && !matched_matmul->get_transpose_a() && - !matched_matmul->get_transpose_b()) { + act_shape.size() == 3 && act_shape[1] == 1 && // single-token case + qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] && + !matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) { // Mark W closure to transpose, and transpose the respective parameter - ctx.get().permute(matched_qweight, {0, 2, 1}); - - // Mark S closure to be lowered fo f16 - ctx.get().to_f16(matched_qcoeff); - ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]}; matched_qweight->set_partial_shape(tw_shape); matched_qweight->validate_and_infer_types(); + ctx.get().permute(matched_qweight, {0, 2, 1}); + // Mark S closure to be lowered fo f16 matched_qcoeff->set_element_type(ov::element::f16); matched_qcoeff->validate_and_infer_types(); + ctx.get().to_f16(matched_qcoeff); // Reshape the Act to group format const auto NSPLIT = qweight_shape[0]; @@ -314,7 +314,7 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) { if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 && ov::element::f16 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 && - act_shape.size() == 3 && qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[2] == 1 && + act_shape.size() == 3 && act_shape[1] == 1 && qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[2] == 1 && qcoeff_shape[1] == qweight_shape[1] && !matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) { // Mark W closure to transpose, and transpose the respective parameter @@ -383,6 +383,237 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) { register_matcher(std::make_shared(qmm, "OptDQMatMulGQ2i"), std::move(callback)); } +// N token case (prompt) +// +// FROM: +// ???(Act) --------------------------------------------> +// Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> MatMul +// Param(S) ---------------------> +// +// WHERE (example): +// Act: [ 1, N, 4096] +// W: [32,128,11008] +// S: [32, 1,11008] +// [1, N ,128] x +// TO: [1,11K,128]T = +// [N,32,128] [1,N,128] [1, N ,11K] [32,N,11K] +// ???(Act) -> Reshape > Split(/32) ->[to(f16) -> Reshape -> ]} +// Param(W*) -----------> Split(/32) ->[to(f16) ------------> MatMul v ]} 32xAdd +// Param(S) -------------Split(/32) ->[--------------------> Multiply ]} v +// to(f32) +// WHERE: +// W* : [32,11008,128] +DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qmuls = opp::wrap_type({qcvtw, qcoeff}); + auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qmmi = opp::any_input(); + auto qmm = opp::wrap_type({qmmi, qreshp}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); + auto matched_out_mmi = node_to_output.at(qmmi); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + auto matched_matmul = std::static_pointer_cast(matched_node_matmul); + + auto qweight_shape = matched_qweight->output(0).get_shape(); + auto qcoeff_shape = matched_qcoeff->output(0).get_shape(); + auto act_shape = matched_out_mmi.get_shape(); + auto out_shape = matched_node_matmul->output(0).get_shape(); + + if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 && + ov::element::f32 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 && + act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case + qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] && + !matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) { + // Mark W closure to transpose, and transpose the respective parameter + ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]}; + matched_qweight->set_partial_shape(tw_shape); + matched_qweight->validate_and_infer_types(); + ctx.get().permute(matched_qweight, {0, 2, 1}); + + // Mark S closure to be lowered fo f16 + matched_qcoeff->set_element_type(ov::element::f16); + matched_qcoeff->validate_and_infer_types(); + ctx.get().to_f16(matched_qcoeff); + + // Reshape the Act to group format + const auto NSPLIT = qweight_shape[0]; + std::vector rshp_act_v = {act_shape[1], NSPLIT, act_shape[2] / NSPLIT}; + auto rshp_act_c = std::make_shared(ov::element::i32, ov::Shape{3}, rshp_act_v); + auto rshp_act = std::make_shared(matched_out_mmi, rshp_act_c, false); + + // Split Act and W, and S tensors by NSPLIT + auto split_axis_a = std::make_shared(ov::element::i32, ov::Shape{}, 1); + auto split_a = std::make_shared(rshp_act, split_axis_a, NSPLIT); + + auto split_axis_w = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto split_w = std::make_shared(matched_qweight, split_axis_w, NSPLIT); + auto split_s = std::make_shared(matched_qcoeff, split_axis_w, NSPLIT); + + std::vector r_a_v = {1, act_shape[1], act_shape[2] / NSPLIT}; + auto r_a_c = std::make_shared(ov::element::i32, ov::Shape{3}, r_a_v); + + // Do the CW MM for every split + std::vector> to_concat; + for (std::size_t i = 0; i < NSPLIT; i++) { + auto a_f16 = std::make_shared(split_a->output(i), ov::element::f16); + auto r_f16 = std::make_shared(a_f16, r_a_c, false); + auto w_f16 = std::make_shared(split_w->output(i), ov::element::f16); + auto m_f16 = std::make_shared(r_f16, w_f16, false, true); + auto s_f16 = std::make_shared(m_f16, split_s->output(i)); + to_concat.push_back(s_f16); + } + + // Reduce via Add + std::vector> reduce; + reduce.push_back(std::make_shared(to_concat[0], to_concat[1])); + for (std::size_t i = 1; i < NSPLIT - 1; i++) { + reduce.push_back(std::make_shared(reduce[i - 1], to_concat[i + 1])); + } + + // Convert the result to f32 to maintain the graph contracts. FIXME should be avoided + auto out = std::make_shared(reduce.back(), ov::element::f32); + + // Now.. Reconnect the matmul readers to the new output (reducesum) + for (auto&& r : matched_matmul->output(0).get_target_inputs()) { + r.replace_source_output(out); + } + return true; // root has changed + } + return false; // did nothing here + }; + register_matcher(std::make_shared(qmm, "OptDQMatMulGQiP"), std::move(callback)); +} + +// N token case (prompt) +// +// FROM: +// ???(Act) -------------------------------------------------------> +// Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> Convert -> MatMul +// Param(S) ---------------------> +// +// WHERE (example): +// Act: [ 1, N,4096] +// W: [11008,32, 128] +// S: [11008,32, 1] +// [1, N ,128] x +// TO: [1,11K,128]T = +// [N,32,128] [1,N,128] [1, N ,11K] [32,N,11K] +// ???(Act) -> Reshape > Split(/32) ->[to(f16) - Reshape -> ]} +// Param(W*) -----------> Split(/32) ->[to(f16) -----------> MatMul v ]} 32xAdd +// Param(S*) -----------> Split(/32) ->[-------------------> Multiply ]} v +// to(f32) +// WHERE: +// W* : [32,11008, 128] +// S* : [32, 1,11008] +DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qmuls = opp::wrap_type({qcvtw, qcoeff}); + auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qcvtm = opp::optional({qreshp->output(0)}); + auto qmmi = opp::any_input(); + auto qmm = opp::wrap_type({qmmi, qcvtm}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); + auto matched_out_mmi = node_to_output.at(qmmi); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + auto matched_matmul = std::static_pointer_cast(matched_node_matmul); + + auto qweight_shape = matched_qweight->output(0).get_shape(); + auto qcoeff_shape = matched_qcoeff->output(0).get_shape(); + auto act_shape = matched_out_mmi.get_shape(); + auto out_shape = matched_node_matmul->output(0).get_shape(); + + if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 && + ov::element::f16 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 && + act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case + qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == qweight_shape[1] && qcoeff_shape[2] == 1 && + !matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) { + // Mark W closure to transpose, and transpose the respective parameter + ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]}; + matched_qweight->set_partial_shape(tw_shape); + matched_qweight->validate_and_infer_types(); + ctx.get().permute(matched_qweight, {1, 0, 2}); + + // Also transpose S, but in a different way (see diagram above) + ctx.get().permute(matched_qcoeff, {1, 2, 0}); + + ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]}; + matched_qcoeff->set_partial_shape(ts_shape); + matched_qcoeff->validate_and_infer_types(); + + // Reshape the Act to group format + const auto NSPLIT = qweight_shape[1]; + std::vector rshp_act_v = {act_shape[1], NSPLIT, act_shape[2] / NSPLIT}; + auto rshp_act_c = std::make_shared(ov::element::i32, ov::Shape{3}, rshp_act_v); + auto rshp_act = std::make_shared(matched_out_mmi, rshp_act_c, false); + + // Split Act and W, and S tensors by NSPLIT + auto split_axis_a = std::make_shared(ov::element::i32, ov::Shape{}, 1); + auto split_a = std::make_shared(rshp_act, split_axis_a, NSPLIT); + + auto split_axis_w = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto split_w = std::make_shared(matched_qweight, split_axis_w, NSPLIT); + auto split_s = std::make_shared(matched_qcoeff, split_axis_w, NSPLIT); + + std::vector r_a_v = {1, act_shape[1], act_shape[2] / NSPLIT}; + auto r_a_c = std::make_shared(ov::element::i32, ov::Shape{3}, r_a_v); + + // Do the CW MM for every split + std::vector> to_concat; + for (std::size_t i = 0; i < NSPLIT; i++) { + auto a_f16 = std::make_shared(split_a->output(i), ov::element::f16); + auto r_f16 = std::make_shared(a_f16, r_a_c, false); + auto w_f16 = std::make_shared(split_w->output(i), ov::element::f16); + auto m_f16 = std::make_shared(r_f16, w_f16, false, true); + auto s_f16 = std::make_shared(m_f16, split_s->output(i)); + to_concat.push_back(s_f16); + } + + // Reduce via Add + std::vector> reduce; + reduce.push_back(std::make_shared(to_concat[0], to_concat[1])); + for (std::size_t i = 1; i < NSPLIT - 1; i++) { + reduce.push_back(std::make_shared(reduce[i - 1], to_concat[i + 1])); + } + + ov::Output out = reduce.back(); + if (matched_matmul->output(0).get_element_type() == ov::element::f32) { + // Convert the result to f32 to maintain the graph contracts, if needed + out = std::make_shared(out, ov::element::f32); + } + + // Now.. Reconnect the matmul readers to the new output (reducesum) + for (auto&& r : matched_matmul->output(0).get_target_inputs()) { + r.replace_source_output(out); + } + return true; // root has changed + } + return false; // did nothing here + }; + register_matcher(std::make_shared(qmm, "OptDQMatMulGQ2iP"), std::move(callback)); +} + // Identifies this pattern // // Multiply -----------------------------------> MatMul diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp index fb867f8a344234..b51b32df23f2a2 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp @@ -61,6 +61,16 @@ class DQMatMulGQ2i : public ov::pass::MatcherPass { explicit DQMatMulGQ2i(Context::Ref ctx); }; +class DQMatMulGQiP : public ov::pass::MatcherPass { +public: + explicit DQMatMulGQiP(Context::Ref ctx); +}; + +class DQMatMulGQ2iP : public ov::pass::MatcherPass { +public: + explicit DQMatMulGQ2iP(Context::Ref ctx); +}; + class DQParMMGQ : public ov::pass::MatcherPass { public: explicit DQParMMGQ(Context::Ref ctx);