Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move beam search in case of chat scenario to sampler.cpp #1215

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
455 changes: 0 additions & 455 deletions src/cpp/src/group_beam_searcher.cpp

This file was deleted.

134 changes: 76 additions & 58 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,23 @@
namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> beam_search(
ov::InferRequest& lm,
ov::Tensor prompts,
ov::Tensor attention_mask,
GenerationConfig config,
std::optional<ov::Tensor> position_ids,
std::optional<int32_t> selected_beam_idx
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;
bool is_chat_conversation = false;
bool m_trust_encoded_history = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;
Sampler m_sampler;
// Tail of previous output in chat mode is missing in KV cache, let's keep it
std::optional<int64_t> m_last_disappeared_token = std::nullopt;
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -154,35 +149,44 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
// so let's check it out, find the trusted part and use it in on the next step
size_t last_same_hist_token = 0;
size_t trusted_history_length = 0;
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
new_chat_tokens.input_ids.data<int64_t>() + last_same_hist_token);
{1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length},
new_chat_tokens.input_ids.data<int64_t>() + trusted_history_length);

ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape());
std::fill_n(new_attention_mask.data<int64_t>(), new_tensor.get_shape()[1], 1);

encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
{1, new_chat_tokens.input_ids.get_shape().at(1) - trusted_history_length});
new_tensor.copy_to(encoded_input.input_ids);
encoded_input.attention_mask = new_attention_mask;

m_selected_beam = std::nullopt;
m_last_disappeared_token = std::nullopt;
} else {
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;

m_tokenized_chat_history.clear();
m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size());
std::copy_n(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.get_size(),
Expand Down Expand Up @@ -264,6 +268,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token);
}

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;

// If eos_token_id was not provided, take value from default m_generation_config
Expand Down Expand Up @@ -294,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand All @@ -304,10 +314,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;

kv_cache_len = atten_mask_history.get_shape()[1] - m_kv_history_manager.num_tokens_to_remove_from_kv_cache;

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
auto start_atten_hst = atten_mask_history.data<int64_t>();

std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
Expand All @@ -317,6 +329,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
concatenated_attention_mask = attention_mask;
}

size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1];

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
Expand All @@ -330,51 +344,55 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}

if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
m_sampler, requests, position_ids, std::nullopt, m_selected_beam);
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

if (is_chat_conversation) {
// force remove from kv_cache last answer
if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_selected_beam = std::nullopt;
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -388,10 +406,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history = {};
Expand All @@ -409,10 +427,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history.clear();
Expand Down
39 changes: 18 additions & 21 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#include <regex>
#include <vector>

#include "utils.hpp"
#include "debug_utils.hpp"
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

#include "debug_utils.hpp"

#include "utils.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -51,16 +50,15 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, int32_t> get_lm_encoded_results(
std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr,
Sampler& sampler,
std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids,
std::optional<EmbeddingsModel> m_embedding,
std::optional<int32_t> selected_beam_idx
std::optional<EmbeddingsModel> m_embedding
) {
std::vector<GenerationHandle> generations;
for (SequenceGroup::Ptr sequence_group : sequence_groups) {
Expand Down Expand Up @@ -105,7 +103,7 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
m_llm.set_tensor("position_ids", *position_ids);

ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size});
std::fill_n(beam_idx.data<int32_t>(), batch_size, selected_beam_idx.has_value() ? *selected_beam_idx : 0);
std::fill_n(beam_idx.data<int32_t>(), batch_size, 0);
m_llm.set_tensor("beam_idx", beam_idx);

// "Prompt" phase
Expand Down Expand Up @@ -171,13 +169,13 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
// apply strides to shift to a next sequence
input_ids_data += num_scheduled_tokens;

// for different sequences iteration of beams started from 0, but we collect it to one input_ids#
// for different sequences iteration of beams started from 0, but we collect it to one input_ids
next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id()));
}
}

for (size_t i = 0; i < sequence_groups.size(); i++) {
beam_offets[sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]);
for (size_t i = 0; i < active_sequence_groups.size(); i++) {
beam_offets[active_sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]);
}

if (m_embedding.has_value()) {
Expand Down Expand Up @@ -212,15 +210,10 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
streamer_ptr->end();
}

// Collect results

size_t next_selected_beam = 0;
for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
std::vector<GenerationOutput> generation_outputs;
auto sampling_params = request->get_sampling_parameters();
const auto& sequences = request->get_finished_sequences();
size_t num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, sequences.size());
for (auto& sequence_group : sequence_groups) {
auto sampling_params = sequence_group->get_sampling_parameters();
const auto& sequences = sequence_group->get_finished_sequences();
size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size());

for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) {
const auto & sequence = sequences[seq_id];
Expand All @@ -229,13 +222,17 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
}
// next_selected_beam = sampler.last_selected_beam(request);
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

return {results, next_selected_beam};
// it is not saved in KV cache, we need to add it for some cases
std::optional<int64_t> last_token_of_best_sequence = std::nullopt;
if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped())
last_token_of_best_sequence = results.tokens[0].back();

return {results, last_token_of_best_sequence};
}

} // namespace genai
Expand Down
Loading
Loading