From ba0b4a7a44decc9d3ae0f7f4ef7c593fa1bfc7d5 Mon Sep 17 00:00:00 2001 From: ienkovich Date: Wed, 6 Sep 2023 22:07:12 -0500 Subject: [PATCH] Support nested dictionaries in StringDictionary::getBulk. Signed-off-by: ienkovich --- .../StringDictionary/StringDictionary.cpp | 129 +++++++++--------- omniscidb/StringDictionary/StringDictionary.h | 2 + omniscidb/Tests/StringDictionaryTest.cpp | 80 +++++++++++ 3 files changed, 149 insertions(+), 62 deletions(-) diff --git a/omniscidb/StringDictionary/StringDictionary.cpp b/omniscidb/StringDictionary/StringDictionary.cpp index 7cfcc0d40..d767edbd0 100644 --- a/omniscidb/StringDictionary/StringDictionary.cpp +++ b/omniscidb/StringDictionary/StringDictionary.cpp @@ -298,81 +298,77 @@ template size_t StringDictionary::getBulk(const std::vector& string_vec, T* encoded_vec, const int64_t generation) const { - CHECK(!base_dict_) << "Not implemented"; constexpr int64_t target_strings_per_thread{1000}; const int64_t num_lookup_strings = string_vec.size(); if (num_lookup_strings == 0) { return 0; } - const ThreadInfo thread_info( - std::thread::hardware_concurrency(), num_lookup_strings, target_strings_per_thread); - CHECK_GE(thread_info.num_threads, 1L); - CHECK_GE(thread_info.num_elems_per_thread, 1L); - - std::vector num_strings_not_found_per_thread(thread_info.num_threads, 0UL); + size_t base_num_strings_not_found = string_vec.size(); + if (base_dict_) { + auto base_generation_for_bulk = + generation >= 0 ? std::min(generation, base_generation_) : base_generation_; + base_num_strings_not_found = + base_dict_->getBulk(string_vec, encoded_vec, base_generation_for_bulk); + } mapd_shared_lock read_lock(rw_mutex_); - const int64_t num_dict_strings = generation >= 0 ? generation : storageEntryCount(); - const bool dictionary_is_empty = (num_dict_strings == 0); - if (dictionary_is_empty) { - tbb::parallel_for(tbb::blocked_range(0, num_lookup_strings), - [&](const tbb::blocked_range& r) { - const int64_t start_idx = r.begin(); - const int64_t end_idx = r.end(); - for (int64_t string_idx = start_idx; string_idx < end_idx; - ++string_idx) { - encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID; - } - }); - return num_lookup_strings; + const int64_t num_dict_strings = generation >= 0 ? generation : entryCount(); + const bool skip_owned_string = + (num_dict_strings <= base_generation_) || !base_num_strings_not_found; + if (skip_owned_string) { + // Need to fill the resulting vector if it wasn't done by the base dictionary. + if (!base_dict_) { + tbb::parallel_for(tbb::blocked_range( + 0, num_lookup_strings, (size_t)64 << 10 /* 256KB chunks*/), + [&](const tbb::blocked_range& r) { + const int64_t start_idx = r.begin(); + const int64_t end_idx = r.end(); + for (int64_t string_idx = start_idx; string_idx < end_idx; + ++string_idx) { + encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID; + } + }); + } + return base_num_strings_not_found; } // If we're here the generation-capped dictionary has strings in it // that we need to look up against - - tbb::task_arena limited_arena(thread_info.num_threads); - limited_arena.execute([&] { - CHECK_LE(tbb::this_task_arena::max_concurrency(), thread_info.num_threads); - tbb::parallel_for( - tbb::blocked_range( - 0, num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */), - [&](const tbb::blocked_range& r) { - const int64_t start_idx = r.begin(); - const int64_t end_idx = r.end(); - size_t num_strings_not_found = 0; - for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) { - const auto& input_string = string_vec[string_idx]; - if (input_string.empty()) { - encoded_vec[string_idx] = inline_int_null_value(); - continue; - } - if (input_string.size() > StringDictionary::MAX_STRLEN) { - throw_string_too_long_error(input_string, dict_ref_); - } - const uint32_t input_string_hash = hash_string(input_string); - uint32_t hash_bucket = - computeBucket(input_string_hash, input_string, string_id_uint32_table_); - // Will either be legit id or INVALID_STR_ID - const auto string_id = string_id_uint32_table_[hash_bucket]; - if (string_id == StringDictionary::INVALID_STR_ID || - string_id >= num_dict_strings) { - encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID; - num_strings_not_found++; - continue; - } - encoded_vec[string_idx] = string_id; + size_t found_owned = tbb::parallel_reduce( + tbb::blocked_range( + 0, num_lookup_strings, target_strings_per_thread /* tbb grain_size */), + (size_t)0, + [&](const tbb::blocked_range& r, size_t found) { + const int64_t start_idx = r.begin(); + const int64_t end_idx = r.end(); + for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) { + if (base_dict_ && encoded_vec[string_idx] != StringDictionary::INVALID_STR_ID) { + continue; } - const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index(); - num_strings_not_found_per_thread[tbb_thread_idx] = num_strings_not_found; - }, - tbb::simple_partitioner()); - }); + const auto& input_string = string_vec[string_idx]; + if (input_string.empty()) { + encoded_vec[string_idx] = inline_int_null_value(); + ++found; + continue; + } + if (input_string.size() > StringDictionary::MAX_STRLEN) { + throw_string_too_long_error(input_string, dict_ref_); + } + // Will either be legit id or INVALID_STR_ID + const auto string_id = getOwnedUnlocked(input_string); + if (string_id == StringDictionary::INVALID_STR_ID || + string_id >= num_dict_strings) { + encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID; + continue; + } + encoded_vec[string_idx] = string_id; + ++found; + } + return found; + }, + std::plus()); - size_t num_strings_not_found = 0; - for (int64_t thread_idx = 0; thread_idx < thread_info.num_threads; ++thread_idx) { - num_strings_not_found += num_strings_not_found_per_thread[thread_idx]; - } - return num_strings_not_found; + return base_num_strings_not_found - found_owned; } template size_t StringDictionary::getBulk(const std::vector& string_vec, @@ -565,6 +561,15 @@ int32_t StringDictionary::getUnlocked(const std::string_view sv, return base_res; } } + return getOwnedUnlocked(sv, hash); +} + +int32_t StringDictionary::getOwnedUnlocked(const std::string_view sv) const noexcept { + return getOwnedUnlocked(sv, hash_string(sv)); +} + +int32_t StringDictionary::getOwnedUnlocked(const std::string_view sv, + const uint32_t hash) const noexcept { auto str_id = string_id_uint32_table_[computeBucket(hash, sv, string_id_uint32_table_)]; return str_id; } diff --git a/omniscidb/StringDictionary/StringDictionary.h b/omniscidb/StringDictionary/StringDictionary.h index b90869711..ae7c7f939 100644 --- a/omniscidb/StringDictionary/StringDictionary.h +++ b/omniscidb/StringDictionary/StringDictionary.h @@ -179,6 +179,8 @@ class StringDictionary { int32_t getIdOfString(const String&, const uint32_t hash) const; int32_t getUnlocked(const std::string_view sv) const noexcept; int32_t getUnlocked(const std::string_view sv, const uint32_t hash) const noexcept; + int32_t getOwnedUnlocked(const std::string_view sv) const noexcept; + int32_t getOwnedUnlocked(const std::string_view sv, const uint32_t hash) const noexcept; std::string getStringUnlocked(int32_t string_id) const noexcept; std::string getOwnedStringChecked(const int string_id) const noexcept; std::pair getOwnedStringBytesChecked(const int string_id) const noexcept; diff --git a/omniscidb/Tests/StringDictionaryTest.cpp b/omniscidb/Tests/StringDictionaryTest.cpp index fa7f1e14a..4df9885b4 100644 --- a/omniscidb/Tests/StringDictionaryTest.cpp +++ b/omniscidb/Tests/StringDictionaryTest.cpp @@ -380,6 +380,86 @@ TEST(NestedStringDictionary, GetOrAddTransient) { } } +TEST(NestedStringDictionary, GetBulk) { + auto dict1 = + std::make_shared(DictRef{-1, 1}, -1, g_cache_string_hash); + ASSERT_EQ(dict1->getOrAdd("str1"), 0); + ASSERT_EQ(dict1->getOrAdd("str2"), 1); + ASSERT_EQ(dict1->getOrAdd("str3"), 2); + auto dict2 = std::make_shared(dict1, -1, g_cache_string_hash); + ASSERT_EQ(dict1->getOrAdd("str4"), 3); + ASSERT_EQ(dict2->getOrAdd("str5"), 3); + ASSERT_EQ(dict2->getOrAdd("str6"), 4); + + { + std::vector ids(5, -10); + auto missing = dict1->getBulk( + std::vector{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s}, + ids.data(), + -1); + ASSERT_EQ(missing, (size_t)1); + ASSERT_EQ(ids, std::vector({0, 1, 2, 3, StringDictionary::INVALID_STR_ID})); + } + + { + std::vector ids(5, -10); + auto missing = dict1->getBulk( + std::vector{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s}, + ids.data(), + 2); + ASSERT_EQ(missing, (size_t)3); + ASSERT_EQ(ids, + std::vector({0, + 1, + StringDictionary::INVALID_STR_ID, + StringDictionary::INVALID_STR_ID, + StringDictionary::INVALID_STR_ID})); + } + + { + std::vector ids(5, -10); + auto missing = dict2->getBulk( + std::vector{"str1"s, "str2"s, "str3"s, "str4"s, "str6"s}, + ids.data(), + -1); + ASSERT_EQ(missing, (size_t)1); + ASSERT_EQ(ids, std::vector({0, 1, 2, StringDictionary::INVALID_STR_ID, 4})); + } + + { + std::vector ids(5, -10); + auto missing = dict2->getBulk( + std::vector{"str1"s, "str2"s, "str3"s, "str5"s, "str6"s}, + ids.data(), + 4); + ASSERT_EQ(missing, (size_t)1); + ASSERT_EQ(ids, std::vector({0, 1, 2, 3, StringDictionary::INVALID_STR_ID})); + } + + { + std::vector ids(5, -10); + auto missing = dict2->getBulk( + std::vector{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s}, + ids.data(), + 2); + ASSERT_EQ(missing, (size_t)3); + ASSERT_EQ(ids, + std::vector({0, + 1, + StringDictionary::INVALID_STR_ID, + StringDictionary::INVALID_STR_ID, + StringDictionary::INVALID_STR_ID})); + } + + { + std::vector ids(2, -10); + auto missing = + dict2->getBulk(std::vector{"str1"s, "str2"s}, ids.data(), -1); + ASSERT_EQ(missing, (size_t)0); + ASSERT_EQ(ids, std::vector({0, 1})); + } +} + static std::shared_ptr create_and_fill_dictionary() { const DictRef dict_ref(-1, 1); std::shared_ptr string_dict =