Skip to content

Commit

Permalink
NPUW: Support DQ for GQ & GPTQ prefills (#26794)
Browse files Browse the repository at this point in the history
### Details:
 - Nothing special, just a couple new patterns
- Reduce via Add with Multiply under loop is used for N-tok cases as
opposed to 1-tok case

### Tickets:
 - E-138545
  • Loading branch information
dmatveev authored Sep 26, 2024
1 parent 1fd9c2b commit 856b8b0
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,8 @@ void Partitioner::optimize(const std::string& func_name) {
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>();
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQ2i>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQiP>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQ2iP>(std::ref(ctx));
rewr.run_on_model(f._model);
ov::pass::Validate().run_on_model(f._model);

Expand Down
251 changes: 241 additions & 10 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ DQMatMulCWi::DQMatMulCWi() {
register_matcher(std::make_shared<opp::Matcher>(qmm, "OptDQMatMulCWi"), std::move(callback));
}

// 1 token case (generate)
//
// FROM:
// ???(Act) -------------------------------------------->
// Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> MatMul
Expand Down Expand Up @@ -193,23 +195,21 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
auto act_shape = matched_out_mmi.get_shape();
auto out_shape = matched_node_matmul->output(0).get_shape();

if (ov::element::i4 == matched_qweight->get_element_type() &&
if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 &&
ov::element::f32 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 &&
qweight_shape.size() == 3 && act_shape.size() == 3 && qcoeff_shape[0] == qweight_shape[0] &&
qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] && !matched_matmul->get_transpose_a() &&
!matched_matmul->get_transpose_b()) {
act_shape.size() == 3 && act_shape[1] == 1 && // single-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
// Mark W closure to transpose, and transpose the respective parameter
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
ctx.get().to_f16(matched_qcoeff);

ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
ctx.get().to_f16(matched_qcoeff);

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[0];
Expand Down Expand Up @@ -314,7 +314,7 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {

if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 &&
ov::element::f16 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 &&
act_shape.size() == 3 && qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[2] == 1 &&
act_shape.size() == 3 && act_shape[1] == 1 && qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[2] == 1 &&
qcoeff_shape[1] == qweight_shape[1] && !matched_matmul->get_transpose_a() &&
matched_matmul->get_transpose_b()) {
// Mark W closure to transpose, and transpose the respective parameter
Expand Down Expand Up @@ -383,6 +383,237 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
register_matcher(std::make_shared<opp::Matcher>(qmm, "OptDQMatMulGQ2i"), std::move(callback));
}

// N token case (prompt)
//
// FROM:
// ???(Act) -------------------------------------------->
// Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> MatMul
// Param(S) --------------------->
//
// WHERE (example):
// Act: [ 1, N, 4096]
// W: [32,128,11008]
// S: [32, 1,11008]
// [1, N ,128] x
// TO: [1,11K,128]T =
// [N,32,128] [1,N,128] [1, N ,11K] [32,N,11K]
// ???(Act) -> Reshape > Split(/32) ->[to(f16) -> Reshape -> ]}
// Param(W*) -----------> Split(/32) ->[to(f16) ------------> MatMul v ]} 32xAdd
// Param(S) -------------Split(/32) ->[--------------------> Multiply ]} v
// to(f32)
// WHERE:
// W* : [32,11008,128]
DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qreshp});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qcoeff);
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);

auto qweight_shape = matched_qweight->output(0).get_shape();
auto qcoeff_shape = matched_qcoeff->output(0).get_shape();
auto act_shape = matched_out_mmi.get_shape();
auto out_shape = matched_node_matmul->output(0).get_shape();

if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 &&
ov::element::f32 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 &&
act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
ctx.get().to_f16(matched_qcoeff);

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[0];
std::vector<std::size_t> rshp_act_v = {act_shape[1], NSPLIT, act_shape[2] / NSPLIT};
auto rshp_act_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, rshp_act_v);
auto rshp_act = std::make_shared<ov::op::v1::Reshape>(matched_out_mmi, rshp_act_c, false);

// Split Act and W, and S tensors by NSPLIT
auto split_axis_a = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
auto split_a = std::make_shared<ov::op::v1::Split>(rshp_act, split_axis_a, NSPLIT);

auto split_axis_w = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_w = std::make_shared<ov::op::v1::Split>(matched_qweight, split_axis_w, NSPLIT);
auto split_s = std::make_shared<ov::op::v1::Split>(matched_qcoeff, split_axis_w, NSPLIT);

std::vector<std::size_t> r_a_v = {1, act_shape[1], act_shape[2] / NSPLIT};
auto r_a_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, r_a_v);

// Do the CW MM for every split
std::vector<std::shared_ptr<ov::Node>> to_concat;
for (std::size_t i = 0; i < NSPLIT; i++) {
auto a_f16 = std::make_shared<ov::op::v0::Convert>(split_a->output(i), ov::element::f16);
auto r_f16 = std::make_shared<ov::op::v1::Reshape>(a_f16, r_a_c, false);
auto w_f16 = std::make_shared<ov::op::v0::Convert>(split_w->output(i), ov::element::f16);
auto m_f16 = std::make_shared<ov::op::v0::MatMul>(r_f16, w_f16, false, true);
auto s_f16 = std::make_shared<ov::op::v1::Multiply>(m_f16, split_s->output(i));
to_concat.push_back(s_f16);
}

// Reduce via Add
std::vector<ov::Output<ov::Node>> reduce;
reduce.push_back(std::make_shared<ov::op::v1::Add>(to_concat[0], to_concat[1]));
for (std::size_t i = 1; i < NSPLIT - 1; i++) {
reduce.push_back(std::make_shared<ov::op::v1::Add>(reduce[i - 1], to_concat[i + 1]));
}

// Convert the result to f32 to maintain the graph contracts. FIXME should be avoided
auto out = std::make_shared<ov::op::v0::Convert>(reduce.back(), ov::element::f32);

// Now.. Reconnect the matmul readers to the new output (reducesum)
for (auto&& r : matched_matmul->output(0).get_target_inputs()) {
r.replace_source_output(out);
}
return true; // root has changed
}
return false; // did nothing here
};
register_matcher(std::make_shared<opp::Matcher>(qmm, "OptDQMatMulGQiP"), std::move(callback));
}

// N token case (prompt)
//
// FROM:
// ???(Act) ------------------------------------------------------->
// Param(W) -> Convert(f16|f32) -> Multiply -> Reshape -> Convert -> MatMul
// Param(S) --------------------->
//
// WHERE (example):
// Act: [ 1, N,4096]
// W: [11008,32, 128]
// S: [11008,32, 1]
// [1, N ,128] x
// TO: [1,11K,128]T =
// [N,32,128] [1,N,128] [1, N ,11K] [32,N,11K]
// ???(Act) -> Reshape > Split(/32) ->[to(f16) - Reshape -> ]}
// Param(W*) -----------> Split(/32) ->[to(f16) -----------> MatMul v ]} 32xAdd
// Param(S*) -----------> Split(/32) ->[-------------------> Multiply ]} v
// to(f32)
// WHERE:
// W* : [32,11008, 128]
// S* : [32, 1,11008]
DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qcvtm = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtm});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qcoeff);
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);

auto qweight_shape = matched_qweight->output(0).get_shape();
auto qcoeff_shape = matched_qcoeff->output(0).get_shape();
auto act_shape = matched_out_mmi.get_shape();
auto out_shape = matched_node_matmul->output(0).get_shape();

if (ov::element::i4 == matched_qweight->get_element_type() && qweight_shape.size() == 3 &&
ov::element::f16 == matched_qcoeff->get_element_type() && qcoeff_shape.size() == 3 &&
act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == qweight_shape[1] && qcoeff_shape[2] == 1 &&
!matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {1, 0, 2});

// Also transpose S, but in a different way (see diagram above)
ctx.get().permute(matched_qcoeff, {1, 2, 0});

ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]};
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[1];
std::vector<std::size_t> rshp_act_v = {act_shape[1], NSPLIT, act_shape[2] / NSPLIT};
auto rshp_act_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, rshp_act_v);
auto rshp_act = std::make_shared<ov::op::v1::Reshape>(matched_out_mmi, rshp_act_c, false);

// Split Act and W, and S tensors by NSPLIT
auto split_axis_a = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
auto split_a = std::make_shared<ov::op::v1::Split>(rshp_act, split_axis_a, NSPLIT);

auto split_axis_w = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_w = std::make_shared<ov::op::v1::Split>(matched_qweight, split_axis_w, NSPLIT);
auto split_s = std::make_shared<ov::op::v1::Split>(matched_qcoeff, split_axis_w, NSPLIT);

std::vector<std::size_t> r_a_v = {1, act_shape[1], act_shape[2] / NSPLIT};
auto r_a_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, r_a_v);

// Do the CW MM for every split
std::vector<std::shared_ptr<ov::Node>> to_concat;
for (std::size_t i = 0; i < NSPLIT; i++) {
auto a_f16 = std::make_shared<ov::op::v0::Convert>(split_a->output(i), ov::element::f16);
auto r_f16 = std::make_shared<ov::op::v1::Reshape>(a_f16, r_a_c, false);
auto w_f16 = std::make_shared<ov::op::v0::Convert>(split_w->output(i), ov::element::f16);
auto m_f16 = std::make_shared<ov::op::v0::MatMul>(r_f16, w_f16, false, true);
auto s_f16 = std::make_shared<ov::op::v1::Multiply>(m_f16, split_s->output(i));
to_concat.push_back(s_f16);
}

// Reduce via Add
std::vector<ov::Output<ov::Node>> reduce;
reduce.push_back(std::make_shared<ov::op::v1::Add>(to_concat[0], to_concat[1]));
for (std::size_t i = 1; i < NSPLIT - 1; i++) {
reduce.push_back(std::make_shared<ov::op::v1::Add>(reduce[i - 1], to_concat[i + 1]));
}

ov::Output<ov::Node> out = reduce.back();
if (matched_matmul->output(0).get_element_type() == ov::element::f32) {
// Convert the result to f32 to maintain the graph contracts, if needed
out = std::make_shared<ov::op::v0::Convert>(out, ov::element::f32);
}

// Now.. Reconnect the matmul readers to the new output (reducesum)
for (auto&& r : matched_matmul->output(0).get_target_inputs()) {
r.replace_source_output(out);
}
return true; // root has changed
}
return false; // did nothing here
};
register_matcher(std::make_shared<opp::Matcher>(qmm, "OptDQMatMulGQ2iP"), std::move(callback));
}

// Identifies this pattern
//
// Multiply -----------------------------------> MatMul
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ class DQMatMulGQ2i : public ov::pass::MatcherPass {
explicit DQMatMulGQ2i(Context::Ref ctx);
};

class DQMatMulGQiP : public ov::pass::MatcherPass {
public:
explicit DQMatMulGQiP(Context::Ref ctx);
};

class DQMatMulGQ2iP : public ov::pass::MatcherPass {
public:
explicit DQMatMulGQ2iP(Context::Ref ctx);
};

class DQParMMGQ : public ov::pass::MatcherPass {
public:
explicit DQParMMGQ(Context::Ref ctx);
Expand Down

0 comments on commit 856b8b0

Please sign in to comment.