From f4af2b6cff99514b6e56c79ea51cb4d121bf8cbc Mon Sep 17 00:00:00 2001 From: mzegla Date: Wed, 24 Jul 2024 09:04:05 +0200 Subject: [PATCH] is_vector_initialized --- src/cpp/src/logit_processor.hpp | 6 +++--- src/cpp/src/sampler.hpp | 4 ++-- tests/cpp/logit_filtering.cpp | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/cpp/src/logit_processor.hpp b/src/cpp/src/logit_processor.hpp index 7b6b9808a2..06ba819b9d 100644 --- a/src/cpp/src/logit_processor.hpp +++ b/src/cpp/src/logit_processor.hpp @@ -32,7 +32,7 @@ struct Logits { m_vector.emplace_back(m_data[i], i); } - bool vector_initialized() const { + bool is_vector_initialized() const { return m_vector.size() > 0; } @@ -59,7 +59,7 @@ class TopPFilter : public ILogitTransformer { TopPFilter(double top_p) : m_top_p(top_p) {} void apply(Logits& logits) override { - if (!logits.vector_initialized()) { + if (!logits.is_vector_initialized()) { // Initialize and sort vector logits.initialize_vector(); std::sort(logits.m_vector.begin(), logits.m_vector.end(), [](const Token& lhs, const Token& rhs) {return lhs.m_log_prob > rhs.m_log_prob; }); @@ -92,7 +92,7 @@ class TopKFilter : public ILogitTransformer { return; */ - if (!logits.vector_initialized()) { + if (!logits.is_vector_initialized()) { // Initialize and partially sort vector logits.initialize_vector(); // TODO: Uncommenting below requires uncommenting section above diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 22437e2e2a..ab8f81ab1c 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -232,7 +232,7 @@ class Sampler { // If top_p or top_k was applied we use sorted vector, if not we go with original buffer. std::vector multinomial_weights; multinomial_weights.reserve(logits.m_size); - if (logits.vector_initialized()) + if (logits.is_vector_initialized()) for (auto& logit: logits.m_vector) multinomial_weights.emplace_back(logit.m_log_prob); else multinomial_weights.assign(logits.m_data, logits.m_data + logits.m_size); @@ -242,7 +242,7 @@ class Sampler { std::vector out_tokens; for (size_t token_idx = 0; token_idx < num_tokens_per_sequence; ++token_idx) { size_t element_to_pick = dist(rng_engine); - if (logits.vector_initialized()) + if (logits.is_vector_initialized()) out_tokens.push_back(logits.m_vector[element_to_pick]); else out_tokens.emplace_back(logits.m_data[element_to_pick], element_to_pick); diff --git a/tests/cpp/logit_filtering.cpp b/tests/cpp/logit_filtering.cpp index 9b0c6ca385..a848683cf3 100644 --- a/tests/cpp/logit_filtering.cpp +++ b/tests/cpp/logit_filtering.cpp @@ -23,7 +23,7 @@ TEST_P(TemperatureTransformTest, TransformResultEqualToReference) { auto logits = Logits(test_struct.input, TemperatureTransformTestStruct::size); auto transform = TemperatureLogitTransform(test_struct.temperature); transform.apply(logits); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, TemperatureTransformTestStruct::size); // temperature transfrom should not change buffer size for (size_t i = 0; i < logits.m_size; i++) { EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6); @@ -58,7 +58,7 @@ TEST_P(TopPFilteringTest, FilterResultEqualToReference) { auto logits = Logits(test_struct.input, TopPTestStruct::size); auto transform = TopPFilter(test_struct.top_p); transform.apply(logits); - ASSERT_TRUE(logits.vector_initialized()); + ASSERT_TRUE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, logits.m_vector.size()); ASSERT_EQ(logits.m_size, test_struct.expected_output.size()); for (size_t i = 0; i < logits.m_vector.size(); i++) { @@ -94,7 +94,7 @@ TEST_P(TopKFilteringTest, FilterResultEqualToReference) { auto logits = Logits(test_struct.input, TopKTestStruct::size); auto transform = TopKFilter(test_struct.top_k); transform.apply(logits); - ASSERT_TRUE(logits.vector_initialized()); + ASSERT_TRUE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, logits.m_vector.size()); ASSERT_EQ(logits.m_size, test_struct.expected_output.size()); for (size_t i = 0; i < logits.m_vector.size(); i++) { @@ -123,7 +123,7 @@ TEST(TopKFilteringTest, FilterNotAppliedTopKGreaterThanInputSize) { auto logits = Logits(input, 3); auto transform = TopKFilter(top_k); transform.apply(logits); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, 3); for (size_t i = 0; i < logits.m_size; i++) { EXPECT_EQ(logits.m_data[i], expected_output[i]); @@ -147,7 +147,7 @@ TEST_P(RepetitionPenaltyTransformTest, TransformResultEqualToReference) { auto logits = Logits(test_struct.input, RepetitionPenaltyTransformTestStruct::size); auto transform = RepetitionPenaltyTransform(test_struct.penalty); transform.apply(logits, test_struct.input_ids); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, RepetitionPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size for (size_t i = 0; i < logits.m_size; i++) { EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6); @@ -206,7 +206,7 @@ TEST_P(FrequencyPenaltyTransformTest, TransformResultEqualToReference) { auto logits = Logits(test_struct.input, FrequencyPenaltyTransformTestStruct::size); auto transform = FrequencyPenaltyTransform(test_struct.penalty); transform.apply(logits, test_struct.input_ids); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, FrequencyPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size for (size_t i = 0; i < logits.m_size; i++) { EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6); @@ -265,7 +265,7 @@ TEST_P(PresencePenaltyTransformTest, TransformResultEqualToReference) { auto logits = Logits(test_struct.input, PresencePenaltyTransformTestStruct::size); auto transform = PresencePenaltyTransform(test_struct.penalty); transform.apply(logits, test_struct.input_ids); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, PresencePenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size for (size_t i = 0; i < logits.m_size; i++) { EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6); @@ -322,7 +322,7 @@ TEST_P(EOSPenaltyTransformTest, TransformResultEqualToReference) { auto logits = Logits(test_struct.input, EOSPenaltyTransformTestStruct::size); auto transform = EOSPenaltyTransform(test_struct.eos_token_id, std::numeric_limits::max()); transform.apply(logits); - ASSERT_FALSE(logits.vector_initialized()); + ASSERT_FALSE(logits.is_vector_initialized()); ASSERT_EQ(logits.m_size, EOSPenaltyTransformTestStruct::size); // penalty transfrom should not change buffer size for (size_t i = 0; i < logits.m_size; i++) { EXPECT_NEAR(logits.m_data[i], test_struct.expected_output[i], 1e-6);