Skip to content

Commit

Permalink
Fix in indices calculation. Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Jan 2, 2025
1 parent 0d172ef commit 0fef019
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
10 changes: 5 additions & 5 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,25 +455,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

size_t num_running_sequences = sequence_group->num_running_seqs();
OPENVINO_ASSERT(num_running_sequences == 1);
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
size_t output_seq_len = sequence_group->get_output_seq_len();

const float * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;

size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens();
OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len());
OPENVINO_ASSERT(num_prompt_tokens_processed + output_seq_len <= sequence_group->get_prompt_len());

// if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion)
// otherwise we include it as it will be used in the next part of the prompt
int exclude_last_logprob = 1;
if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len())
if (num_prompt_tokens_processed + output_seq_len < sequence_group->get_prompt_len())
exclude_last_logprob = 0;

// if we start processing the prompt we add "fake" log prob for the first position (begin of sequence)
if (num_prompt_tokens_processed == 0)
sequence_group->append_prompt_log_prob(1.0);

for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1;
token_logits_offset < actual_seq_len - exclude_last_logprob;
token_logits_offset < output_seq_len - exclude_last_logprob;
token_logits_offset++, token_id_offset++) {

const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size);
Expand All @@ -498,7 +498,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

sequence_group->append_prompt_log_prob(token_logit - max_value - log_sum);
}
currently_processed_tokens += actual_seq_len * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
sequence_group->notify_handle_echo_only();
Expand Down
12 changes: 9 additions & 3 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class ModelRunner {
const size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
actual_seq_len = 0;
Sequence::CPtr sequence = running_sequences[seq_id];
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
Expand All @@ -148,10 +149,15 @@ class ModelRunner {

position_ids_data[token_id] = position_id;

if (matmul_gathering_is_required && sampling_is_required) {
// Check if token gathering is required for the entire sequence group
if (matmul_gathering_is_required && (sampling_is_required || echo_output)) {
// Determine if the current token should be gathered
if (echo_output ||
// Skip gathering for prompt tokens
group_position_id + token_id >= prompt_len - 1 &&
group_position_id + token_id >= num_scheduled_tokens - tokens_to_sample_per_sequence) {
// Gather only the last scheduled token or 1 + num_tokens_to_validate tokens for SD
// In SD, tokens_to_sample_per_sequence may exceed num_scheduled_tokens
token_id + tokens_to_sample_per_sequence >= num_scheduled_tokens) {
gather_indice_values.push_back(gathering_current_index);
actual_seq_len++;
}
Expand All @@ -173,7 +179,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens);
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? actual_seq_len : num_scheduled_tokens);
}

// typical LLM parameters
Expand Down
10 changes: 5 additions & 5 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
size_t output_seq_len = sequence_group->get_output_seq_len();
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand All @@ -773,13 +773,13 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
auto& stop_strings = m_stop_strings.at(request_id);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, output_seq_len, vocab_size}, (void *)sequence_group_logits_data);
size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits<size_t>::max(), updated_validation_len = 0;
if (sequence_group->requires_sampling()) {
// get number of token to be validated
auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate();
if (num_tokens_to_process > actual_seq_len - 1) {
auto delta = num_tokens_to_process - (actual_seq_len - 1);
if (num_tokens_to_process > output_seq_len - 1) {
auto delta = num_tokens_to_process - (output_seq_len - 1);
updated_validation_len = std::max(updated_validation_len, delta);
num_tokens_to_process -= delta;
}
Expand Down Expand Up @@ -913,7 +913,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// accumulate a number of processed tokens
currently_processed_tokens += actual_seq_len * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
}

return sampler_output;
Expand Down
5 changes: 1 addition & 4 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

size_t m_num_streamed_tokens = 0, m_stream_window_size = 0;

// flag shows wheather last matmul was sliced
bool m_sliced_matmul = false;

SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size)
: m_request_id(request_id),
m_sampling_params(sampling_params),
Expand Down Expand Up @@ -399,7 +396,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return m_num_processed_tokens;
}

size_t get_seq_len_to_sample() const {
size_t get_output_seq_len() const {
return m_seq_len_to_sample;
}

Expand Down

0 comments on commit 0fef019

Please sign in to comment.