Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: use Sampler #1615

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
68d3e48
use decoder interface
as-suvorov Jan 3, 2025
ade7313
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov Jan 3, 2025
b2df4a6
remove reshape
as-suvorov Jan 3, 2025
17e3ea7
use stateful seq2seq barnch
as-suvorov Jan 3, 2025
806b01a
Address review comments
as-suvorov Jan 3, 2025
e041a33
Rename
as-suvorov Jan 3, 2025
9502d9b
Use commit
as-suvorov Jan 3, 2025
6c30fa4
Set tests reqs
as-suvorov Jan 3, 2025
7600072
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov Jan 8, 2025
f870a4c
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov Jan 9, 2025
e38cf5c
Add with_past model tests
as-suvorov Jan 9, 2025
acc656f
remove comment
as-suvorov Jan 9, 2025
aa0f742
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov Jan 9, 2025
3728884
bump tokenizers
as-suvorov Jan 9, 2025
445ce5a
Fix typo
as-suvorov Jan 9, 2025
5bdd695
Add deprecation message
as-suvorov Jan 10, 2025
7f2a153
Use sampler for whisper pipeline
as-suvorov Jan 21, 2025
c368401
Add with past decoder
as-suvorov Jan 21, 2025
2e061aa
Refactor with past decoder
as-suvorov Jan 22, 2025
4eaa9a7
Do not copy encoder_hidden_states if not needed
as-suvorov Jan 22, 2025
3139a43
Merge remote-tracking branch 'upstream/master' into as/sampler_for_wh…
as-suvorov Jan 22, 2025
50fb829
Add stubs
as-suvorov Jan 22, 2025
5021742
Remove comment
as-suvorov Jan 22, 2025
04318d1
Add args name
as-suvorov Jan 22, 2025
9e08d18
add tests
as-suvorov Jan 22, 2025
56bf11c
Apply review comments
as-suvorov Jan 22, 2025
50eb509
move set_encoder_states to base class
as-suvorov Jan 23, 2025
e802584
Merge remote-tracking branch 'upstream/master' into as/sampler_for_wh…
as-suvorov Jan 23, 2025
abde309
Move detect_language to base decoder
as-suvorov Jan 23, 2025
3348ad5
revert sample
as-suvorov Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
OPENVINO_DEPRECATED("Please, use `is_assisting_generation()` instead of `is_speculative_decoding()`. This method will be removed in 2026.0.0 release")
bool is_speculative_decoding() const;

void update_generation_config(const ov::AnyMap& properties);
virtual void update_generation_config(const ov::AnyMap& properties);

template <typename... Properties>
util::EnableIfAllStringAny<void, Properties...> update_generation_config(Properties&&... properties) {
Expand All @@ -152,7 +152,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {

/// @brief checks that are no conflicting parameters, e.g. do_sample=true and num_beams > 1.
/// @throws Exception if config is invalid.
void validate() const;
virtual void validate() const;
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
};

/*
Expand Down
27 changes: 4 additions & 23 deletions src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <filesystem>
#include <optional>

#include "generation_config.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "openvino/runtime/compiled_model.hpp"

Expand All @@ -15,28 +16,14 @@ namespace genai {
/**
* @brief Structure to keep whisper generation config parameters.
*/
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
public:
WhisperGenerationConfig() = default;
explicit WhisperGenerationConfig(const std::filesystem::path& json_path);

// Generic

// the maximum length the generated tokens can have. Corresponds to the length of the input prompt +
// `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
size_t max_new_tokens = SIZE_MAX;
// the maximum numbers of tokens to generate, excluding the number of tokens in the prompt.
// max_new_tokens has priority over max_length.
size_t max_length = SIZE_MAX;

// Whisper specific

// Corresponds to the ”<|startoftranscript|>” token.
int64_t decoder_start_token_id = 50258;

// End of stream token id.
int64_t eos_token_id = 50257;
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved

// Padding token id.
int64_t pad_token_id = 50257;

Expand Down Expand Up @@ -110,13 +97,7 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
// A list containing the non-speech tokens that will be suppressed during generation.
std::vector<int64_t> suppress_tokens;

/** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0.
* Otherwise verifies eos_token_id == tokenizer_eos_token_id.
*/
void set_eos_token_id(int64_t tokenizer_eos_token_id);
size_t get_max_new_tokens(size_t prompt_length = 0) const;

void update_generation_config(const ov::AnyMap& config_map = {});
void update_generation_config(const ov::AnyMap& config_map = {}) override;

template <typename... Properties>
util::EnableIfAllStringAny<void, Properties...> update_generation_config(Properties&&... properties) {
Expand All @@ -125,7 +106,7 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {

/// @brief checks that are no conflicting parameters.
/// @throws Exception if config is invalid.
void validate() const;
void validate() const override;
};

/*
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/debug_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
template <typename T>
void print_array(T * array, size_t size) {
std::cout << " => [ ";
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < std::min(size, size_t(10)); ++i) {
std::cout << array[i] << " ";
}
std::cout << " ] " << std::endl;
}

inline void print_tensor(std::string name, ov::Tensor tensor) {
std::cout << name;
std::cout << " " << tensor.get_shape().to_string();
if (tensor.get_element_type() == ov::element::i32) {
print_array(tensor.data<int>(), tensor.get_size());
} else if (tensor.get_element_type() == ov::element::i64) {
Expand Down
12 changes: 9 additions & 3 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

namespace {

namespace ov {
namespace genai {

/**
* Set position ids tensor data for next token inference based on provided attention mask
* Supports multi batch
* Supports sparse attention_mask
*/
void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);
Expand Down Expand Up @@ -65,7 +68,10 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
attention_mask.data<int64_t>()[result_prompt_offset + new_shape.at(1) - 1] = 1;
}
}
}

namespace ov {
namespace genai {

std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::InferRequest& m_llm,
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ov::genai {

class Logger {
public:
static void warn(std::string message) {
static void warn(const std::string& message) {
std::cout << "[WARN] " << message << '\n';
};
};
Expand Down
16 changes: 8 additions & 8 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits,
}

// check whether group has finished
group.is_done(m_parameters);
group.is_done(m_parameters, m_sequence_group->get_prompt_len());

// group cannot continue if there are no valid child beams
if (child_beams_per_group[group_id].size() == 0) {
Expand Down Expand Up @@ -549,14 +549,14 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
std::vector<int64_t> dropped_seq_ids;
for (auto& running_sequence : sequence_group->get_running_sequences()) {
const auto generated_len = running_sequence->get_generated_len();
if (sampling_params.max_new_tokens <= generated_len ||
if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) <= generated_len ||
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or stop token (eos included)
running_sequence->set_status(SequenceStatus::FINISHED);

if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
} else if (sampling_params.max_new_tokens == generated_len) {
} else if (sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

Expand Down Expand Up @@ -800,8 +800,8 @@ SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_g
// max counter of needed to be sampled tokens
OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset);
size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset;
OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len;
OPENVINO_ASSERT(sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.get_max_new_tokens(sequence_group->get_prompt_len()) - generated_and_verified_len;
if (max_num_sampled_token == 0) {
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request);
break;
Expand Down Expand Up @@ -887,7 +887,7 @@ SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_g
// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
if (!sequence_group->has_finished() &&
running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) {
running_sequences[0]->get_generated_len() == sampling_params.get_max_new_tokens(sequence_group->get_prompt_len())) {
// stop sequence by max_new_tokens
m_beam_search_info.at(request_id).finalize(sampler_output);
}
Expand Down Expand Up @@ -956,7 +956,7 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge
return preeempted_sequence_id;
}

void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_length) {
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
Expand All @@ -977,7 +977,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi
return;
}
case ov::genai::StopCriteria::NEVER: {
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.get_max_new_tokens(prompt_length) : cur_len;
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
done = worst_score >= highest_attainable_score;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Sampler::GroupBeamSearcher {
bool done = false;

int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_length);
};

SequenceGroup::Ptr m_sequence_group;
Expand Down
17 changes: 0 additions & 17 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,6 @@ void set_attention_mask(ov::Tensor&& attention_mask, std::vector<int32_t> next_b
}
}

/**
* Set position ids tensor data for next token inference based on provided attention mask
* Supports multi batch
* Supports sparse attention_mask
*/
void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t atten_length = attention_mask.get_shape().at(1);
position_ids.set_shape({batch_size, 1});

for (size_t batch = 0; batch < batch_size; batch++) {
int64_t* start = attention_mask.data<int64_t>() + batch * atten_length;
// todo: be careful with start + atten_length, probably need to replace with start + atten_length -1
position_ids.data<int64_t>()[batch] = std::accumulate(start, start + atten_length, 0);
}
}

/**
* Get attention mask tensor for next token inference
* Supports multi batch
Expand Down
2 changes: 0 additions & 2 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attenti

ov::Tensor extend_attention(ov::Tensor attention_mask);

void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask);

template <typename T> struct OmitOptional { using value = T; };
template <typename T> struct OmitOptional<std::optional<T>> { using value = T; };

Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/whisper/logit_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void process_whisper_timestamp_logits(ov::Tensor& logits,
const std::vector<int64_t>& generated_tokens,
bool initial_step = false) {
const size_t batch_size = logits.get_shape().at(0);
OPENVINO_ASSERT(batch_size == 1, "Batch != 1 is not supported");

size_t vocab_size = logits.get_shape().back();
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
Expand Down
8 changes: 4 additions & 4 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ class WhisperDecoder {
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
virtual std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;

virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) = 0;
virtual std::pair<Tensor, float> decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) = 0;

virtual void reset_state() = 0;

Expand Down
105 changes: 89 additions & 16 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "statefull_decoder.hpp"

#include "debug_utils.hpp"
#include "utils.hpp"

namespace ov::genai {
Expand All @@ -19,7 +20,13 @@ WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& mo

std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);
Tensor input_ids_tensor{ov::element::i64, {1, 1}};
input_ids_tensor.data<int64_t>()[0] = decoder_start_token_id;

Tensor beam_idx_tensor{ov::element::i32, {1}};
beam_idx_tensor.data<int32_t>()[0] = 0;

auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

Expand All @@ -28,22 +35,17 @@ std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Ten
return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state);

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
m_request.set_tensor("input_ids", input_ids_tensor);
std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) {
const size_t batch_size = input_ids.get_shape().at(0);
const size_t seq_len = input_ids.get_shape().at(1);

ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");
cache_position_tensor.set_shape({input_ids.size()});

auto cache_data = cache_position_tensor.data<int64_t>();
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position);
_set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size);

m_request.get_tensor("beam_idx").set_shape({1});
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0;
_set_cache_position_tensor(seq_len);
m_request.set_tensor("input_ids", input_ids);
m_request.set_tensor("beam_idx", beam_idx);

const auto infer_start = std::chrono::steady_clock::now();
m_request.infer();
Expand All @@ -54,7 +56,78 @@ std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& e
return {output_tensor, infer_ms};
};

/**
* Encoder hidden states expected to be with batch 1
* Copy encoder hidden state tensor from batch 1 to requested batch_size.
* Set new encoder hidden states tensor to infer request.
*/
void WhisperStatefullDecoder::_set_encoder_hidden_states_tensor(const Tensor& encoder_hidden_state,
const size_t batch_size) {
const size_t current_batch_size = m_request.get_tensor("encoder_hidden_states").get_shape().at(0);
// batch hasn't changed, skip
if (current_batch_size == batch_size) {
return;
}

if (current_batch_size != 0) {
_reset_encoder_past_key_values_states(encoder_hidden_state, batch_size);
}

OPENVINO_ASSERT(encoder_hidden_state.get_shape().at(0) == 1);
Shape shape{encoder_hidden_state.get_shape()};
shape[0] = batch_size;

Tensor new_encoder_hidden_states{ov::element::f32, shape};

auto new_encoder_hidden_states_data = new_encoder_hidden_states.data<float>();
auto encoder_hidden_state_data = encoder_hidden_state.data<float>();

for (size_t batch = 0; batch < batch_size; batch++) {
const size_t batch_offset = batch * encoder_hidden_state.get_size();
std::memcpy(new_encoder_hidden_states_data + batch_offset,
encoder_hidden_state_data,
encoder_hidden_state.get_byte_size());
}

m_request.set_tensor("encoder_hidden_states", new_encoder_hidden_states);
}

// Past_key value states are not shring/grow when batch is changed. Reset past_key values states as a workaround.
// Ticket:
void WhisperStatefullDecoder::_reset_encoder_past_key_values_states(const Tensor& encoder_hidden_state,
const size_t batch_size) {
const size_t encoder_state_length_dim = encoder_hidden_state.get_shape().at(1);
for (auto& state : m_request.query_state()) {
// find encoder states by dimension
const Shape& state_shape = state.get_state().get_shape();
if (state_shape.at(2) == encoder_state_length_dim) {
state.reset();
}
}
}

void WhisperStatefullDecoder::_set_cache_position_tensor(const size_t seq_len) {
ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");

int64_t start_cache_position = 0;

if (cache_position_tensor.get_size() != 0) {
start_cache_position = cache_position_tensor.data<int64_t>()[cache_position_tensor.get_size() - 1] + 1;
}

cache_position_tensor.set_shape({seq_len});

auto cache_data = cache_position_tensor.data<int64_t>();
std::iota(cache_data, cache_data + seq_len, start_cache_position);
};

void WhisperStatefullDecoder::reset_state() {
m_request.reset_state();
}
m_request.set_tensor("cache_position", ov::Tensor{ov::element::i64, {0}});

Shape encoder_hidden_states_shape{m_request.get_tensor("encoder_hidden_states").get_shape()};
encoder_hidden_states_shape[0] = 0;
m_request.set_tensor("encoder_hidden_states", ov::Tensor{ov::element::f32, encoder_hidden_states_shape});
};

} // namespace ov::genai
Loading
Loading