diff --git a/src/cpp/include/openvino/genai/whisper_generation_config.hpp b/src/cpp/include/openvino/genai/whisper_generation_config.hpp index 12ca32ebb2..18b4202609 100644 --- a/src/cpp/include/openvino/genai/whisper_generation_config.hpp +++ b/src/cpp/include/openvino/genai/whisper_generation_config.hpp @@ -6,6 +6,7 @@ #include #include +#include "generation_config.hpp" #include "openvino/genai/tokenizer.hpp" #include "openvino/runtime/compiled_model.hpp" @@ -15,28 +16,14 @@ namespace genai { /** * @brief Structure to keep whisper generation config parameters. */ -class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig { +class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig { public: WhisperGenerationConfig() = default; explicit WhisperGenerationConfig(const std::filesystem::path& json_path); - // Generic - - // the maximum length the generated tokens can have. Corresponds to the length of the input prompt + - // `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - size_t max_new_tokens = SIZE_MAX; - // the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. - // max_new_tokens has priority over max_length. - size_t max_length = SIZE_MAX; - - // Whisper specific - // Corresponds to the ”<|startoftranscript|>” token. int64_t decoder_start_token_id = 50258; - // End of stream token id. - int64_t eos_token_id = 50257; - // Padding token id. int64_t pad_token_id = 50257; @@ -110,12 +97,6 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig { // A list containing the non-speech tokens that will be suppressed during generation. std::vector suppress_tokens; - /** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0. - * Otherwise verifies eos_token_id == tokenizer_eos_token_id. - */ - void set_eos_token_id(int64_t tokenizer_eos_token_id); - size_t get_max_new_tokens(size_t prompt_length = 0) const; - void update_generation_config(const ov::AnyMap& config_map = {}); template diff --git a/src/cpp/src/debug_utils.hpp b/src/cpp/src/debug_utils.hpp index 4c301493ce..66b42f8640 100644 --- a/src/cpp/src/debug_utils.hpp +++ b/src/cpp/src/debug_utils.hpp @@ -12,7 +12,7 @@ template void print_array(T * array, size_t size) { std::cout << " => [ "; - for (size_t i = 0; i < size; ++i) { + for (size_t i = 0; i < std::min(size, size_t(10)); ++i) { std::cout << array[i] << " "; } std::cout << " ] " << std::endl; @@ -20,6 +20,7 @@ void print_array(T * array, size_t size) { inline void print_tensor(std::string name, ov::Tensor tensor) { std::cout << name; + std::cout << " " << tensor.get_shape().to_string(); if (tensor.get_element_type() == ov::element::i32) { print_array(tensor.data(), tensor.get_size()); } else if (tensor.get_element_type() == ov::element::i64) { diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index e2ec3a1b33..0a51a4ed4d 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -14,10 +14,13 @@ #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" +namespace { -namespace ov { -namespace genai { - +/** + * Set position ids tensor data for next token inference based on provided attention mask + * Supports multi batch + * Supports sparse attention_mask + */ void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) { const size_t batch_size = attention_mask.get_shape().at(0); const size_t sequence_length = attention_mask.get_shape().at(1); @@ -65,7 +68,10 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector()[result_prompt_offset + new_shape.at(1) - 1] = 1; } } +} +namespace ov { +namespace genai { std::pair> get_lm_encoded_results( ov::InferRequest& m_llm, diff --git a/src/cpp/src/logger.hpp b/src/cpp/src/logger.hpp index 503a419e5e..fbf0657a87 100644 --- a/src/cpp/src/logger.hpp +++ b/src/cpp/src/logger.hpp @@ -9,7 +9,7 @@ namespace ov::genai { class Logger { public: - static void warn(std::string message) { + static void warn(const std::string& message) { std::cout << "[WARN] " << message << '\n'; }; }; diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 4c399ab641..7a1e079746 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -408,7 +408,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, } // check whether group has finished - group.is_done(m_parameters); + group.is_done(); // group cannot continue if there are no valid child beams if (child_beams_per_group[group_id].size() == 0) { @@ -549,14 +549,14 @@ std::vector Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen std::vector dropped_seq_ids; for (auto& running_sequence : sequence_group->get_running_sequences()) { const auto generated_len = running_sequence->get_generated_len(); - if (sampling_params.max_new_tokens <= generated_len || + if (sequence_group->get_max_new_tokens() <= generated_len || is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { // stop sequence by max_new_tokens or stop token (eos included) running_sequence->set_status(SequenceStatus::FINISHED); if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { running_sequence->set_finish_reason(GenerationFinishReason::STOP); - } else if (sampling_params.max_new_tokens == generated_len) { + } else if (sequence_group->get_max_new_tokens() == generated_len) { running_sequence->set_finish_reason(GenerationFinishReason::LENGTH); } @@ -800,8 +800,8 @@ SamplerOutput Sampler::sample(const std::vector & sequence_g // max counter of needed to be sampled tokens OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset); size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset; - OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len); - size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len; + OPENVINO_ASSERT(sequence_group->get_max_new_tokens() >= generated_and_verified_len); + size_t max_num_sampled_token = sequence_group->get_max_new_tokens() - generated_and_verified_len; if (max_num_sampled_token == 0) { stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request); break; @@ -887,7 +887,7 @@ SamplerOutput Sampler::sample(const std::vector & sequence_g // check max length stop criteria std::vector running_sequences = sequence_group->get_running_sequences(); if (!sequence_group->has_finished() && - running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) { + running_sequences[0]->get_generated_len() == sequence_group->get_max_new_tokens()) { // stop sequence by max_new_tokens m_beam_search_info.at(request_id).finalize(sampler_output); } @@ -956,7 +956,10 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge return preeempted_sequence_id; } -void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) { +void Sampler::GroupBeamSearcher::Group::is_done() { + const auto sequence_group = ongoing.front().m_sequence->get_sequence_group_ptr(); + const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); + assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 && "number of beams should be divisible by number of groups"); size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups; @@ -977,7 +980,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi return; } case ov::genai::StopCriteria::NEVER: { - size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len; + size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens() : cur_len; float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty); done = worst_score >= highest_attainable_score; return; diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 73c656a41d..9768e0a7af 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -114,7 +114,7 @@ class Sampler::GroupBeamSearcher { bool done = false; int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params); - void is_done(const ov::genai::GenerationConfig& sampling_params); + void is_done(); }; SequenceGroup::Ptr m_sequence_group; diff --git a/src/cpp/src/scheduler.hpp b/src/cpp/src/scheduler.hpp index fb97c0b0fa..210d9acf6f 100644 --- a/src/cpp/src/scheduler.hpp +++ b/src/cpp/src/scheduler.hpp @@ -497,7 +497,7 @@ class Scheduler { for (auto idx = 0; idx < sequence_groups.size(); idx++) { auto seq_length = sequence_groups[idx]->get_prompt_len() * m_kv_blocks_initial_multiplier; auto gen_config = sequence_groups[idx]->get_sampling_parameters(); - seq_length = std::min(seq_length, sequence_groups[idx]->get_prompt_len() + gen_config.get_max_new_tokens(sequence_groups[idx]->get_prompt_len())); + seq_length = std::min(seq_length, sequence_groups[idx]->get_prompt_len() + sequence_groups[idx]->get_max_new_tokens()); size_t blocks_num = std::ceil((float)seq_length / m_block_manager->get_block_size()); if (gen_config.is_beam_search()) { blocks_num *= gen_config.num_beams; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index fef9757b43..19d29c92ac 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -689,7 +689,11 @@ class SequenceGroup : public std::enable_shared_from_this { GenerationOutputs outputs; outputs.emplace(0, output); m_generation_stream->push(std::move(outputs)); - } + } + + size_t get_max_new_tokens() { + return m_sampling_params.get_max_new_tokens(get_prompt_len()); + } }; inline std::shared_ptr Sequence::get_sequence_group_ptr() const { diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 2d6dfd2ae5..a8cf844cb7 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -46,22 +46,6 @@ void print_tensor(const ov::Tensor& tensor) { std::cout << "]" << std::endl; } -int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) { - if (logits.get_shape()[0] <= batch_idx) { - OPENVINO_THROW("logits batch size doesn't match the number of beams"); - } - - size_t vocab_size = logits.get_shape().back(); - size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size; - size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; - const float* logits_data = logits.data() + batch_offset + sequence_offset; - - int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data; - float max_logit = logits_data[out_token]; - - return out_token; -} - /** * Initializes position ids based on attention mask and starting position */ @@ -128,23 +112,6 @@ void set_attention_mask(ov::Tensor&& attention_mask, std::vector next_b } } -/** - * Set position ids tensor data for next token inference based on provided attention mask - * Supports multi batch - * Supports sparse attention_mask - */ -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) { - const size_t batch_size = attention_mask.get_shape().at(0); - const size_t atten_length = attention_mask.get_shape().at(1); - position_ids.set_shape({batch_size, 1}); - - for (size_t batch = 0; batch < batch_size; batch++) { - int64_t* start = attention_mask.data() + batch * atten_length; - // todo: be careful with start + atten_length, probably need to replace with start + atten_length -1 - position_ids.data()[batch] = std::accumulate(start, start + atten_length, 0); - } -} - /** * Get attention mask tensor for next token inference * Supports multi batch diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index ff1aea1ae9..8c56c39a8c 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -47,14 +47,10 @@ Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); -int64_t argmax(const ov::Tensor& logits, const size_t batch_idx); - void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0); ov::Tensor extend_attention(ov::Tensor attention_mask); -void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask); - template struct OmitOptional { using value = T; }; template struct OmitOptional> { using value = T; }; diff --git a/src/cpp/src/whisper/logit_processor.cpp b/src/cpp/src/whisper/logit_processor.cpp index 92ae39bd4c..ff185e2799 100644 --- a/src/cpp/src/whisper/logit_processor.cpp +++ b/src/cpp/src/whisper/logit_processor.cpp @@ -28,7 +28,6 @@ void process_whisper_timestamp_logits(ov::Tensor& logits, const std::vector& generated_tokens, bool initial_step = false) { const size_t batch_size = logits.get_shape().at(0); - OPENVINO_ASSERT(batch_size == 1, "Batch != 1 is not supported"); size_t vocab_size = logits.get_shape().back(); size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size; diff --git a/src/cpp/src/whisper/models/decoder.cpp b/src/cpp/src/whisper/models/decoder.cpp index 32a8f2eff6..c09a84ccdd 100644 --- a/src/cpp/src/whisper/models/decoder.cpp +++ b/src/cpp/src/whisper/models/decoder.cpp @@ -6,7 +6,7 @@ #include #include "statefull_decoder.hpp" -#include "utils.hpp" +#include "whisper/whisper_utils.hpp" #include "with_past_decoder.hpp" namespace ov::genai { @@ -22,5 +22,55 @@ std::shared_ptr WhisperDecoder::from_path(const std::filesystem: return std::make_shared(models_path, device, properties); } +std::pair WhisperDecoder::detect_language(const ov::Tensor& encoder_hidden_state, + const int64_t decoder_start_token_id) { + Tensor input_ids_tensor{ov::element::i64, {1, 1}}; + input_ids_tensor.data()[0] = decoder_start_token_id; + + Tensor beam_idx_tensor{ov::element::i32, {1}}; + beam_idx_tensor.data()[0] = 0; + + auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor); + + int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); + + reset_state(); + + return {output_token, infer_ms}; +} + +/** + * Encoder hidden states expected to be with batch 1 + * Copy encoder hidden state tensor from batch 1 to requested batch_size. + * Set new encoder hidden states tensor to infer request. + */ +void WhisperDecoder::_set_encoder_hidden_states_tensor(const Tensor& encoder_hidden_state, + const size_t batch_size, + InferRequest& request) { + const size_t current_batch_size = request.get_tensor("encoder_hidden_states").get_shape().at(0); + // batch hasn't changed, skip + if (current_batch_size == batch_size) { + return; + } + + OPENVINO_ASSERT(encoder_hidden_state.get_shape().at(0) == 1); + Shape shape{encoder_hidden_state.get_shape()}; + shape[0] = batch_size; + + Tensor new_encoder_hidden_states{ov::element::f32, shape}; + + auto new_encoder_hidden_states_data = new_encoder_hidden_states.data(); + auto encoder_hidden_state_data = encoder_hidden_state.data(); + + for (size_t batch = 0; batch < batch_size; batch++) { + const size_t batch_offset = batch * encoder_hidden_state.get_size(); + std::memcpy(new_encoder_hidden_states_data + batch_offset, + encoder_hidden_state_data, + encoder_hidden_state.get_byte_size()); + } + + request.set_tensor("encoder_hidden_states", new_encoder_hidden_states); +} + WhisperDecoder::~WhisperDecoder() = default; } // namespace ov::genai diff --git a/src/cpp/src/whisper/models/decoder.hpp b/src/cpp/src/whisper/models/decoder.hpp index cd58e54729..6eeba2b387 100644 --- a/src/cpp/src/whisper/models/decoder.hpp +++ b/src/cpp/src/whisper/models/decoder.hpp @@ -15,15 +15,19 @@ class WhisperDecoder { const std::string& device, const ov::AnyMap& properties); - virtual std::pair detect_language(const ov::Tensor& encoder_hidden_state, - const int64_t decoder_start_token_id) = 0; + std::pair detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id); - virtual std::pair decode(const ov::Tensor& encoder_hidden_state, - const std::vector& input_ids, - const size_t cache_position) = 0; + virtual std::pair decode(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) = 0; virtual void reset_state() = 0; virtual ~WhisperDecoder(); + +protected: + void _set_encoder_hidden_states_tensor(const Tensor& encoder_hidden_state, + const size_t batch_size, + InferRequest& request); }; } // namespace ov::genai diff --git a/src/cpp/src/whisper/models/statefull_decoder.cpp b/src/cpp/src/whisper/models/statefull_decoder.cpp index 39b7c4b1e0..affdfffaf5 100644 --- a/src/cpp/src/whisper/models/statefull_decoder.cpp +++ b/src/cpp/src/whisper/models/statefull_decoder.cpp @@ -3,6 +3,7 @@ #include "statefull_decoder.hpp" +#include "debug_utils.hpp" #include "utils.hpp" namespace ov::genai { @@ -21,44 +22,49 @@ WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& mo m_request = compiled_model.create_infer_request(); } -std::pair WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state, - const int64_t decoder_start_token_id) { - auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); +std::pair WhisperStatefullDecoder::decode(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) { + const size_t batch_size = input_ids.get_shape().at(0); + const size_t seq_len = input_ids.get_shape().at(1); - int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); + _set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, m_request); - reset_state(); + _set_cache_position_tensor(seq_len); + m_request.set_tensor("input_ids", input_ids); + m_request.set_tensor("beam_idx", beam_idx); - return {output_token, infer_ms}; -} + const auto infer_start = std::chrono::steady_clock::now(); + m_request.infer(); + const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); -std::pair WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state, - const std::vector& input_ids, - const size_t cache_position) { - m_request.set_tensor("encoder_hidden_states", encoder_hidden_state); + auto output_tensor = m_request.get_tensor("logits"); - ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); - m_request.set_tensor("input_ids", input_ids_tensor); + return {output_tensor, infer_ms}; +}; +void WhisperStatefullDecoder::_set_cache_position_tensor(const size_t seq_len) { ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position"); - cache_position_tensor.set_shape({input_ids.size()}); - auto cache_data = cache_position_tensor.data(); - std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position); + int64_t start_cache_position = 0; - m_request.get_tensor("beam_idx").set_shape({1}); - m_request.get_tensor("beam_idx").data()[0] = 0; + if (cache_position_tensor.get_size() != 0) { + start_cache_position = cache_position_tensor.data()[cache_position_tensor.get_size() - 1] + 1; + } - const auto infer_start = std::chrono::steady_clock::now(); - m_request.infer(); - const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); - - auto output_tensor = m_request.get_tensor("logits"); + cache_position_tensor.set_shape({seq_len}); - return {output_tensor, infer_ms}; + auto cache_data = cache_position_tensor.data(); + std::iota(cache_data, cache_data + seq_len, start_cache_position); }; void WhisperStatefullDecoder::reset_state() { m_request.reset_state(); -} + m_request.set_tensor("cache_position", ov::Tensor{ov::element::i64, {0}}); + + Shape encoder_hidden_states_shape{m_request.get_tensor("encoder_hidden_states").get_shape()}; + encoder_hidden_states_shape[0] = 0; + m_request.set_tensor("encoder_hidden_states", ov::Tensor{ov::element::f32, encoder_hidden_states_shape}); +}; + } // namespace ov::genai diff --git a/src/cpp/src/whisper/models/statefull_decoder.hpp b/src/cpp/src/whisper/models/statefull_decoder.hpp index 6f1c9eb002..c8c733e943 100644 --- a/src/cpp/src/whisper/models/statefull_decoder.hpp +++ b/src/cpp/src/whisper/models/statefull_decoder.hpp @@ -14,15 +14,15 @@ class WhisperStatefullDecoder : public WhisperDecoder { const std::string& device, const ov::AnyMap& properties); - std::pair detect_language(const ov::Tensor& encoder_hidden_state, - const int64_t decoder_start_token_id) override; - - std::pair decode(const ov::Tensor& encoder_hidden_state, - const std::vector& input_ids, - const size_t cache_position) override; + std::pair decode(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) override; void reset_state() override; +private: + void _set_cache_position_tensor(const size_t seq_len); + private: ov::InferRequest m_request; }; diff --git a/src/cpp/src/whisper/models/with_past_decoder.cpp b/src/cpp/src/whisper/models/with_past_decoder.cpp index 7f62ea5657..1ade0dea6b 100644 --- a/src/cpp/src/whisper/models/with_past_decoder.cpp +++ b/src/cpp/src/whisper/models/with_past_decoder.cpp @@ -9,7 +9,32 @@ #include "utils.hpp" namespace { -void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { + +void copy_with_beam_gather(const ov::Tensor& source, ov::Tensor& dest, const ov::Tensor& beam_idx) { + const size_t dest_batch_size = beam_idx.get_shape().at(0); + + ov::Shape dest_shape{source.get_shape()}; + dest_shape[0] = dest_batch_size; + dest.set_shape(dest_shape); + + OPENVINO_ASSERT(dest_shape.size() == 4); + + const size_t batch_dim_size = dest_shape[1] * dest_shape[2] * dest_shape[3]; + + const auto beam_idx_data = beam_idx.data(); + const auto source_data = source.data(); + auto dest_data = dest.data(); + + for (size_t dest_batch = 0; dest_batch < dest_batch_size; dest_batch++) { + const size_t source_batch = beam_idx_data[dest_batch]; + + const auto source_start = source_data + (source_batch * batch_dim_size); + const auto dest_start = dest_data + (dest_batch * batch_dim_size); + std::memcpy(dest_start, source_start, sizeof(float) * batch_dim_size); + } +} + +void copy_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, const ov::Tensor& beam_idx) { // source outputs: // present.0.decoder.key // present.0.decoder.value @@ -24,17 +49,33 @@ void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { for (auto& source_output : source.get_compiled_model().outputs()) { std::string source_output_name = source_output.get_any_name(); - if (source_output_name.find("logits") != std::string::npos) { + if (source_output_name.find("present") == std::string::npos) { + continue; + } + + std::string dest_input_name = std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); + + auto source_tensor = source.get_tensor(source_output_name); + auto dest_tensor = dest.get_tensor(dest_input_name); + + copy_with_beam_gather(source_tensor, dest_tensor, beam_idx); + } +} + +void link_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { + for (auto& source_output : source.get_compiled_model().outputs()) { + std::string source_output_name = source_output.get_any_name(); + if (source_output_name.find("present") == std::string::npos) { continue; } - std::string with_past_input_name = - std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); + std::string dest_input_name = std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); + auto source_tensor = source.get_tensor(source_output_name); - auto kv_tensor = source.get_tensor(source_output_name); - dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor}); + dest.set_tensor(dest_input_name, source_tensor); } } + } // namespace namespace ov::genai { @@ -56,52 +97,80 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode m_request_decoder_with_past = compiled_model.create_infer_request(); } -std::pair WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state, - const int64_t decoder_start_token_id) { - auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); - - int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); - - reset_state(); - - return {output_token, infer_ms}; -} - -std::pair WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state, - const std::vector& input_ids, - const size_t cache_position) { - const bool initial_step = cache_position == 0; - ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past; +std::pair WhisperWithPastDecoder::decode(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) { + const bool is_initial_step = m_cache_position == 0; + ov::InferRequest& request = is_initial_step ? m_request_decoder : m_request_decoder_with_past; - request.set_tensor("encoder_hidden_states", encoder_hidden_state); + const size_t batch_size = input_ids.get_shape().at(0); + const size_t seq_length = input_ids.get_shape().at(1); - const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); - request.set_tensor("input_ids", input_ids_tensor); + _set_encoder_hidden_states_tensor(encoder_hidden_state, batch_size, request); + request.set_tensor("input_ids", input_ids); - if (!initial_step) { + if (!is_initial_step) { ov::Tensor cache_position_tensor = request.get_tensor("cache_position"); cache_position_tensor.set_shape({1}); - cache_position_tensor.data()[0] = cache_position; + cache_position_tensor.data()[0] = m_cache_position; } + _set_past_key_value(beam_idx); + const auto infer_start = std::chrono::steady_clock::now(); request.infer(); const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); auto output_tensor = request.get_tensor("logits"); - if (initial_step) { - set_past_key_value(m_request_decoder, m_request_decoder_with_past); - } else if (!m_decoder_with_past_kv_value_set) { - set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past); - m_decoder_with_past_kv_value_set = true; - } + m_cache_position += seq_length; return {output_tensor, infer_ms}; } +void WhisperWithPastDecoder::_set_past_key_value(const Tensor& beam_idx) { + const bool is_initial_step = m_cache_position == 0; + if (is_initial_step) { + return; + } + + const size_t batch_size = beam_idx.get_shape().at(0); + // no copy needed, just 'link' output tensor with input tensor + const bool can_link_past_key_value = batch_size == 1 && beam_idx.data()[0] == 0; + + if (!m_initial_past_key_value_set) { + if (can_link_past_key_value) { + link_past_key_value(m_request_decoder, m_request_decoder_with_past); + } else { + copy_past_key_value(m_request_decoder, m_request_decoder_with_past, beam_idx); + } + + m_initial_past_key_value_set = true; + return; + } + + if (m_past_key_value_linked) { + return; + } + + if (can_link_past_key_value) { + link_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past); + m_past_key_value_linked = true; + } else { + copy_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past, beam_idx); + } +}; + void WhisperWithPastDecoder::reset_state() { m_request_decoder_with_past.reset_state(); - m_decoder_with_past_kv_value_set = false; + m_cache_position = 0; + m_initial_past_key_value_set = false; + m_past_key_value_linked = false; + + Shape encoder_hidden_states_shape{m_request_decoder_with_past.get_tensor("encoder_hidden_states").get_shape()}; + encoder_hidden_states_shape[0] = 0; + m_request_decoder.set_tensor("encoder_hidden_states", ov::Tensor{ov::element::f32, encoder_hidden_states_shape}); + m_request_decoder_with_past.set_tensor("encoder_hidden_states", + ov::Tensor{ov::element::f32, encoder_hidden_states_shape}); } } // namespace ov::genai diff --git a/src/cpp/src/whisper/models/with_past_decoder.hpp b/src/cpp/src/whisper/models/with_past_decoder.hpp index c7af1cdaa2..1610c60d4e 100644 --- a/src/cpp/src/whisper/models/with_past_decoder.hpp +++ b/src/cpp/src/whisper/models/with_past_decoder.hpp @@ -14,19 +14,20 @@ class WhisperWithPastDecoder : public WhisperDecoder { const std::string& device, const ov::AnyMap& properties); - std::pair detect_language(const ov::Tensor& encoder_hidden_state, - const int64_t decoder_start_token_id) override; - - std::pair decode(const ov::Tensor& encoder_hidden_state, - const std::vector& input_ids, - const size_t cache_position) override; + std::pair decode(const Tensor& encoder_hidden_state, + const Tensor& input_ids, + const Tensor& beam_idx) override; void reset_state() override; private: ov::InferRequest m_request_decoder; ov::InferRequest m_request_decoder_with_past; - bool m_decoder_with_past_kv_value_set = false; + size_t m_cache_position = 0; + bool m_initial_past_key_value_set = false; + bool m_past_key_value_linked = false; + + void _set_past_key_value(const Tensor& beam_idx); }; } // namespace ov::genai diff --git a/src/cpp/src/whisper/whisper.cpp b/src/cpp/src/whisper/whisper.cpp index 0f523e3bee..4031149163 100644 --- a/src/cpp/src/whisper/whisper.cpp +++ b/src/cpp/src/whisper/whisper.cpp @@ -14,6 +14,7 @@ #include "openvino/genai/perf_metrics.hpp" #include "openvino/genai/whisper_generation_config.hpp" #include "openvino/genai/whisper_pipeline.hpp" +#include "sampler.hpp" #include "timestamps.hpp" #include "utils.hpp" #include "whisper_config.hpp" @@ -25,6 +26,154 @@ using ov::genai::MicroSeconds; namespace { +void process_whisper_logits(ov::Tensor logits, + const ov::genai::WhisperGenerationConfig& config, + const bool return_timestamps, + const std::map>& batch_to_generated_ids) { + const bool initial_step = batch_to_generated_ids.empty(); + const size_t batch_size = logits.get_shape().at(0); + + for (size_t batch = 0; batch < batch_size; batch++) { + if (initial_step) { + ov::genai::do_suppress_tokens(logits, batch, config.begin_suppress_tokens); + } + + ov::genai::do_suppress_tokens(logits, batch, config.suppress_tokens); + + if (return_timestamps) { + const auto& generated_ids = initial_step ? std::vector{} : batch_to_generated_ids.at(batch); + ov::genai::process_whisper_timestamp_logits(logits, batch, config, generated_ids, initial_step); + } + } +} + +std::pair decode(std::shared_ptr decoder, + const std::vector& input_ids, + const ov::Tensor& encoder_hidden_state, + const std::shared_ptr streamer_ptr, + ov::genai::Sampler& sampler, + ov::genai::SequenceGroup::Ptr sequence_group, + const bool return_timestamps, + const ov::genai::WhisperGenerationConfig& config, + ov::genai::RawPerfMetrics& raw_metrics) { + const auto handle = std::make_shared(sequence_group->get_generation_stream(), + sequence_group->get_sampling_parameters()); + + auto stream_generated_tokens = [&streamer_ptr, &handle, &return_timestamps]() { + if (return_timestamps || !streamer_ptr || !handle->can_read()) { + return; + } + + std::unordered_map token = handle->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + handle->drop(); + break; + } + } + }; + + const size_t batch_size = 1; + + ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size}); + std::fill_n(beam_idx.data(), batch_size, 0); + + const ov::Tensor input_ids_tensor{ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()}; + + auto [logits, infer_ms] = decoder->decode(encoder_hidden_state, input_ids_tensor, beam_idx); + + const auto infer_end = std::chrono::steady_clock::now(); + raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_metrics.m_token_infer_durations.emplace_back(infer_ms); + raw_metrics.m_new_token_times.emplace_back(infer_end); + raw_metrics.m_batch_sizes.emplace_back(batch_size); + + process_whisper_logits(logits, config, return_timestamps, {}); + + // sample last token only + int64_t output_sequence_len = logits.get_shape().at(1); + sequence_group->schedule_tokens(sequence_group->get_prompt_len()); + sequence_group->set_output_seq_len(output_sequence_len); + + sampler.sample({sequence_group}, logits); + stream_generated_tokens(); + + // "Generation" phase + while (!sequence_group->has_finished() && !sequence_group->handle_dropped()) { + std::map> batch_to_generated_ids{}; + + sequence_group->schedule_tokens(1); + // compute aggregated values + size_t num_sequences = sequence_group->num_running_seqs(); + size_t total_num_tokens = sequence_group->get_num_scheduled_tokens() * num_sequences; + + ov::Tensor new_input_ids(ov::element::i64, {total_num_tokens, 1}); + int64_t* input_ids_data = new_input_ids.data(); + + std::vector next_beams; + + std::vector running_sequences = sequence_group->get_running_sequences(); + size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); + size_t num_processed_tokens = sequence_group->get_num_processed_tokens(); + + std::map beam_idxs = sampler.get_beam_idxs(sequence_group); + + for (auto sequence : running_sequences) { + for (size_t batch = 0, position_id = num_processed_tokens; batch < num_scheduled_tokens; + ++batch, ++position_id) { + // compute token for current sequence + if (position_id < sequence_group->get_prompt_len()) { + input_ids_data[batch] = sequence_group->get_prompt_ids()[position_id]; + } else { + input_ids_data[batch] = + sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()]; + } + } + + // apply strides to shift to a next sequence + input_ids_data += num_scheduled_tokens; + + auto beam_idx = beam_idxs[sequence->get_id()]; + next_beams.push_back(beam_idx); + batch_to_generated_ids[next_beams.size() - 1] = sequence->get_generated_ids(); + } + + auto [logits, infer_ms] = decoder->decode(encoder_hidden_state, + new_input_ids, + ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()}); + + const auto infer_end = std::chrono::steady_clock::now(); + raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_metrics.m_token_infer_durations.emplace_back(infer_ms); + raw_metrics.m_new_token_times.emplace_back(infer_end); + raw_metrics.m_batch_sizes.emplace_back(batch_size); + + process_whisper_logits(logits, config, return_timestamps, batch_to_generated_ids); + + sampler.sample({sequence_group}, logits); + stream_generated_tokens(); + } + + ov::genai::EncodedResults results; + + const auto sampling_params = sequence_group->get_sampling_parameters(); + + // there is also check in generation config validate function + OPENVINO_ASSERT(config.num_return_sequences == 1); + const auto& sequences = sequence_group->get_finished_sequences(); + const auto& sequence = sequences[0]; + + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) + : sequence->get_cumulative_log_prob(); + + results.tokens.push_back(sequence->get_generated_ids()); + results.scores.push_back(score); + + sampler.clear_request_info(sequence_group->get_request_id()); + + return {results, sequence_group->handle_dropped()}; +} + ov::Tensor encode(ov::InferRequest& request, std::vector& mel_data, const size_t feature_size, @@ -54,41 +203,6 @@ ov::Tensor encode(ov::InferRequest& request, return request.get_tensor("last_hidden_state"); } -int64_t decode(ov::Tensor& encoder_hidden_state, - std::shared_ptr decoder, - const std::vector& input_ids, - const size_t cache_position, - const ov::genai::WhisperGenerationConfig& config, - ov::genai::RawPerfMetrics& raw_metrics, - const bool return_timestamps, - const bool initial_step, - const std::vector& generated_tokens) { - auto [output_tensor, infer_ms] = decoder->decode(encoder_hidden_state, input_ids, cache_position); - const auto infer_end = std::chrono::steady_clock::now(); - raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); - raw_metrics.m_token_infer_durations.emplace_back(infer_ms); - raw_metrics.m_new_token_times.emplace_back(infer_end); - raw_metrics.m_batch_sizes.emplace_back(1); - - if (initial_step) { - ov::genai::do_suppress_tokens(output_tensor, 0, config.begin_suppress_tokens); - } - - ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens); - - if (return_timestamps) { - if (initial_step) { - ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, {}, true); - } else { - ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, generated_tokens); - } - } - - int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); - - return output_token; -} - std::vector prepare_init_tokens(ov::Tensor& encoder_hidden_state, std::shared_ptr decoder, const ov::genai::WhisperGenerationConfig& config, @@ -129,52 +243,6 @@ std::vector prepare_init_tokens(ov::Tensor& encoder_hidden_state, config.no_timestamps_token_id}; } -std::pair> full_decode(ov::Tensor& encoder_hidden_state, - const ov::genai::WhisperGenerationConfig& config, - std::shared_ptr decoder, - const std::vector& init_tokens, - const size_t max_new_tokens, - const bool return_timestamps, - ov::genai::RawPerfMetrics& raw_metrics, - const std::shared_ptr streamer) { - int64_t output_token = - decode(encoder_hidden_state, decoder, init_tokens, 0, config, raw_metrics, return_timestamps, true, {}); - - std::vector output_tokens{output_token}; - - if (!return_timestamps && streamer && streamer->put(output_token)) { - return {true, output_tokens}; - } - - if (max_new_tokens == 1) { - return {false, output_tokens}; - } - - for (size_t i = 0; i < max_new_tokens - 1; i++) { - auto output_token = decode(encoder_hidden_state, - decoder, - {output_tokens.back()}, - init_tokens.size() + i, - config, - raw_metrics, - return_timestamps, - false, - output_tokens); - - if (output_token == config.eos_token_id) { - break; - } - - output_tokens.push_back(output_token); - - if (!return_timestamps && streamer && streamer->put(output_token)) { - return {true, output_tokens}; - } - } - - return {false, output_tokens}; -} - } // namespace namespace ov { @@ -187,7 +255,8 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& ov::InferRequest& encoder, std::shared_ptr decoder, WhisperFeatureExtractor& feature_extractor, - const std::shared_ptr streamer) { + const std::shared_ptr streamer, + Sampler& sampler) { size_t max_new_tokens = config.get_max_new_tokens(); WhisperGenerateResult result; @@ -216,10 +285,6 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& size_t segment_offset = 0; for (size_t chunk_offset = 0; chunk_offset < input_features.n_frames; chunk_offset += segment_offset) { - if (output_tokens.size() >= max_new_tokens) { - break; - } - auto input_features_chunk = input_features.get_data_with_offset(chunk_offset, feature_extractor.nb_max_frames); ov::Tensor hidden_state_tensor = encode(encoder, @@ -236,16 +301,19 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& std::vector chunk_init_tokens = ov::genai::get_prompt_tokens(context_tokens, config, chunk_offset); chunk_init_tokens.insert(chunk_init_tokens.end(), init_tokens.begin(), init_tokens.end()); - auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, - config, - decoder, - chunk_init_tokens, - max_new_tokens - output_tokens.size(), - return_timestamps, - raw_metrics, - streamer); - + SequenceGroup::Ptr sequence_group = std::make_shared(0, chunk_init_tokens, config, 1); + + auto [result, cancelled] = decode(decoder, + chunk_init_tokens, + hidden_state_tensor, + streamer, + sampler, + sequence_group, + return_timestamps, + config, + raw_metrics); decoder->reset_state(); + std::vector chunk_output_tokens = result.tokens[0]; if (return_timestamps) { auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens, diff --git a/src/cpp/src/whisper/whisper.hpp b/src/cpp/src/whisper/whisper.hpp index aed6487160..88957119e8 100644 --- a/src/cpp/src/whisper/whisper.hpp +++ b/src/cpp/src/whisper/whisper.hpp @@ -9,6 +9,7 @@ #include "models/decoder.hpp" #include "openvino/genai/whisper_generation_config.hpp" #include "openvino/genai/whisper_pipeline.hpp" +#include "sampler.hpp" #include "whisper_config.hpp" #include "whisper_feature_extractor.hpp" #include "whisper_models.hpp" @@ -35,7 +36,8 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& ov::InferRequest& encoder, std::shared_ptr decoder, WhisperFeatureExtractor& feature_extractor, - const std::shared_ptr streamer); + const std::shared_ptr streamer, + Sampler& sampler); } // namespace genai } // namespace ov diff --git a/src/cpp/src/whisper/whisper_utils.cpp b/src/cpp/src/whisper/whisper_utils.cpp index 6e56a1439d..f41d3d11d8 100644 --- a/src/cpp/src/whisper/whisper_utils.cpp +++ b/src/cpp/src/whisper/whisper_utils.cpp @@ -41,6 +41,22 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics, filter_by_ranges(raw_metrics.m_batch_sizes, offset, ranges); } +int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) { + if (logits.get_shape()[0] <= batch_idx) { + OPENVINO_THROW("logits batch size doesn't match the number of beams"); + } + + size_t vocab_size = logits.get_shape().back(); + size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size; + size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size; + const float* logits_data = logits.data() + batch_offset + sequence_offset; + + int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data; + float max_logit = logits_data[out_token]; + + return out_token; +} + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/whisper/whisper_utils.hpp b/src/cpp/src/whisper/whisper_utils.hpp index 234feed6a8..8fd0a080c6 100644 --- a/src/cpp/src/whisper/whisper_utils.hpp +++ b/src/cpp/src/whisper/whisper_utils.hpp @@ -17,6 +17,8 @@ void filter_non_segment_metrics(ov::genai::RawPerfMetrics& raw_metrics, size_t offset, std::vector>& ranges); +int64_t argmax(const ov::Tensor& logits, const size_t batch_idx); + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/whisper_generation_config.cpp b/src/cpp/src/whisper_generation_config.cpp index 1cc79ab0e6..ec12170cf9 100644 --- a/src/cpp/src/whisper_generation_config.cpp +++ b/src/cpp/src/whisper_generation_config.cpp @@ -14,7 +14,8 @@ namespace ov { namespace genai { -WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& json_path) { +WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& json_path) + : GenerationConfig::GenerationConfig(json_path) { using ov::genai::utils::read_json_param; std::ifstream f(json_path); @@ -22,12 +23,9 @@ WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& js nlohmann::json data = nlohmann::json::parse(f); - read_json_param(data, "max_new_tokens", max_new_tokens); - read_json_param(data, "max_length", max_length); read_json_param(data, "begin_suppress_tokens", begin_suppress_tokens); read_json_param(data, "suppress_tokens", suppress_tokens); read_json_param(data, "decoder_start_token_id", decoder_start_token_id); - read_json_param(data, "eos_token_id", eos_token_id); read_json_param(data, "pad_token_id", pad_token_id); read_json_param(data, "no_timestamps_token_id", no_timestamps_token_id); read_json_param(data, "max_initial_timestamp_index", max_initial_timestamp_index); @@ -42,28 +40,12 @@ WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& js read_json_param(data, "lang_to_id", lang_to_id); } -void WhisperGenerationConfig::set_eos_token_id(int64_t tokenizer_eos_token_id) { - if (eos_token_id < 0) { - eos_token_id = tokenizer_eos_token_id; - } else { - OPENVINO_ASSERT(eos_token_id == tokenizer_eos_token_id, - "EOS token ID is different in generation config (", - eos_token_id, - ") and tokenizer (", - tokenizer_eos_token_id, - ")"); - } -} - void WhisperGenerationConfig::update_generation_config(const ov::AnyMap& config_map) { using ov::genai::utils::read_anymap_param; - read_anymap_param(config_map, "max_new_tokens", max_new_tokens); - read_anymap_param(config_map, "max_length", max_length); read_anymap_param(config_map, "begin_suppress_tokens", begin_suppress_tokens); read_anymap_param(config_map, "suppress_tokens", suppress_tokens); read_anymap_param(config_map, "decoder_start_token_id", decoder_start_token_id); - read_anymap_param(config_map, "eos_token_id", eos_token_id); read_anymap_param(config_map, "pad_token_id", pad_token_id); read_anymap_param(config_map, "transcribe_token_id", transcribe_token_id); read_anymap_param(config_map, "translate_token_id", translate_token_id); @@ -76,27 +58,12 @@ void WhisperGenerationConfig::update_generation_config(const ov::AnyMap& config_ read_anymap_param(config_map, "return_timestamps", return_timestamps); read_anymap_param(config_map, "initial_prompt", initial_prompt); read_anymap_param(config_map, "hotwords", hotwords); -} -size_t WhisperGenerationConfig::get_max_new_tokens(size_t prompt_length) const { - // max_new_tokens has priority over max_length, only if max_new_tokens was not specified use max_length - if (max_new_tokens != SIZE_MAX) { - return max_new_tokens; - } else { - return max_length - prompt_length; - } + GenerationConfig::update_generation_config(config_map); } void WhisperGenerationConfig::validate() const { - OPENVINO_ASSERT(max_new_tokens > 0, "'max_new_tokens' must be greater than 0"); - - // max_new_tokens has priority over max_length - // if max_new_tokens is defined no need to check max_length - OPENVINO_ASSERT(max_new_tokens != SIZE_MAX || max_length > 0, - "'max_length' must be greater than 0 or 'max_new_tokens' should be defined"); - - OPENVINO_ASSERT(eos_token_id != -1 || max_new_tokens != SIZE_MAX || max_length != SIZE_MAX, - "Either 'eos_token_id', or 'max_new_tokens', or 'max_length' should be defined."); + GenerationConfig::validate(); if (is_multilingual && language.has_value()) { OPENVINO_ASSERT(lang_to_id.count(*language), @@ -114,6 +81,13 @@ void WhisperGenerationConfig::validate() const { OPENVINO_ASSERT(!language.has_value(), "Cannot specify 'language' for not multilingual model."); OPENVINO_ASSERT(!task.has_value(), "Cannot specify 'task' for not multilingual model."); } + + OPENVINO_ASSERT(num_return_sequences == 1, + "'num_return_sequences' must be 1. Provided: ", + num_return_sequences, + "."); + + OPENVINO_ASSERT(!is_assisting_generation(), "Assisted generation is not supported."); } } // namespace genai } // namespace ov diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index 70e3950536..19a1a4831d 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -51,7 +51,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi WhisperPipelineStatefulImpl(const std::filesystem::path& models_path, const std::string& device, const ov::AnyMap& properties) - : WhisperPipelineImplBase{models_path} { + : WhisperPipelineImplBase{models_path}, + m_sampler(m_tokenizer) { ov::Core core = utils::singleton_core(); ov::CompiledModel compiled_model = @@ -65,6 +66,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler.set_seed(m_generation_config.rng_seed); } WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, @@ -96,7 +99,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi m_encoder, m_decoder, m_feature_extractor, - streamer_ptr); + streamer_ptr, + m_sampler); auto decode_start_time = std::chrono::steady_clock::now(); WhisperDecodedResults result{std::vector{m_tokenizer.decode(generate_result.output_tokens)}, std::vector{1.f}}; generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back( @@ -135,6 +139,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi private: ov::InferRequest m_encoder; std::shared_ptr m_decoder; + Sampler m_sampler; }; std::pair streamer(ChunkStreamerVariant func) { diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index cf80e973f4..f1898d1232 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -354,11 +354,11 @@ class ChunkStreamerBase: """ End is called at the end of generation. It can be used to flush cache if your own streamer has one """ - def put(self, arg0: int) -> bool: + def put(self, token: int) -> bool: """ Put is called every time new token is generated. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops """ - def put_chunk(self, arg0: list[int]) -> bool: + def put_chunk(self, tokens: list[int]) -> bool: """ Put is called every time new token chunk is generated. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops """ @@ -1953,22 +1953,12 @@ class WhisperDecodedResults: @property def texts(self) -> list[str]: ... -class WhisperGenerationConfig: +class WhisperGenerationConfig(GenerationConfig): """ WhisperGenerationConfig - :param max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - :type max_length: int - - :param max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. - :type max_new_tokens: int - - :param eos_token_id: End of stream token id. - :type eos_token_id: int - + Whisper specific parameters: - :param decoder_start_token_id: Corresponds to the ”<|startoftranscript|>” token. :type decoder_start_token_id: int @@ -2037,18 +2027,55 @@ class WhisperGenerationConfig: auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); // He has gone and gone for good answered Polychrome who... :type hotwords: Optional[str] + + Generic parameters: + max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + + max_new_tokens. Its effect is overridden by `max_new_tokens`, if also set. + max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. + min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. + ignore_eos: if set to true, then generation will not stop even if token is met. + eos_token_id: token_id of (end of sentence) + stop_strings: a set of strings that will cause pipeline to stop generating further tokens. + include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false) + stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens. + echo: if set to true, the model will echo the prompt in the output. + logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned. + Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0). + + repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. + presence_penalty: reduces absolute log prob if the token was generated at least once. + frequency_penalty: reduces absolute log prob as many times as the token was generated. + + Beam search specific parameters: + num_beams: number of beams for beam search. 1 disables beam search. + num_beam_groups: number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + diversity_penalty: value is subtracted from a beam's score if it generates the same token as any beam from other group at a particular time. + length_penalty: exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while + length_penalty < 0.0 encourages shorter sequences. + num_return_sequences: the number of sequences to return for grouped beam search decoding. + no_repeat_ngram_size: if set to int > 0, all ngrams of that size can only occur once. + stop_criteria: controls the stopping condition for grouped beam search. It accepts the following values: + "openvino_genai.StopCriteria.EARLY", where the generation stops as soon as there are `num_beams` complete candidates; + "openvino_genai.StopCriteria.HEURISTIC" is applied and the generation stops when is it very unlikely to find better candidates; + "openvino_genai.StopCriteria.NEVER", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). + + Random sampling parameters: + temperature: the value used to modulate token probabilities for random sampling. + top_p: if set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: whether or not to use multinomial random sampling that add up to `top_p` or higher are kept. + num_return_sequences: the number of sequences to generate from a single prompt. """ begin_suppress_tokens: list[int] decoder_start_token_id: int - eos_token_id: int hotwords: str | None initial_prompt: str | None is_multilingual: bool lang_to_id: dict[str, int] language: str | None max_initial_timestamp_index: int - max_length: int - max_new_tokens: int no_timestamps_token_id: int pad_token_id: int prev_sot_token_id: int @@ -2065,8 +2092,6 @@ class WhisperGenerationConfig: @typing.overload def __init__(self, **kwargs) -> None: ... - def set_eos_token_id(self, tokenizer_eos_token_id: int) -> None: - ... def update_generation_config(self, **kwargs) -> None: ... class WhisperPerfMetrics(PerfMetrics): @@ -2119,18 +2144,8 @@ class WhisperPipeline: WhisperGenerationConfig - :param max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - :type max_length: int - - :param max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. - :type max_new_tokens: int - - :param eos_token_id: End of stream token id. - :type eos_token_id: int - + Whisper specific parameters: - :param decoder_start_token_id: Corresponds to the ”<|startoftranscript|>” token. :type decoder_start_token_id: int @@ -2199,6 +2214,46 @@ class WhisperPipeline: auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); // He has gone and gone for good answered Polychrome who... :type hotwords: Optional[str] + + Generic parameters: + max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + + max_new_tokens. Its effect is overridden by `max_new_tokens`, if also set. + max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. + min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. + ignore_eos: if set to true, then generation will not stop even if token is met. + eos_token_id: token_id of (end of sentence) + stop_strings: a set of strings that will cause pipeline to stop generating further tokens. + include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false) + stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens. + echo: if set to true, the model will echo the prompt in the output. + logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned. + Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0). + + repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. + presence_penalty: reduces absolute log prob if the token was generated at least once. + frequency_penalty: reduces absolute log prob as many times as the token was generated. + + Beam search specific parameters: + num_beams: number of beams for beam search. 1 disables beam search. + num_beam_groups: number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + diversity_penalty: value is subtracted from a beam's score if it generates the same token as any beam from other group at a particular time. + length_penalty: exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while + length_penalty < 0.0 encourages shorter sequences. + num_return_sequences: the number of sequences to return for grouped beam search decoding. + no_repeat_ngram_size: if set to int > 0, all ngrams of that size can only occur once. + stop_criteria: controls the stopping condition for grouped beam search. It accepts the following values: + "openvino_genai.StopCriteria.EARLY", where the generation stops as soon as there are `num_beams` complete candidates; + "openvino_genai.StopCriteria.HEURISTIC" is applied and the generation stops when is it very unlikely to find better candidates; + "openvino_genai.StopCriteria.NEVER", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). + + Random sampling parameters: + temperature: the value used to modulate token probabilities for random sampling. + top_p: if set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: whether or not to use multinomial random sampling that add up to `top_p` or higher are kept. + num_return_sequences: the number of sequences to generate from a single prompt. """ def get_generation_config(self) -> WhisperGenerationConfig: ... diff --git a/src/python/py_whisper_pipeline.cpp b/src/python/py_whisper_pipeline.cpp index 55728409e8..aac14c258a 100644 --- a/src/python/py_whisper_pipeline.cpp +++ b/src/python/py_whisper_pipeline.cpp @@ -17,6 +17,7 @@ namespace py = pybind11; using ov::genai::ChunkStreamerBase; using ov::genai::ChunkStreamerVariant; using ov::genai::DecodedResults; +using ov::genai::GenerationConfig; using ov::genai::OptionalWhisperGenerationConfig; using ov::genai::PerfMetrics; using ov::genai::RawSpeechInput; @@ -76,18 +77,8 @@ auto whisper_decoded_result_chunk = R"( auto whisper_generation_config_docstring = R"( WhisperGenerationConfig - :param max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - :type max_length: int - - :param max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. - :type max_new_tokens: int - - :param eos_token_id: End of stream token id. - :type eos_token_id: int - + Whisper specific parameters: - :param decoder_start_token_id: Corresponds to the ”<|startoftranscript|>” token. :type decoder_start_token_id: int @@ -156,6 +147,46 @@ auto whisper_generation_config_docstring = R"( auto result = pipeline.generate(raw_speech, ov::genai::hotwords("Polychrome")); // He has gone and gone for good answered Polychrome who... :type hotwords: Optional[str] + + Generic parameters: + max_length: the maximum length the generated tokens can have. Corresponds to the length of the input prompt + + max_new_tokens. Its effect is overridden by `max_new_tokens`, if also set. + max_new_tokens: the maximum numbers of tokens to generate, excluding the number of tokens in the prompt. max_new_tokens has priority over max_length. + min_new_tokens: set 0 probability for eos_token_id for the first eos_token_id generated tokens. + ignore_eos: if set to true, then generation will not stop even if token is met. + eos_token_id: token_id of (end of sentence) + stop_strings: a set of strings that will cause pipeline to stop generating further tokens. + include_stop_str_in_output: if set to true stop string that matched generation will be included in generation output (default: false) + stop_token_ids: a set of tokens that will cause pipeline to stop generating further tokens. + echo: if set to true, the model will echo the prompt in the output. + logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned. + Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0). + + repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. + presence_penalty: reduces absolute log prob if the token was generated at least once. + frequency_penalty: reduces absolute log prob as many times as the token was generated. + + Beam search specific parameters: + num_beams: number of beams for beam search. 1 disables beam search. + num_beam_groups: number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + diversity_penalty: value is subtracted from a beam's score if it generates the same token as any beam from other group at a particular time. + length_penalty: exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while + length_penalty < 0.0 encourages shorter sequences. + num_return_sequences: the number of sequences to return for grouped beam search decoding. + no_repeat_ngram_size: if set to int > 0, all ngrams of that size can only occur once. + stop_criteria: controls the stopping condition for grouped beam search. It accepts the following values: + "openvino_genai.StopCriteria.EARLY", where the generation stops as soon as there are `num_beams` complete candidates; + "openvino_genai.StopCriteria.HEURISTIC" is applied and the generation stops when is it very unlikely to find better candidates; + "openvino_genai.StopCriteria.NEVER", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). + + Random sampling parameters: + temperature: the value used to modulate token probabilities for random sampling. + top_p: if set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: whether or not to use multinomial random sampling that add up to `top_p` or higher are kept. + num_return_sequences: the number of sequences to generate from a single prompt. )"; auto streamer_base_docstring = R"( @@ -264,27 +295,28 @@ void init_whisper_pipeline(py::module_& m) { .def("put", &ChunkStreamerBase::put, "Put is called every time new token is generated. Returns a bool flag to indicate whether generation " - "should be stopped, if return true generation stops") + "should be stopped, if return true generation stops", + py::arg("token")) .def("put_chunk", &ChunkStreamerBase::put_chunk, "Put is called every time new token chunk is generated. Returns a bool flag to indicate whether " - "generation should be stopped, if return true generation stops") + "generation should be stopped, if return true generation stops", + py::arg("tokens")) .def("end", &ChunkStreamerBase::end, "End is called at the end of generation. It can be used to flush cache if your own streamer has one"); // Binding for WhisperGenerationConfig - py::class_(m, "WhisperGenerationConfig", whisper_generation_config_docstring) + py::class_(m, + "WhisperGenerationConfig", + whisper_generation_config_docstring) .def(py::init(), py::arg("json_path"), "path where generation_config.json is stored") .def(py::init([](const py::kwargs& kwargs) { return *update_whisper_config_from_kwargs(WhisperGenerationConfig(), kwargs); })) - .def_readwrite("max_new_tokens", &WhisperGenerationConfig::max_new_tokens) - .def_readwrite("max_length", &WhisperGenerationConfig::max_length) .def_readwrite("begin_suppress_tokens", &WhisperGenerationConfig::begin_suppress_tokens) .def_readwrite("suppress_tokens", &WhisperGenerationConfig::suppress_tokens) .def_readwrite("decoder_start_token_id", &WhisperGenerationConfig::decoder_start_token_id) - .def_readwrite("eos_token_id", &WhisperGenerationConfig::eos_token_id) .def_readwrite("pad_token_id", &WhisperGenerationConfig::pad_token_id) .def_readwrite("translate_token_id", &WhisperGenerationConfig::translate_token_id) .def_readwrite("transcribe_token_id", &WhisperGenerationConfig::transcribe_token_id) @@ -298,12 +330,9 @@ void init_whisper_pipeline(py::module_& m) { .def_readwrite("return_timestamps", &WhisperGenerationConfig::return_timestamps) .def_readwrite("initial_prompt", &WhisperGenerationConfig::initial_prompt) .def_readwrite("hotwords", &WhisperGenerationConfig::hotwords) - .def("set_eos_token_id", &WhisperGenerationConfig::set_eos_token_id, py::arg("tokenizer_eos_token_id")) - .def("update_generation_config", []( - ov::genai::WhisperGenerationConfig& config, - const py::kwargs& kwargs) { + .def("update_generation_config", [](ov::genai::WhisperGenerationConfig& config, const py::kwargs& kwargs) { config.update_generation_config(pyutils::kwargs_to_any_map(kwargs)); - });; + }); py::class_(m, "WhisperRawPerfMetrics", raw_perf_metrics_docstring) .def(py::init<>()) diff --git a/tests/python_tests/test_whisper_pipeline.py b/tests/python_tests/test_whisper_pipeline.py index c3a13b6b87..683693857e 100644 --- a/tests/python_tests/test_whisper_pipeline.py +++ b/tests/python_tests/test_whisper_pipeline.py @@ -127,9 +127,14 @@ def run_huggingface( return pipeline( sample, - max_new_tokens=min(config.max_new_tokens, 444), return_timestamps=config.return_timestamps, - generate_kwargs={"language": config.language, "task": config.task}, + generate_kwargs={ + "language": config.language, + "task": config.task, + "max_new_tokens": min(config.max_new_tokens, 444), + "top_p": config.top_p, + "do_sample": config.do_sample, + }, ) @@ -148,6 +153,8 @@ def run_genai( genai_config.return_timestamps = config.return_timestamps genai_config.task = config.task genai_config.language = f"<|{config.language}|>" if config.language else None + genai_config.do_sample = config.do_sample + genai_config.top_p = config.top_p return pipeline.generate(sample, genai_config, streamer=streamer) @@ -546,6 +553,45 @@ def test_initial_prompt_hotwords(model_descr, sample_from_dataset): assert "Joel Kyton" in result.texts[0] +@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) +@pytest.mark.parametrize("sample_from_dataset", [{"language" : "en", "sample_id": 0}], indirect=True) +@pytest.mark.precommit +def test_random_sampling(model_descr, sample_from_dataset): + _, _, hf_pipe, genai_pipe = read_whisper_model(model_descr) + + config = ov_genai.WhisperGenerationConfig(do_sample=True, top_p=0.01) + + genai_result = run_genai( + genai_pipe, + sample_from_dataset, + config=config, + ) + + hf_result = run_huggingface( + hf_pipe, + sample_from_dataset, + config=config, + ) + + compare_results(hf_result, genai_result) + + config.top_p = 0.6 + + genai_result = run_genai( + genai_pipe, + sample_from_dataset, + config=config, + ) + + hf_result = run_huggingface( + hf_pipe, + sample_from_dataset, + config=config, + ) + + assert genai_result.texts[0] != hf_result["text"] + + @pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True)) @pytest.mark.parametrize("sample_from_dataset", [{"language" : "en", "sample_id": 0}], indirect=True) @pytest.mark.precommit