Skip to content
This repository has been archived by the owner on May 9, 2024. It is now read-only.

Commit

Permalink
Support nested dictionaries in StringDictionary::getBulk.
Browse files Browse the repository at this point in the history
Signed-off-by: ienkovich <[email protected]>
  • Loading branch information
ienkovich committed Oct 9, 2023
1 parent 87977e8 commit ba0b4a7
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 62 deletions.
129 changes: 67 additions & 62 deletions omniscidb/StringDictionary/StringDictionary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,81 +298,77 @@ template <class T, class String>
size_t StringDictionary::getBulk(const std::vector<String>& string_vec,
T* encoded_vec,
const int64_t generation) const {
CHECK(!base_dict_) << "Not implemented";
constexpr int64_t target_strings_per_thread{1000};
const int64_t num_lookup_strings = string_vec.size();
if (num_lookup_strings == 0) {
return 0;
}

const ThreadInfo thread_info(
std::thread::hardware_concurrency(), num_lookup_strings, target_strings_per_thread);
CHECK_GE(thread_info.num_threads, 1L);
CHECK_GE(thread_info.num_elems_per_thread, 1L);

std::vector<size_t> num_strings_not_found_per_thread(thread_info.num_threads, 0UL);
size_t base_num_strings_not_found = string_vec.size();
if (base_dict_) {
auto base_generation_for_bulk =
generation >= 0 ? std::min(generation, base_generation_) : base_generation_;
base_num_strings_not_found =
base_dict_->getBulk(string_vec, encoded_vec, base_generation_for_bulk);
}

mapd_shared_lock<mapd_shared_mutex> read_lock(rw_mutex_);
const int64_t num_dict_strings = generation >= 0 ? generation : storageEntryCount();
const bool dictionary_is_empty = (num_dict_strings == 0);
if (dictionary_is_empty) {
tbb::parallel_for(tbb::blocked_range<int64_t>(0, num_lookup_strings),
[&](const tbb::blocked_range<int64_t>& r) {
const int64_t start_idx = r.begin();
const int64_t end_idx = r.end();
for (int64_t string_idx = start_idx; string_idx < end_idx;
++string_idx) {
encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
}
});
return num_lookup_strings;
const int64_t num_dict_strings = generation >= 0 ? generation : entryCount();
const bool skip_owned_string =
(num_dict_strings <= base_generation_) || !base_num_strings_not_found;
if (skip_owned_string) {
// Need to fill the resulting vector if it wasn't done by the base dictionary.
if (!base_dict_) {
tbb::parallel_for(tbb::blocked_range<int64_t>(
0, num_lookup_strings, (size_t)64 << 10 /* 256KB chunks*/),
[&](const tbb::blocked_range<int64_t>& r) {
const int64_t start_idx = r.begin();
const int64_t end_idx = r.end();
for (int64_t string_idx = start_idx; string_idx < end_idx;
++string_idx) {
encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
}
});
}
return base_num_strings_not_found;
}
// If we're here the generation-capped dictionary has strings in it
// that we need to look up against

tbb::task_arena limited_arena(thread_info.num_threads);
limited_arena.execute([&] {
CHECK_LE(tbb::this_task_arena::max_concurrency(), thread_info.num_threads);
tbb::parallel_for(
tbb::blocked_range<int64_t>(
0, num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */),
[&](const tbb::blocked_range<int64_t>& r) {
const int64_t start_idx = r.begin();
const int64_t end_idx = r.end();
size_t num_strings_not_found = 0;
for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) {
const auto& input_string = string_vec[string_idx];
if (input_string.empty()) {
encoded_vec[string_idx] = inline_int_null_value<T>();
continue;
}
if (input_string.size() > StringDictionary::MAX_STRLEN) {
throw_string_too_long_error(input_string, dict_ref_);
}
const uint32_t input_string_hash = hash_string(input_string);
uint32_t hash_bucket =
computeBucket(input_string_hash, input_string, string_id_uint32_table_);
// Will either be legit id or INVALID_STR_ID
const auto string_id = string_id_uint32_table_[hash_bucket];
if (string_id == StringDictionary::INVALID_STR_ID ||
string_id >= num_dict_strings) {
encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
num_strings_not_found++;
continue;
}
encoded_vec[string_idx] = string_id;
size_t found_owned = tbb::parallel_reduce(
tbb::blocked_range<int64_t>(
0, num_lookup_strings, target_strings_per_thread /* tbb grain_size */),
(size_t)0,
[&](const tbb::blocked_range<int64_t>& r, size_t found) {
const int64_t start_idx = r.begin();
const int64_t end_idx = r.end();
for (int64_t string_idx = start_idx; string_idx != end_idx; ++string_idx) {
if (base_dict_ && encoded_vec[string_idx] != StringDictionary::INVALID_STR_ID) {
continue;
}
const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index();
num_strings_not_found_per_thread[tbb_thread_idx] = num_strings_not_found;
},
tbb::simple_partitioner());
});
const auto& input_string = string_vec[string_idx];
if (input_string.empty()) {
encoded_vec[string_idx] = inline_int_null_value<T>();
++found;
continue;
}
if (input_string.size() > StringDictionary::MAX_STRLEN) {
throw_string_too_long_error(input_string, dict_ref_);
}
// Will either be legit id or INVALID_STR_ID
const auto string_id = getOwnedUnlocked(input_string);
if (string_id == StringDictionary::INVALID_STR_ID ||
string_id >= num_dict_strings) {
encoded_vec[string_idx] = StringDictionary::INVALID_STR_ID;
continue;
}
encoded_vec[string_idx] = string_id;
++found;
}
return found;
},
std::plus<size_t>());

size_t num_strings_not_found = 0;
for (int64_t thread_idx = 0; thread_idx < thread_info.num_threads; ++thread_idx) {
num_strings_not_found += num_strings_not_found_per_thread[thread_idx];
}
return num_strings_not_found;
return base_num_strings_not_found - found_owned;
}

template size_t StringDictionary::getBulk(const std::vector<std::string>& string_vec,
Expand Down Expand Up @@ -565,6 +561,15 @@ int32_t StringDictionary::getUnlocked(const std::string_view sv,
return base_res;
}
}
return getOwnedUnlocked(sv, hash);
}

int32_t StringDictionary::getOwnedUnlocked(const std::string_view sv) const noexcept {
return getOwnedUnlocked(sv, hash_string(sv));
}

int32_t StringDictionary::getOwnedUnlocked(const std::string_view sv,
const uint32_t hash) const noexcept {
auto str_id = string_id_uint32_table_[computeBucket(hash, sv, string_id_uint32_table_)];
return str_id;
}
Expand Down
2 changes: 2 additions & 0 deletions omniscidb/StringDictionary/StringDictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class StringDictionary {
int32_t getIdOfString(const String&, const uint32_t hash) const;
int32_t getUnlocked(const std::string_view sv) const noexcept;
int32_t getUnlocked(const std::string_view sv, const uint32_t hash) const noexcept;
int32_t getOwnedUnlocked(const std::string_view sv) const noexcept;
int32_t getOwnedUnlocked(const std::string_view sv, const uint32_t hash) const noexcept;
std::string getStringUnlocked(int32_t string_id) const noexcept;
std::string getOwnedStringChecked(const int string_id) const noexcept;
std::pair<char*, size_t> getOwnedStringBytesChecked(const int string_id) const noexcept;
Expand Down
80 changes: 80 additions & 0 deletions omniscidb/Tests/StringDictionaryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,86 @@ TEST(NestedStringDictionary, GetOrAddTransient) {
}
}

TEST(NestedStringDictionary, GetBulk) {
auto dict1 =
std::make_shared<StringDictionary>(DictRef{-1, 1}, -1, g_cache_string_hash);
ASSERT_EQ(dict1->getOrAdd("str1"), 0);
ASSERT_EQ(dict1->getOrAdd("str2"), 1);
ASSERT_EQ(dict1->getOrAdd("str3"), 2);
auto dict2 = std::make_shared<StringDictionary>(dict1, -1, g_cache_string_hash);
ASSERT_EQ(dict1->getOrAdd("str4"), 3);
ASSERT_EQ(dict2->getOrAdd("str5"), 3);
ASSERT_EQ(dict2->getOrAdd("str6"), 4);

{
std::vector<int> ids(5, -10);
auto missing = dict1->getBulk(
std::vector<std::string>{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s},
ids.data(),
-1);
ASSERT_EQ(missing, (size_t)1);
ASSERT_EQ(ids, std::vector<int>({0, 1, 2, 3, StringDictionary::INVALID_STR_ID}));
}

{
std::vector<int> ids(5, -10);
auto missing = dict1->getBulk(
std::vector<std::string>{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s},
ids.data(),
2);
ASSERT_EQ(missing, (size_t)3);
ASSERT_EQ(ids,
std::vector<int>({0,
1,
StringDictionary::INVALID_STR_ID,
StringDictionary::INVALID_STR_ID,
StringDictionary::INVALID_STR_ID}));
}

{
std::vector<int> ids(5, -10);
auto missing = dict2->getBulk(
std::vector<std::string>{"str1"s, "str2"s, "str3"s, "str4"s, "str6"s},
ids.data(),
-1);
ASSERT_EQ(missing, (size_t)1);
ASSERT_EQ(ids, std::vector<int>({0, 1, 2, StringDictionary::INVALID_STR_ID, 4}));
}

{
std::vector<int> ids(5, -10);
auto missing = dict2->getBulk(
std::vector<std::string>{"str1"s, "str2"s, "str3"s, "str5"s, "str6"s},
ids.data(),
4);
ASSERT_EQ(missing, (size_t)1);
ASSERT_EQ(ids, std::vector<int>({0, 1, 2, 3, StringDictionary::INVALID_STR_ID}));
}

{
std::vector<int> ids(5, -10);
auto missing = dict2->getBulk(
std::vector<std::string>{"str1"s, "str2"s, "str3"s, "str4"s, "str5"s},
ids.data(),
2);
ASSERT_EQ(missing, (size_t)3);
ASSERT_EQ(ids,
std::vector<int>({0,
1,
StringDictionary::INVALID_STR_ID,
StringDictionary::INVALID_STR_ID,
StringDictionary::INVALID_STR_ID}));
}

{
std::vector<int> ids(2, -10);
auto missing =
dict2->getBulk(std::vector<std::string>{"str1"s, "str2"s}, ids.data(), -1);
ASSERT_EQ(missing, (size_t)0);
ASSERT_EQ(ids, std::vector<int>({0, 1}));
}
}

static std::shared_ptr<StringDictionary> create_and_fill_dictionary() {
const DictRef dict_ref(-1, 1);
std::shared_ptr<StringDictionary> string_dict =
Expand Down

0 comments on commit ba0b4a7

Please sign in to comment.