Skip to content

Commit

Permalink
Use sequence length axis in tensor trim (#723)
Browse files Browse the repository at this point in the history
* Use sequence length axis in `trimm_tensor`
  • Loading branch information
as-suvorov authored Aug 5, 2024
1 parent f50aaf7 commit 4e1e755
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ jobs:
run: |
source ./ov/setupvars.sh
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm ./dolly-v2-3b/ ./dolly-v2-7b/ "Alan Turing was a" > predictions_speculative.txt
./build/samples/cpp/speculative_decoding_lm/speculative_decoding_lm ./dolly-v2-3b/ ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
Expand Down Expand Up @@ -420,7 +420,7 @@ jobs:
A:' > ./prompt.txt
./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_prompt_lookup.txt
./build/samples/cpp/prompt_lookup_decoding_lm/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
./build/samples/cpp/greedy_causal_lm/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <string_view>
#include <openvino/core/parallel.hpp>
#include <openvino/openvino.hpp>
#include <string_view>

namespace {

Expand Down Expand Up @@ -58,7 +58,7 @@ struct TextStreamer {
void end() {
std::string text = detokenize(detokenizer, token_cache);
if (text.size() <= print_len)
return ;
return;
std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << '\n';
token_cache.clear();
print_len = 0;
Expand All @@ -75,25 +75,24 @@ ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_

auto old_tensor_data = tensor.data<float>();
auto shape = tensor.get_shape();
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t old_seq_len = shape[2];
size_t head_size = shape[3];
size_t old_seq_len = shape[seq_len_axis];

OPENVINO_ASSERT(new_seq_len <= old_seq_len);

// if new_seq_len equal to old one no need to copy tensor, return as is
if (old_seq_len == new_seq_len)
return tensor;

shape[seq_len_axis] = new_seq_len;

if (seq_len_axis == 0) {
shape[0] = new_seq_len;
tensor.set_shape(shape);
return tensor;
}

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{batch_size, num_kv_heads, new_seq_len, head_size};
ov::Coordinate new_shape_end{shape};

auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end);

return new_tensor;
Expand Down
17 changes: 8 additions & 9 deletions samples/cpp/speculative_decoding_lm/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
#include <openvino/openvino.hpp>
#include <random>

constexpr size_t BATCH_SIZE = 1;
namespace {

constexpr size_t BATCH_SIZE = 1;
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually SEQ_LEN_AXIS = 2
constexpr size_t SEQ_LEN_AXIS = 2;

namespace {
std::pair<ov::Tensor, ov::Tensor> tokenize(ov::InferRequest& tokenizer, std::string&& prompt) {
tokenizer.set_input_tensor(ov::Tensor{ov::element::string, {BATCH_SIZE}, &prompt});
tokenizer.infer();
Expand Down Expand Up @@ -58,7 +58,7 @@ struct TextStreamer {
void end() {
std::string text = detokenize(detokenizer, token_cache);
if (text.size() <= print_len)
return ;
return;
std::cout << std::string_view{text.data() + print_len, text.size() - print_len} << '\n';
token_cache.clear();
print_len = 0;
Expand All @@ -75,25 +75,24 @@ ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_

auto old_tensor_data = tensor.data<float>();
auto shape = tensor.get_shape();
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t old_seq_len = shape[2];
size_t head_size = shape[3];
size_t old_seq_len = shape[seq_len_axis];

OPENVINO_ASSERT(new_seq_len <= old_seq_len);

// if new_seq_len equal to old one no need to copy tensor, return as is
if (old_seq_len == new_seq_len)
return tensor;

shape[seq_len_axis] = new_seq_len;

if (seq_len_axis == 0) {
shape[0] = new_seq_len;
tensor.set_shape(shape);
return tensor;
}

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{batch_size, num_kv_heads, new_seq_len, head_size};
ov::Coordinate new_shape_end{shape};

auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end);

return new_tensor;
Expand Down

0 comments on commit 4e1e755

Please sign in to comment.