Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sequence length axis in tensor trim #723

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion samples/cpp/prompt_lookup_decoding_lm/CMakeLists.txt
Copy link
Contributor

@ilya-lavrenov ilya-lavrenov Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add such models with different sequence length dimension ID to GHA CI validation ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we keep stateful mode after CB merge? If yes, then I think it make sense to add tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CB is already merged and we keep stateful mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests can be added separately

Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,35 @@ if(TARGET openvino_tokenizers)
else()
message(FATAL_ERROR "multinomial_causal_lm must be compiled as part of OpenVIINOGenAI to have the path to openvino_tokenizers hardcoded.")
endif()

include(FetchContent)

if(POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()

if(NOT TARGET nlohmann_json)
FetchContent_Declare(nlohmann_json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
URL_HASH SHA256=0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406)
FetchContent_MakeAvailable(nlohmann_json)
endif()

find_package(OpenVINO REQUIRED COMPONENTS Runtime)
find_package(TBB REQUIRED COMPONENTS tbb)

add_executable(prompt_lookup_decoding_lm prompt_lookup_decoding_lm.cpp)
target_link_libraries(prompt_lookup_decoding_lm PRIVATE openvino::runtime TBB::tbb)

target_link_libraries(prompt_lookup_decoding_lm PRIVATE openvino::runtime TBB::tbb nlohmann_json::nlohmann_json)

set_target_properties(prompt_lookup_decoding_lm PROPERTIES
COMPILE_PDB_NAME prompt_lookup_decoding_lm
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)

target_compile_definitions(prompt_lookup_decoding_lm PRIVATE OPENVINO_TOKENIZERS_PATH="${OPENVINO_TOKENIZERS_PATH}")
target_compile_features(prompt_lookup_decoding_lm PRIVATE cxx_std_17)

install(TARGETS prompt_lookup_decoding_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
Expand Down
58 changes: 46 additions & 12 deletions samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <string_view>
#include <fstream>
#include <nlohmann/json.hpp>
#include <openvino/core/parallel.hpp>
#include <openvino/openvino.hpp>
#include <string_view>

namespace {

// only batch_size = 1 currently supported
constexpr size_t BATCH_SIZE = 1;
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually SEQ_LEN_AXIS = 2
constexpr size_t SEQ_LEN_AXIS = 2;

size_t get_seq_len_axis(const std::string model_dir) {
// get sequence length axis based on config.json model_type
// return DEFAILT_SEQ_LEN_AXIS if no model_type found or if there is no predefined seq len axis for this model type

// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually DEFAILT_SEQ_LEN_AXIS = 2
constexpr size_t DEFAILT_SEQ_LEN_AXIS = 2;

std::ifstream f(model_dir + "/config.json");

if (!f.is_open()) {
return DEFAILT_SEQ_LEN_AXIS;
}

nlohmann::json data = nlohmann::json::parse(f);

if (!data.contains("model_type")) {
return DEFAILT_SEQ_LEN_AXIS;
}

const std::string model_type = data["model_type"].get<std::string>();

const std::map<std::string, size_t> model_type_to_seq_len_axis{
{"chatglm", 0},
{"llama", 2},
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov , @Wovchena , This approach is not reliable, chatglm3 and chatglm2 have seq_len_axis=0 but glm-4-9b-chat has seq_len_axis=2. All of them have same model_type='chatglm'.
I'll try to detect seq_len_axis by comparing kv tensor shape dimensions with generated tokens length.


if (!model_type_to_seq_len_axis.count(model_type)) {
return DEFAILT_SEQ_LEN_AXIS;
}

return model_type_to_seq_len_axis.at(model_type);
}

std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt});
Expand Down Expand Up @@ -58,7 +91,7 @@ struct TextStreamer {
void end() {
std::string text = detokenize(detokenizer, token_cache);
if (text.size() <= print_len)
return ;
return;
std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << '\n';
token_cache.clear();
print_len = 0;
Expand All @@ -75,25 +108,24 @@ ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_

auto old_tensor_data = tensor.data<float>();
auto shape = tensor.get_shape();
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t old_seq_len = shape[2];
size_t head_size = shape[3];
size_t old_seq_len = shape[seq_len_axis];

OPENVINO_ASSERT(new_seq_len <= old_seq_len);

// if new_seq_len equal to old one no need to copy tensor, return as is
if (old_seq_len == new_seq_len)
return tensor;

shape[seq_len_axis] = new_seq_len;

if (seq_len_axis == 0) {
shape[0] = new_seq_len;
tensor.set_shape(shape);
return tensor;
}

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{batch_size, num_kv_heads, new_seq_len, head_size};
ov::Coordinate new_shape_end{shape};

auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end);

return new_tensor;
Expand Down Expand Up @@ -228,6 +260,8 @@ int main(int argc, char* argv[]) try {

const int64_t EOS_TOKEN = get_eos_token(tokenizer_model);

const size_t seq_len_axis = get_seq_len_axis(model_dir);

// Prompt lookup decoding is a speculative decoding technic where the draft model replaced
// with string matching in the prompt to generate candidate token sequences.
int max_sequence_length = 100;
Expand Down Expand Up @@ -288,7 +322,7 @@ int main(int argc, char* argv[]) try {
// Increment the sequence length by the number of matched tokens, and
// trim the KV cache to match the new sequence length.
seq_len += accepted_tokens_number;
update_kv_cache(model, SEQ_LEN_AXIS, seq_len);
update_kv_cache(model, seq_len_axis, seq_len);

first_token = out_token;
}
Expand Down
23 changes: 21 additions & 2 deletions samples/cpp/speculative_decoding_lm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,35 @@ if(TARGET openvino_tokenizers)
else()
message(FATAL_ERROR "multinomial_causal_lm must be compiled as part of OpenVIINOGenAI to have the path to openvino_tokenizers hardcoded.")
endif()

include(FetchContent)

if(POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()

if(NOT TARGET nlohmann_json)
FetchContent_Declare(nlohmann_json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
URL_HASH SHA256=0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406)
FetchContent_MakeAvailable(nlohmann_json)
endif()

find_package(OpenVINO REQUIRED COMPONENTS Runtime)
find_package(TBB REQUIRED COMPONENTS tbb)

add_executable(speculative_decoding_lm speculative_decoding_lm.cpp)
target_link_libraries(speculative_decoding_lm PRIVATE openvino::runtime TBB::tbb)

target_link_libraries(speculative_decoding_lm PRIVATE openvino::runtime TBB::tbb nlohmann_json::nlohmann_json)

set_target_properties(speculative_decoding_lm PROPERTIES
COMPILE_PDB_NAME speculative_decoding_lm
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)
target_compile_definitions(speculative_decoding_lm PRIVATE OPENVINO_TOKENIZERS_PATH="${OPENVINO_TOKENIZERS_PATH}")

target_compile_definitions(speculative_decoding_lm PRIVATE OPENVINO_TOKENIZERS_PATH="${OPENVINO_TOKENIZERS_PATH}")
target_compile_features(speculative_decoding_lm PRIVATE cxx_std_17)

install(TARGETS speculative_decoding_lm
RUNTIME DESTINATION samples_bin/
COMPONENT samples_bin
Expand Down
83 changes: 63 additions & 20 deletions samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,50 @@
// SPDX-License-Identifier: Apache-2.0

#include <cmath>
#include <fstream>
#include <nlohmann/json.hpp>
#include <openvino/core/parallel.hpp>
#include <openvino/openvino.hpp>
#include <random>

namespace {

constexpr size_t BATCH_SIZE = 1;

// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually SEQ_LEN_AXIS = 2
constexpr size_t SEQ_LEN_AXIS = 2;
size_t get_seq_len_axis(const std::string model_dir) {
// get sequence length axis based on config.json model_type
// return DEFAILT_SEQ_LEN_AXIS if no model_type found or if there is no predefined seq len axis for this model type

// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually DEFAILT_SEQ_LEN_AXIS = 2
constexpr size_t DEFAILT_SEQ_LEN_AXIS = 2;

std::ifstream f(model_dir + "/config.json");

if (!f.is_open()) {
return DEFAILT_SEQ_LEN_AXIS;
}

nlohmann::json data = nlohmann::json::parse(f);

if (!data.contains("model_type")) {
return DEFAILT_SEQ_LEN_AXIS;
}

const std::string model_type = data["model_type"].get<std::string>();

const std::map<std::string, size_t> model_type_to_seq_len_axis{
{"chatglm", 0},
{"llama", 2},
};

if (!model_type_to_seq_len_axis.count(model_type)) {
return DEFAILT_SEQ_LEN_AXIS;
}

return model_type_to_seq_len_axis.at(model_type);
}

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt});
tokenizer.infer();
Expand Down Expand Up @@ -58,7 +91,7 @@ struct TextStreamer {
void end() {
std::string text = detokenize(detokenizer, token_cache);
if (text.size() <= print_len)
return ;
return;
std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << '\n';
token_cache.clear();
print_len = 0;
Expand All @@ -75,25 +108,24 @@ ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_

auto old_tensor_data = tensor.data<float>();
auto shape = tensor.get_shape();
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t old_seq_len = shape[2];
size_t head_size = shape[3];
size_t old_seq_len = shape[seq_len_axis];

OPENVINO_ASSERT(new_seq_len <= old_seq_len);

// if new_seq_len equal to old one no need to copy tensor, return as is
if (old_seq_len == new_seq_len)
return tensor;

shape[seq_len_axis] = new_seq_len;

if (seq_len_axis == 0) {
shape[0] = new_seq_len;
tensor.set_shape(shape);
return tensor;
}

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{batch_size, num_kv_heads, new_seq_len, head_size};
ov::Coordinate new_shape_end{shape};

auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end);

return new_tensor;
Expand All @@ -114,14 +146,19 @@ class AssistedCandidateGenerator {
size_t max_seq_length;
size_t num_pred_tokens = 5;
const size_t max_pred_tokens = 10;
const size_t seq_len_axis;
int64_t out_of_kv_cache_token = -1;
size_t draft_model_seq_length = 0;

public:
AssistedCandidateGenerator(ov::InferRequest draft_model, const size_t max_seq_length, const size_t num_pred_tokens)
AssistedCandidateGenerator(ov::InferRequest draft_model,
const size_t max_seq_length,
const size_t num_pred_tokens,
const size_t seq_len_axis)
: draft_model{draft_model},
max_seq_length{max_seq_length},
num_pred_tokens{num_pred_tokens} {};
num_pred_tokens{num_pred_tokens},
seq_len_axis{seq_len_axis} {};

int64_t generate_next_token(const std::vector<int64_t> tokens) {
size_t tokens_size = tokens.size();
Expand Down Expand Up @@ -199,7 +236,7 @@ class AssistedCandidateGenerator {
}

out_of_kv_cache_token = -1;
::update_kv_cache(draft_model, SEQ_LEN_AXIS, seq_length);
::update_kv_cache(draft_model, seq_len_axis, seq_length);
draft_model_seq_length = seq_length;
}
};
Expand All @@ -224,27 +261,32 @@ int main(int argc, char* argv[]) try {
// tokenizer model
ov::Core core;
core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt
auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml");

const std::string draft_model_dir = std::string{argv[1]};

auto tokenizer_model = core.read_model(draft_model_dir + "/openvino_tokenizer.xml");
// tokenizer and detokenizer work on CPU only
ov::InferRequest tokenizer = core.compile_model(tokenizer_model, "CPU").create_infer_request();
auto [input_ids, attention_mask] = tokenize(tokenizer, argv[3]);
ov::InferRequest detokenizer =
core.compile_model(std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
core.compile_model(draft_model_dir + "/openvino_detokenizer.xml", "CPU").create_infer_request();
TextStreamer text_streamer{std::move(detokenizer)};

// draft model (which is smaller, less accurate but faster)
ov::InferRequest draft_model =
core.compile_model(std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();
core.compile_model(draft_model_dir + "/openvino_model.xml", "CPU").create_infer_request();

uint64_t seq_len = input_ids.get_shape()[1];

const std::string main_model_dir = std::string{argv[2]};
// main model (which is bigger, more accurate but slower)
ov::InferRequest main_model =
core.compile_model(std::string{argv[2]} + "/openvino_model.xml", "CPU").create_infer_request();
core.compile_model(main_model_dir + "/openvino_model.xml", "CPU").create_infer_request();

size_t max_sequence_length = 100;

AssistedCandidateGenerator candidateGenerator{draft_model, max_sequence_length, 5};
const size_t draft_model_seq_len_axis = get_seq_len_axis(draft_model_dir);
AssistedCandidateGenerator candidateGenerator{draft_model, max_sequence_length, 5, draft_model_seq_len_axis};

main_model.set_tensor("input_ids", input_ids);
main_model.set_tensor("attention_mask", attention_mask);
Expand Down Expand Up @@ -275,6 +317,7 @@ int main(int argc, char* argv[]) try {
text_streamer.put(out_token);

const int64_t EOS_TOKEN = get_eos_token(tokenizer_model);
const size_t main_model_seq_len_axis = get_seq_len_axis(main_model_dir);

/* Speculative decoding works the following way. The draft model predicts the next K
tokens one by one in an autoregressive manner, while the main model validates these
Expand Down Expand Up @@ -347,7 +390,7 @@ int main(int argc, char* argv[]) try {
}

candidateGenerator.update_kv_cache(seq_len);
update_kv_cache(main_model, SEQ_LEN_AXIS, seq_len);
update_kv_cache(main_model, main_model_seq_len_axis, seq_len);

candidates.clear();
}
Expand Down
Loading