Skip to content

Commit

Permalink
Merge branch 'different_tokenizers' of github.com:iefode/openvino.gen…
Browse files Browse the repository at this point in the history
…ai into different_tokenizers
  • Loading branch information
iefode committed Jan 23, 2025
2 parents a98bba6 + 90e2ca3 commit 5420d9d
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,50 +233,74 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token
}

namespace {
std::shared_ptr<ov::Node> find_llm_matmul(const std::shared_ptr<ov::Model>& model) {

bool has_op_with_type(const std::shared_ptr<const ov::Model>& function, const std::string& type_name) {
for (const auto& op : function->get_ops()) {
if (op->get_type_name() == type_name) {
return true;
}
}
return false;
}

std::tuple<std::shared_ptr<ov::Node>, int64_t> find_llm_matmul(const std::shared_ptr<ov::Model>& model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr();
std::shared_ptr<ov::Node> matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(last_node);
std::shared_ptr<ov::Node> matmul = ov::as_type_ptr<ov::op::v0::MatMul>(last_node);

// in case of PA all tokens are moved to batch dimension and we have to slice / gather accordingly
const bool pa_based_model = has_op_with_type(model, "PagedAttentionExtension");
int64_t slice_gather_dim = pa_based_model ? 0 : 1;

// There are several patterns for matmul we are looking for:
// Matmul -> Result
// Matmul -> Add -> Result
// Matmul -> Transpose -> Result
// MatMul -> Divide -> Tanh -> Multiply -> Result
if (!matmul) {
if(auto add = std::dynamic_pointer_cast<ov::op::v1::Add>(last_node)) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->input_value(0).get_node_shared_ptr());
} else if (auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(last_node)) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(transpose->input_value(0).get_node_shared_ptr());
} else if (auto multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(last_node)) {
if (auto tanh = std::dynamic_pointer_cast<ov::op::v0::Tanh>(multiply->input_value(0).get_node_shared_ptr())) {
if (auto divide = std::dynamic_pointer_cast<ov::op::v1::Divide>(tanh->input_value(0).get_node_shared_ptr())) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(divide->input_value(0).get_node_shared_ptr());
if (auto add = ov::as_type_ptr<ov::op::v1::Add>(last_node)) {
matmul = ov::as_type_ptr<ov::op::v0::MatMul>(add->input_value(0).get_node_shared_ptr());
} else if (auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(last_node)) {
matmul = ov::as_type_ptr<ov::op::v0::MatMul>(transpose->input_value(0).get_node_shared_ptr());
auto order = ov::as_type_ptr<ov::op::v0::Constant>(transpose->input_value(1).get_node_shared_ptr())->get_axis_vector_val();
slice_gather_dim = order[slice_gather_dim];
} else if (auto multiply = ov::as_type_ptr<ov::op::v1::Multiply>(last_node)) {
if (auto tanh = ov::as_type_ptr<ov::op::v0::Tanh>(multiply->input_value(0).get_node_shared_ptr())) {
if (auto divide = ov::as_type_ptr<ov::op::v1::Divide>(tanh->input_value(0).get_node_shared_ptr())) {
matmul = as_type_ptr<ov::op::v0::MatMul>(divide->input_value(0).get_node_shared_ptr());
}
}
}
}
return matmul;
return std::make_tuple(matmul, slice_gather_dim);
}

} // namespace

void apply_slice_before_matmul_transformation(std::shared_ptr<ov::Model> model) {
auto matmul = find_llm_matmul(model);
std::shared_ptr<ov::Node> matmul = nullptr;
int64_t slice_gather_dim = -1;
std::tie(matmul, slice_gather_dim) = find_llm_matmul(model);

if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{slice_gather_dim});
auto slice = std::make_shared<ov::op::v8::Slice>(matmul->input_value(0), start, stop, step, axis);
matmul->input(0).replace_source_output(slice);
}
}

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model) {
auto matmul = ov::genai::utils::find_llm_matmul(model);
std::shared_ptr<ov::Node> matmul = nullptr;
int64_t slice_gather_dim = -1;
std::tie(matmul, slice_gather_dim) = find_llm_matmul(model);

if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto indices = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{-1});
indices->set_friendly_name("sampled_tokens_indices");
indices->output(0).get_tensor().set_names({"sampled_tokens_indices"});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{0});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{slice_gather_dim});
auto gather = std::make_shared<ov::op::v8::Gather>(matmul->input_value(0), indices, axis);
matmul->input(0).replace_source_output(gather);
model->add_parameters({indices});
Expand Down

0 comments on commit 5420d9d

Please sign in to comment.