Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Dec 12, 2024
1 parent 075823f commit 266c4b6
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 53 deletions.
98 changes: 54 additions & 44 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
std::string m_templated_chat_history = {};
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -111,6 +112,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

auto start_time = std::chrono::steady_clock::now();
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
// If eos_token_id was not provided, take value from default m_generation_config
if (config.eos_token_id == -1)
config.set_eos_token_id(m_generation_config.eos_token_id);
config.validate();

TokenizedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
Expand Down Expand Up @@ -138,25 +144,34 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
// and find the difference as a prompt, so let's check it out and use the whole history in this case
// so let's check it out, find the trusted part and use it in on the next step
size_t last_same_hist_token = 0;
if (!m_tokenized_chat_history.empty()) {
auto stop_tokens = config.stop_token_ids;
// config could be reset by user and stop_tokens could be empty
// but model/tokenizer still will rely to eos token, so let's add it
stop_tokens.insert(m_tokenizer.get_eos_token_id());
size_t last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
std::set<int64_t> stop_tokens = config.stop_token_ids;
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
}

if (!m_trust_encoded_history) {
reset_kv_state();
m_selected_beam = std::nullopt;
}
if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;

if (!m_tokenized_chat_history.empty() && m_trust_encoded_history) {
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
new_chat_tokens.input_ids.data<int64_t>() + last_same_hist_token);

ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape());
std::fill_n(new_attention_mask.data<int64_t>(), new_tensor.get_shape()[1], 1);

encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
new_tensor.copy_to(encoded_input.input_ids);
encoded_input.attention_mask = new_attention_mask;

m_selected_beam = std::nullopt;
} else {
encoded_input = new_chat_tokens;
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
m_tokenized_chat_history.clear();
Expand All @@ -169,6 +184,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
encoded_input = m_tokenizer.encode(prompt);
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
auto encoded_results = generate(encoded_input, config, streamer);

Expand Down Expand Up @@ -270,57 +286,48 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
if (is_chat_conversation && !m_tokenized_chat_history.empty()) {
if (m_trust_encoded_history) {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
} else {
attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history);
concatenated_attention_mask = attention_mask;
}
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
} else {
concatenated_attention_mask = attention_mask;
}

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
if (is_chat_conversation && !m_trust_encoded_history) {
position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()};
} else {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
}
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
}

if(m_adapter_controller) {
m_adapter_controller->apply(m_model_runner, config.adapters);
}

auto input_tokens = input_ids;
if (is_chat_conversation && !m_trust_encoded_history) {
input_tokens = tokenized_chat_history;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_tokens, concatenated_attention_mask,
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
Expand All @@ -330,6 +337,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
Expand All @@ -345,8 +353,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
}

Sampler sampler = Sampler(m_tokenizer);
std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
}

if (is_chat_conversation) {
Expand All @@ -371,6 +379,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
Expand All @@ -391,6 +400,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
Expand Down
14 changes: 12 additions & 2 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,17 @@ size_t get_first_history_difference(const ov::Tensor& encoded_history, const std
return idx;
}

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end) {
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, std::optional<AdapterController> adapter_controller) {
// nothing to trim in this case
if (remove_from_end == 0)
return;

// TODO: add handling for case with LoRA adapters enabled
auto states = request.query_state();
for (auto& state : states) {
if(adapter_controller && adapter_controller->has_state_name(state.get_name()))
continue;

ov::Tensor old_tensor = state.get_state();
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
auto shape = old_tensor.get_shape();
Expand All @@ -337,7 +344,10 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end) {
ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{shape};

auto new_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);
auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);

ov::Tensor new_tensor(old_tensor.get_element_type(), shape);
trimmed_tensor.copy_to(new_tensor);

state.set_state(new_tensor);
}
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void read_rt_info(std::shared_ptr<ov::Model>& model, const char* name, T& value)

size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector<int64_t> tokenized_history, std::set<int64_t> stop_tokens);

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end);
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, std::optional<AdapterController> adapter_controller);

} // namespace utils
} // namespace genai
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class InputsEmbedder::IInputsEmbedder {
new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt, chat_template_fallback);
}
auto start_tokenizer_time = std::chrono::steady_clock::now();
ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids;
TokenizedInputs prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history);
ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids;
TokenizedInputs prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));

// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
Expand All @@ -173,7 +173,7 @@ class InputsEmbedder::IInputsEmbedder {
if (m_tokenized_chat_history.empty()) {
encoded_input_ids = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - 1 - last_same_hist_token;
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(),
{1, new_chat_tokens.get_shape().at(1) - last_same_hist_token},
Expand Down
4 changes: 1 addition & 3 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
auto end_get_inputs_embeds = std::chrono::steady_clock::now();

auto to_remove_from_hist = m_inputs_embedder->get_amount_to_remove_from_hist();
if (to_remove_from_hist > 0) {
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist);
}
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, std::nullopt);

Sampler sampler = Sampler(m_tokenizer);

Expand Down

0 comments on commit 266c4b6

Please sign in to comment.