diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index 06dd57edf66..d3cbea74195 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -182,18 +182,18 @@ class AggregateFunctionGroupUniqArrayGeneric { // We have to copy the keys to our arena. assert(arena != nullptr); - cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted); + cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), arena}, it, inserted); } } void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - ColumnArray & arr_to = assert_cast(to); + auto & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & data_to = arr_to.getData(); auto & set = this->data(place).value; - offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + set.size()); + offsets_to.push_back((offsets_to.empty() ? 0 : offsets_to.back()) + set.size()); for (auto & elem : set) deserializeAndInsert(elem.getValue(), data_to); diff --git a/dbms/src/AggregateFunctions/KeyHolderHelpers.h b/dbms/src/AggregateFunctions/KeyHolderHelpers.h index 6677866f0d3..b8a4ee0def3 100644 --- a/dbms/src/AggregateFunctions/KeyHolderHelpers.h +++ b/dbms/src/AggregateFunctions/KeyHolderHelpers.h @@ -24,7 +24,7 @@ inline auto getKeyHolder(const IColumn & column, size_t row_num, Arena & arena) { if constexpr (is_plain_column) { - return ArenaKeyHolder{column.getDataAt(row_num), arena}; + return ArenaKeyHolder{column.getDataAt(row_num), &arena}; } else { diff --git a/dbms/src/Common/ColumnsHashing.h b/dbms/src/Common/ColumnsHashing.h index 398d6605e60..23dd30ecc44 100644 --- a/dbms/src/Common/ColumnsHashing.h +++ b/dbms/src/Common/ColumnsHashing.h @@ -48,15 +48,20 @@ struct HashMethodOneNumber using Self = HashMethodOneNumber; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = false; + const FieldType * vec; + const size_t total_rows; /// If the keys of a fixed length then key_sizes contains their lengths, empty otherwise. HashMethodOneNumber(const ColumnRawPtrs & key_columns, const Sizes & /*key_sizes*/, const TiDB::TiDBCollators &) + : total_rows(key_columns[0]->size()) { vec = &static_cast *>(key_columns[0])->getData()[0]; } explicit HashMethodOneNumber(const IColumn * column) + : total_rows(column->size()) { vec = &static_cast *>(column)->getData()[0]; } @@ -86,58 +91,65 @@ struct HashMethodOneNumber /// For the case when there is one string key. -template +template struct HashMethodString - : public columns_hashing_impl:: - HashMethodBase, Value, Mapped, use_cache> + : public columns_hashing_impl::HashMethodBase, Value, Mapped, use_cache> { - using Self = HashMethodString; + using Self = HashMethodString; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = false; + const IColumn::Offset * offsets; const UInt8 * chars; TiDB::TiDBCollatorPtr collator = nullptr; + const size_t total_rows; HashMethodString( const ColumnRawPtrs & key_columns, const Sizes & /*key_sizes*/, const TiDB::TiDBCollators & collators) + : total_rows(key_columns[0]->size()) { const IColumn & column = *key_columns[0]; const auto & column_string = assert_cast(column); offsets = column_string.getOffsets().data(); chars = column_string.getChars().data(); if (!collators.empty()) - { - if constexpr (!place_string_to_arena) - throw Exception("String with collator must be placed on arena.", ErrorCodes::LOGICAL_ERROR); collator = collators[0]; - } } - ALWAYS_INLINE inline auto getKeyHolder( + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolder( ssize_t row, [[maybe_unused]] Arena * pool, - std::vector & sort_key_containers) const + [[maybe_unused]] std::vector & sort_key_containers) const { - auto last_offset = row == 0 ? 0 : offsets[row - 1]; - // Remove last zero byte. - StringRef key(chars + last_offset, offsets[row] - last_offset - 1); + auto key = getKey(row); + if (likely(collator)) + key = collator->sortKey(key.data, key.size, sort_key_containers[0]); - if constexpr (place_string_to_arena) - { - if (likely(collator)) - key = collator->sortKey(key.data, key.size, sort_key_containers[0]); - return ArenaKeyHolder{key, *pool}; - } - else - { - return key; - } + return ArenaKeyHolder{key, pool}; + } + + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolder(ssize_t row, Arena * pool, Arena * sort_key_pool) const + { + auto key = getKey(row); + if (likely(collator)) + key = collator->sortKey(key.data, key.size, *sort_key_pool); + + return ArenaKeyHolder{key, pool}; } protected: friend class columns_hashing_impl::HashMethodBase; + +private: + ALWAYS_INLINE inline StringRef getKey(size_t row) const + { + auto last_offset = row == 0 ? 0 : offsets[row - 1]; + // Remove last zero byte. + return StringRef(chars + last_offset, offsets[row] - last_offset - 1); + } }; template @@ -147,10 +159,14 @@ struct HashMethodStringBin using Self = HashMethodStringBin; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = false; + const IColumn::Offset * offsets; const UInt8 * chars; + const size_t total_rows; HashMethodStringBin(const ColumnRawPtrs & key_columns, const Sizes & /*key_sizes*/, const TiDB::TiDBCollators &) + : total_rows(key_columns[0]->size()) { const IColumn & column = *key_columns[0]; const auto & column_string = assert_cast(column); @@ -159,11 +175,16 @@ struct HashMethodStringBin } ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const + { + return getKeyHolder(row, pool, nullptr); + } + + ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, Arena *) const { auto last_offset = row == 0 ? 0 : offsets[row - 1]; StringRef key(chars + last_offset, offsets[row] - last_offset - 1); key = BinCollatorSortKey(key.data, key.size); - return ArenaKeyHolder{key, *pool}; + return ArenaKeyHolder{key, pool}; } protected: @@ -344,12 +365,16 @@ struct HashMethodFastPathTwoKeysSerialized using Self = HashMethodFastPathTwoKeysSerialized; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = true; + Key1Desc key_1_desc; Key2Desc key_2_desc; + const size_t total_rows; HashMethodFastPathTwoKeysSerialized(const ColumnRawPtrs & key_columns, const Sizes &, const TiDB::TiDBCollators &) : key_1_desc(key_columns[0]) , key_2_desc(key_columns[1]) + , total_rows(key_columns[0]->size()) {} ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const @@ -370,25 +395,26 @@ struct HashMethodFastPathTwoKeysSerialized /// For the case when there is one fixed-length string key. -template +template struct HashMethodFixedString - : public columns_hashing_impl::HashMethodBase< - HashMethodFixedString, - Value, - Mapped, - use_cache> + : public columns_hashing_impl:: + HashMethodBase, Value, Mapped, use_cache> { - using Self = HashMethodFixedString; + using Self = HashMethodFixedString; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = false; + size_t n; const ColumnFixedString::Chars_t * chars; TiDB::TiDBCollatorPtr collator = nullptr; + const size_t total_rows; HashMethodFixedString( const ColumnRawPtrs & key_columns, const Sizes & /*key_sizes*/, const TiDB::TiDBCollators & collators) + : total_rows(key_columns[0]->size()) { const IColumn & column = *key_columns[0]; const auto & column_string = assert_cast(column); @@ -398,26 +424,25 @@ struct HashMethodFixedString collator = collators[0]; } - ALWAYS_INLINE inline auto getKeyHolder( + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolder( size_t row, - [[maybe_unused]] Arena * pool, + Arena * pool, std::vector & sort_key_containers) const { StringRef key(&(*chars)[row * n], n); - if (collator) - { key = collator->sortKeyFastPath(key.data, key.size, sort_key_containers[0]); - } - if constexpr (place_string_to_arena) - { - return ArenaKeyHolder{key, *pool}; - } - else - { - return key; - } + return ArenaKeyHolder{key, pool}; + } + + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolder(size_t row, Arena * pool, Arena * sort_key_pool) const + { + StringRef key(&(*chars)[row * n], n); + if (collator) + key = collator->sortKeyFastPath(key.data, key.size, *sort_key_pool); + + return ArenaKeyHolder{key, pool}; } protected: @@ -438,10 +463,12 @@ struct HashMethodKeysFixed using BaseHashed = columns_hashing_impl::HashMethodBase; using Base = columns_hashing_impl::BaseStateKeysFixed; + static constexpr bool is_serialized_key = false; static constexpr bool has_nullable_keys = has_nullable_keys_; Sizes key_sizes; size_t keys_size; + const size_t total_rows; /// SSSE3 shuffle method can be used. Shuffle masks will be calculated and stored here. #if defined(__SSSE3__) && !defined(MEMORY_SANITIZER) @@ -467,6 +494,7 @@ struct HashMethodKeysFixed : Base(key_columns) , key_sizes(std::move(key_sizes_)) , keys_size(key_columns.size()) + , total_rows(key_columns[0]->size()) { if (usePreparedKeys(key_sizes)) { @@ -593,9 +621,12 @@ struct HashMethodSerialized using Self = HashMethodSerialized; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = true; + ColumnRawPtrs key_columns; size_t keys_size; TiDB::TiDBCollators collators; + const size_t total_rows; HashMethodSerialized( const ColumnRawPtrs & key_columns_, @@ -604,6 +635,7 @@ struct HashMethodSerialized : key_columns(key_columns_) , keys_size(key_columns_.size()) , collators(collators_) + , total_rows(key_columns_[0]->size()) {} ALWAYS_INLINE inline SerializedKeyHolder getKeyHolder( @@ -629,12 +661,16 @@ struct HashMethodHashed using Self = HashMethodHashed; using Base = columns_hashing_impl::HashMethodBase; + static constexpr bool is_serialized_key = false; + ColumnRawPtrs key_columns; TiDB::TiDBCollators collators; + const size_t total_rows; HashMethodHashed(ColumnRawPtrs key_columns_, const Sizes &, const TiDB::TiDBCollators & collators_) : key_columns(std::move(key_columns_)) , collators(collators_) + , total_rows(key_columns[0]->size()) {} ALWAYS_INLINE inline Key getKeyHolder(size_t row, Arena *, std::vector & sort_key_containers) const diff --git a/dbms/src/Common/ColumnsHashingImpl.h b/dbms/src/Common/ColumnsHashingImpl.h index d4f4143015d..fcbfc4bc358 100644 --- a/dbms/src/Common/ColumnsHashingImpl.h +++ b/dbms/src/Common/ColumnsHashingImpl.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -127,7 +128,17 @@ class HashMethodBase using FindResult = FindResultImpl; static constexpr bool has_mapped = !std::is_same::value; using Cache = LastElementCache; + static constexpr size_t prefetch_step = 16; + template + static ALWAYS_INLINE inline void prefetch(Map & map, size_t idx, const std::vector & hashvals) + { + const auto prefetch_idx = idx + prefetch_step; + if likely (prefetch_idx < hashvals.size()) + map.prefetch(hashvals[prefetch_idx]); + } + + // Emplace key without hashval, and this method doesn't support prefetch. template ALWAYS_INLINE inline EmplaceResult emplaceKey( Data & data, @@ -150,14 +161,95 @@ class HashMethodBase return findKeyImpl(keyHolderGetKey(key_holder), data); } + // Emplace key using hashval, you can enable prefetch or not. + template + ALWAYS_INLINE inline EmplaceResult emplaceKey( + Data & data, + size_t row, + Arena & pool, + std::vector & sort_key_containers, + const std::vector & hashvals) + { + auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); + if constexpr (enable_prefetch) + { + assert(hashvals.size() == static_cast(*this).total_rows); + prefetch(data, row, hashvals); + return emplaceImpl(key_holder, data, hashvals[row]); + } + else + { + return emplaceImpl(key_holder, data); + } + } + + template + ALWAYS_INLINE inline FindResult findKey( + Data & data, + size_t row, + Arena & pool, + std::vector & sort_key_containers, + const std::vector & hashvals) + { + auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); + if constexpr (enable_prefetch) + { + assert(hashvals.size() == static_cast(*this).total_rows); + prefetch(data, row, hashvals); + return findKeyImpl(keyHolderGetKey(key_holder), data, hashvals[row]); + } + else + { + return findKeyImpl(keyHolderGetKey(key_holder), data); + } + } + + template + ALWAYS_INLINE inline EmplaceResult emplaceStringKey( + Data & data, + size_t idx, + std::vector & datas, + const std::vector & hashvals) + { + // For spill, hashvals.size() will be le to total_rows. + // Because only remaining rows that didn't insert into HashMap will be handled here. + assert(hashvals.size() <= static_cast(*this).total_rows); + + auto & submap = StringHashTableSubMapSelector>::getSubMap( + hashvals[idx], + data); + if constexpr (enable_prefetch) + prefetch(submap, idx, hashvals); + + return emplaceImpl(datas[idx], submap, hashvals[idx]); + } + + template + ALWAYS_INLINE inline FindResult findStringKey( + Data & data, + size_t idx, + std::vector & datas, + const std::vector & hashvals) + { + assert(hashvals.size() <= static_cast(*this).total_rows); + + auto & submap = StringHashTableSubMapSelector>::getSubMap( + hashvals[idx], + data); + if constexpr (enable_prefetch) + prefetch(submap, idx, hashvals); + + return findKeyImpl(keyHolderGetKey(datas[idx]), submap, hashvals[idx]); + } + template ALWAYS_INLINE inline size_t getHash( const Data & data, size_t row, Arena & pool, - std::vector & sort_key_containers) + std::vector & sort_key_containers) const { - auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); + auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); return data.hash(keyHolderGetKey(key_holder)); } @@ -179,99 +271,129 @@ class HashMethodBase } } +#define DEFINE_EMPLACE_IMPL_BEGIN \ + if constexpr (Cache::consecutive_keys_optimization) \ + { \ + if (cache.found && cache.check(keyHolderGetKey(key_holder))) \ + { \ + if constexpr (has_mapped) \ + return EmplaceResult(cache.value.second, cache.value.second, false); \ + else \ + return EmplaceResult(false); \ + } \ + } \ + typename Data::LookupResult it; \ + bool inserted = false; + +#define DEFINE_EMPLACE_IMPL_END \ + [[maybe_unused]] Mapped * cached = nullptr; \ + if constexpr (has_mapped) \ + cached = &it->getMapped(); \ + \ + if (inserted) \ + { \ + if constexpr (has_mapped) \ + { \ + new (&it->getMapped()) Mapped(); \ + } \ + } \ + \ + if constexpr (consecutive_keys_optimization) \ + { \ + cache.found = true; \ + cache.empty = false; \ + \ + if constexpr (has_mapped) \ + { \ + cache.value.first = it->getKey(); \ + cache.value.second = it->getMapped(); \ + cached = &cache.value.second; \ + } \ + else \ + { \ + cache.value = it->getKey(); \ + } \ + } \ + \ + if constexpr (has_mapped) \ + return EmplaceResult(it->getMapped(), *cached, inserted); \ + else \ + return EmplaceResult(inserted); + + // This method is performance critical, so there are two emplaceImpl to make sure caller can use the one they need. template - ALWAYS_INLINE inline EmplaceResult emplaceImpl(KeyHolder & key_holder, Data & data) + ALWAYS_INLINE inline EmplaceResult emplaceImpl(KeyHolder & key_holder, Data & data, size_t hashval) { - if constexpr (Cache::consecutive_keys_optimization) - { - if (cache.found && cache.check(keyHolderGetKey(key_holder))) - { - if constexpr (has_mapped) - return EmplaceResult(cache.value.second, cache.value.second, false); - else - return EmplaceResult(false); - } - } + DEFINE_EMPLACE_IMPL_BEGIN + data.emplace(key_holder, it, inserted, hashval); + DEFINE_EMPLACE_IMPL_END + } - typename Data::LookupResult it; - bool inserted = false; + template + ALWAYS_INLINE inline EmplaceResult emplaceImpl(KeyHolder & key_holder, Data & data) + { + DEFINE_EMPLACE_IMPL_BEGIN data.emplace(key_holder, it, inserted); - - [[maybe_unused]] Mapped * cached = nullptr; - if constexpr (has_mapped) - cached = &it->getMapped(); - - if (inserted) - { - if constexpr (has_mapped) - { - new (&it->getMapped()) Mapped(); - } - } - - if constexpr (consecutive_keys_optimization) - { - cache.found = true; - cache.empty = false; - - if constexpr (has_mapped) - { - cache.value.first = it->getKey(); - cache.value.second = it->getMapped(); - cached = &cache.value.second; - } - else - { - cache.value = it->getKey(); - } - } - - if constexpr (has_mapped) - return EmplaceResult(it->getMapped(), *cached, inserted); - else - return EmplaceResult(inserted); + DEFINE_EMPLACE_IMPL_END } +#undef DEFINE_EMPLACE_IMPL_BEGIN +#undef DEFINE_EMPLACE_IMPL_END + +#define DEFINE_FIND_IMPL_BEGIN \ + if constexpr (Cache::consecutive_keys_optimization) \ + { \ + if (cache.check(key)) \ + { \ + if constexpr (has_mapped) \ + return FindResult(&cache.value.second, cache.found); \ + else \ + return FindResult(cache.found); \ + } \ + } \ + typename Data::LookupResult it; + +#define DEFINE_FIND_IMPL_END \ + if constexpr (consecutive_keys_optimization) \ + { \ + cache.found = it != nullptr; \ + cache.empty = false; \ + \ + if constexpr (has_mapped) \ + { \ + cache.value.first = key; \ + if (it) \ + { \ + cache.value.second = it->getMapped(); \ + } \ + } \ + else \ + { \ + cache.value = key; \ + } \ + } \ + \ + if constexpr (has_mapped) \ + return FindResult(it ? &it->getMapped() : nullptr, it != nullptr); \ + else \ + return FindResult(it != nullptr); template - ALWAYS_INLINE inline FindResult findKeyImpl(Key key, Data & data) + ALWAYS_INLINE inline FindResult findKeyImpl(Key & key, Data & data) { - if constexpr (Cache::consecutive_keys_optimization) - { - if (cache.check(key)) - { - if constexpr (has_mapped) - return FindResult(&cache.value.second, cache.found); - else - return FindResult(cache.found); - } - } - - auto it = data.find(key); - - if constexpr (consecutive_keys_optimization) - { - cache.found = it != nullptr; - cache.empty = false; - - if constexpr (has_mapped) - { - cache.value.first = key; - if (it) - { - cache.value.second = it->getMapped(); - } - } - else - { - cache.value = key; - } - } + DEFINE_FIND_IMPL_BEGIN + it = data.find(key); + DEFINE_FIND_IMPL_END + } - if constexpr (has_mapped) - return FindResult(it ? &it->getMapped() : nullptr, it != nullptr); - else - return FindResult(it != nullptr); + template + ALWAYS_INLINE inline FindResult findKeyImpl(Key & key, Data & data, size_t hashval) + { + DEFINE_FIND_IMPL_BEGIN + it = data.find(key, hashval); + DEFINE_FIND_IMPL_END } +#undef DEFINE_FIND_IMPL_BEGIN +#undef DEFINE_FIND_IMPL_END }; diff --git a/dbms/src/Common/FailPoint.cpp b/dbms/src/Common/FailPoint.cpp index be684cf3751..49f7f97f5fc 100644 --- a/dbms/src/Common/FailPoint.cpp +++ b/dbms/src/Common/FailPoint.cpp @@ -114,6 +114,7 @@ namespace DB M(force_set_parallel_prehandle_threshold) \ M(force_raise_prehandle_exception) \ M(force_agg_on_partial_block) \ + M(force_agg_prefetch) \ M(force_set_fap_candidate_store_id) \ M(force_not_clean_fap_on_destroy) \ M(force_fap_worker_throw) \ diff --git a/dbms/src/Common/HashTable/FixedHashTable.h b/dbms/src/Common/HashTable/FixedHashTable.h index 259e90684fc..8b0b721aa8c 100644 --- a/dbms/src/Common/HashTable/FixedHashTable.h +++ b/dbms/src/Common/HashTable/FixedHashTable.h @@ -212,7 +212,6 @@ class FixedHashTable typename cell_type::CellExt cell; }; - public: using key_type = Key; using mapped_type = typename Cell::mapped_type; @@ -222,6 +221,8 @@ class FixedHashTable using LookupResult = Cell *; using ConstLookupResult = const Cell *; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; size_t hash(const Key & x) const { return x; } @@ -352,6 +353,8 @@ class FixedHashTable iterator end() { return iterator(this, buf ? buf + NUM_CELLS : buf); } + inline void prefetch(size_t) {} + /// The last parameter is unused but exists for compatibility with HashTable interface. void ALWAYS_INLINE emplace(const Key & x, LookupResult & it, bool & inserted, size_t /* hash */ = 0) { diff --git a/dbms/src/Common/HashTable/Hash.h b/dbms/src/Common/HashTable/Hash.h index b4f5d2c0a04..207919a347e 100644 --- a/dbms/src/Common/HashTable/Hash.h +++ b/dbms/src/Common/HashTable/Hash.h @@ -130,8 +130,8 @@ inline DB::UInt64 wideIntHashCRC32(const T & x, DB::UInt64 updated_value) return updated_value; } static_assert( - DB::IsDecimal< - T> || is_boost_number_v || std::is_same_v || std::is_same_v || std::is_same_v); + DB::IsDecimal || is_boost_number_v || std::is_same_v || std::is_same_v + || std::is_same_v); __builtin_unreachable(); } @@ -244,8 +244,8 @@ inline size_t defaultHash64(const std::enable_if_t, T> & key return boost::multiprecision::hash_value(key); } static_assert( - is_boost_number_v< - T> || std::is_same_v || std::is_same_v || std::is_same_v); + is_boost_number_v || std::is_same_v || std::is_same_v + || std::is_same_v); __builtin_unreachable(); } @@ -297,20 +297,26 @@ inline size_t hashCRC32(const std::enable_if_t, T> & key) template struct HashCRC32; -#define DEFINE_HASH(T) \ - template <> \ - struct HashCRC32 \ - { \ - static_assert(is_fit_register); \ - size_t operator()(T key) const { return hashCRC32(key); } \ +#define DEFINE_HASH(T) \ + template <> \ + struct HashCRC32 \ + { \ + static_assert(is_fit_register); \ + size_t operator()(T key) const \ + { \ + return hashCRC32(key); \ + } \ }; -#define DEFINE_HASH_WIDE(T) \ - template <> \ - struct HashCRC32 \ - { \ - static_assert(!is_fit_register); \ - size_t operator()(const T & key) const { return hashCRC32(key); } \ +#define DEFINE_HASH_WIDE(T) \ + template <> \ + struct HashCRC32 \ + { \ + static_assert(!is_fit_register); \ + size_t operator()(const T & key) const \ + { \ + return hashCRC32(key); \ + } \ }; DEFINE_HASH(DB::UInt8) @@ -416,3 +422,128 @@ struct IntHash32, void>> } } }; + +inline uint64_t umul128(uint64_t v, uint64_t kmul, uint64_t * high) +{ + DB::Int128 res = static_cast(v) * static_cast(kmul); + *high = static_cast(res >> 64); + return static_cast(res); +} + +template +inline void hash_combine(uint64_t & seed, const T & val) +{ + // from: https://github.com/HowardHinnant/hash_append/issues/7#issuecomment-629414712 + seed ^= std::hash{}(val) + 0x9e3779b97f4a7c15LLU + (seed << 12) + (seed >> 4); +} + +inline uint64_t hash_int128(uint64_t seed, const DB::Int128 & v) +{ + auto low = static_cast(v); + auto high = static_cast(v >> 64); + hash_combine(seed, low); + hash_combine(seed, high); + return seed; +} + +inline uint64_t hash_uint128(uint64_t seed, const DB::UInt128 & v) +{ + hash_combine(seed, v.low); + hash_combine(seed, v.high); + return seed; +} + +inline uint64_t hash_int256(uint64_t seed, const DB::Int256 & v) +{ + const auto & backend_value = v.backend(); + for (size_t i = 0; i < backend_value.size(); ++i) + { + hash_combine(seed, backend_value.limbs()[i]); + } + return seed; +} + +inline uint64_t hash_uint256(uint64_t seed, const DB::UInt256 & v) +{ + hash_combine(seed, v.a); + hash_combine(seed, v.b); + hash_combine(seed, v.c); + hash_combine(seed, v.d); + return seed; +} + +template +struct HashWithMixSeedHelper +{ + static inline size_t operator()(size_t); +}; + +template <> +struct HashWithMixSeedHelper<4> +{ + static inline size_t operator()(size_t v) + { + // from: https://github.com/aappleby/smhasher/blob/0ff96f7835817a27d0487325b6c16033e2992eb5/src/MurmurHash3.cpp#L102 + static constexpr uint64_t kmul = 0xcc9e2d51UL; + uint64_t mul = v * kmul; + return static_cast(mul ^ (mul >> 32u)); + } +}; + +template <> +struct HashWithMixSeedHelper<8> +{ + static inline size_t operator()(size_t v) + { + // from: https://github.com/martinus/robin-hood-hashing/blob/b21730713f4b5296bec411917c46919f7b38b178/src/include/robin_hood.h#L735 + static constexpr uint64_t kmul = 0xde5fb9d2630458e9ULL; + uint64_t high = 0; + uint64_t low = umul128(v, kmul, &high); + return static_cast(high + low); + } +}; + +template +struct HashWithMixSeed +{ + static size_t operator()(const T & v) + { + return HashWithMixSeedHelper::operator()(std::hash()(v)); + } +}; + +template <> +struct HashWithMixSeed +{ + static size_t operator()(const DB::Int128 & v) + { + return HashWithMixSeedHelper::operator()(hash_int128(0, v)); + } +}; + +template <> +struct HashWithMixSeed +{ + static inline size_t operator()(const DB::UInt128 & v) + { + return HashWithMixSeedHelper::operator()(hash_uint128(0, v)); + } +}; + +template <> +struct HashWithMixSeed +{ + static inline size_t operator()(const DB::Int256 & v) + { + return HashWithMixSeedHelper::operator()(hash_int256(0, v)); + } +}; + +template <> +struct HashWithMixSeed +{ + static inline size_t operator()(const DB::UInt256 & v) + { + return HashWithMixSeedHelper::operator()(hash_uint256(0, v)); + } +}; diff --git a/dbms/src/Common/HashTable/HashMap.h b/dbms/src/Common/HashTable/HashMap.h index fba8e957ac4..3e7f2b7226a 100644 --- a/dbms/src/Common/HashTable/HashMap.h +++ b/dbms/src/Common/HashTable/HashMap.h @@ -59,7 +59,7 @@ PairNoInit, std::decay_t> makePairNoInit(First && fi } -template +template struct HashMapCell { using Mapped = TMapped; @@ -96,7 +96,11 @@ struct HashMapCell } void setHash(size_t /*hash_value*/) {} - size_t getHash(const Hash & hash) const { return hash(value.first); } + template + size_t getHash(const THash & hash) const + { + return hash(value.first); + } bool isZero(const State & state) const { return isZero(value.first, state); } static bool isZero(const Key & key, const State & /*state*/) { return ZeroTraits::check(key); } @@ -171,28 +175,28 @@ struct HashMapCell namespace std { -template -struct tuple_size> : std::integral_constant +template +struct tuple_size> : std::integral_constant { }; -template -struct tuple_element<0, HashMapCell> +template +struct tuple_element<0, HashMapCell> { using type = Key; }; -template -struct tuple_element<1, HashMapCell> +template +struct tuple_element<1, HashMapCell> { using type = TMapped; }; } // namespace std -template -struct HashMapCellWithSavedHash : public HashMapCell +template +struct HashMapCellWithSavedHash : public HashMapCell { - using Base = HashMapCell; + using Base = HashMapCell; size_t saved_hash; @@ -209,7 +213,11 @@ struct HashMapCellWithSavedHash : public HashMapCell } void setHash(size_t hash_value) { saved_hash = hash_value; } - size_t getHash(const Hash & /*hash_function*/) const { return saved_hash; } + template + size_t getHash(const THash & /*hash_function*/) const + { + return saved_hash; + } }; @@ -318,19 +326,19 @@ class HashMapTable : public HashTable -struct tuple_size> : std::integral_constant +template +struct tuple_size> : std::integral_constant { }; -template -struct tuple_element<0, HashMapCellWithSavedHash> +template +struct tuple_element<0, HashMapCellWithSavedHash> { using type = Key; }; -template -struct tuple_element<1, HashMapCellWithSavedHash> +template +struct tuple_element<1, HashMapCellWithSavedHash> { using type = TMapped; }; @@ -343,7 +351,7 @@ template < typename Hash = DefaultHash, typename Grower = HashTableGrower<>, typename Allocator = HashTableAllocator> -using HashMap = HashMapTable, Hash, Grower, Allocator>; +using HashMap = HashMapTable, Hash, Grower, Allocator>; template < @@ -352,16 +360,15 @@ template < typename Hash = DefaultHash, typename Grower = HashTableGrower<>, typename Allocator = HashTableAllocator> -using HashMapWithSavedHash = HashMapTable, Hash, Grower, Allocator>; +using HashMapWithSavedHash = HashMapTable, Hash, Grower, Allocator>; template using HashMapWithStackMemory = HashMapTable< Key, - HashMapCellWithSavedHash, + HashMapCellWithSavedHash, Hash, HashTableGrower, - HashTableAllocatorWithStackMemory< - (1ULL << initial_size_degree) * sizeof(HashMapCellWithSavedHash)>>; + HashTableAllocatorWithStackMemory<(1ULL << initial_size_degree) * sizeof(HashMapCellWithSavedHash)>>; /// ConcurrentHashTable is the base class, it contains a vector of HashTableWithLock, ConcurrentHashMapTable is a derived /// class from ConcurrentHashTable, it makes hash table to be a hash map, and ConcurrentHashMap/ConcurrentHashMapWithSavedHash @@ -417,7 +424,7 @@ template < typename Hash = DefaultHash, typename Grower = HashTableGrower<>, typename Allocator = HashTableAllocator> -using ConcurrentHashMap = ConcurrentHashMapTable, Hash, Grower, Allocator>; +using ConcurrentHashMap = ConcurrentHashMapTable, Hash, Grower, Allocator>; template < @@ -427,4 +434,4 @@ template < typename Grower = HashTableGrower<>, typename Allocator = HashTableAllocator> using ConcurrentHashMapWithSavedHash - = ConcurrentHashMapTable, Hash, Grower, Allocator>; + = ConcurrentHashMapTable, Hash, Grower, Allocator>; diff --git a/dbms/src/Common/HashTable/HashTable.h b/dbms/src/Common/HashTable/HashTable.h index a4f0fe3be03..046d3ba37cf 100644 --- a/dbms/src/Common/HashTable/HashTable.h +++ b/dbms/src/Common/HashTable/HashTable.h @@ -390,7 +390,7 @@ struct AllocatorBufferDeleter template class HashTable : private boost::noncopyable - , protected HashType + , public HashType , protected AllocatorType , protected CellType::State , protected ZeroValueStorage /// empty base optimization @@ -402,6 +402,9 @@ class HashTable using Grower = GrowerType; using Allocator = AllocatorType; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; + protected: friend class const_iterator; friend class iterator; @@ -851,6 +854,11 @@ class HashTable iterator end() { return iterator(this, buf ? buf + grower.bufSize() : buf); } + void ALWAYS_INLINE prefetch(size_t hashval) const + { + const size_t place_value = grower.place(hashval); + __builtin_prefetch(static_cast(&buf[place_value])); + } protected: const_iterator iteratorTo(const Cell * ptr) const { return const_iterator(this, ptr); } diff --git a/dbms/src/Common/HashTable/HashTableKeyHolder.h b/dbms/src/Common/HashTable/HashTableKeyHolder.h index 01b06dce87d..dd8a4b53376 100644 --- a/dbms/src/Common/HashTable/HashTableKeyHolder.h +++ b/dbms/src/Common/HashTable/HashTableKeyHolder.h @@ -91,8 +91,8 @@ namespace DB */ struct ArenaKeyHolder { - StringRef key; - Arena & pool; + StringRef key{}; + Arena * pool = nullptr; }; } // namespace DB @@ -111,14 +111,14 @@ inline void ALWAYS_INLINE keyHolderPersistKey(DB::ArenaKeyHolder & holder) { // Hash table shouldn't ask us to persist a zero key assert(holder.key.size > 0); - holder.key.data = holder.pool.insert(holder.key.data, holder.key.size); + holder.key.data = holder.pool->insert(holder.key.data, holder.key.size); } inline void ALWAYS_INLINE keyHolderPersistKey(DB::ArenaKeyHolder && holder) { // Hash table shouldn't ask us to persist a zero key assert(holder.key.size > 0); - holder.key.data = holder.pool.insert(holder.key.data, holder.key.size); + holder.key.data = holder.pool->insert(holder.key.data, holder.key.size); } inline void ALWAYS_INLINE keyHolderDiscardKey(DB::ArenaKeyHolder &) {} diff --git a/dbms/src/Common/HashTable/SmallTable.h b/dbms/src/Common/HashTable/SmallTable.h index fa40b479430..1292a4205da 100644 --- a/dbms/src/Common/HashTable/SmallTable.h +++ b/dbms/src/Common/HashTable/SmallTable.h @@ -85,6 +85,9 @@ class SmallTable using value_type = typename Cell::value_type; using cell_type = Cell; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; + class Reader final : private Cell::State { public: @@ -296,6 +299,7 @@ class SmallTable iterator ALWAYS_INLINE find(Key x) { return iteratorTo(findCell(x)); } const_iterator ALWAYS_INLINE find(Key x) const { return iteratorTo(findCell(x)); } + void ALWAYS_INLINE prefetch(size_t) {} void write(DB::WriteBuffer & wb) const { diff --git a/dbms/src/Common/HashTable/StringHashMap.h b/dbms/src/Common/HashTable/StringHashMap.h index 6f7e668e1d9..4e05183d691 100644 --- a/dbms/src/Common/HashTable/StringHashMap.h +++ b/dbms/src/Common/HashTable/StringHashMap.h @@ -19,9 +19,9 @@ #include template -struct StringHashMapCell : public HashMapCell +struct StringHashMapCell : public HashMapCell { - using Base = HashMapCell; + using Base = HashMapCell; using value_type = typename Base::value_type; using Base::Base; static constexpr bool need_zero_value_storage = false; @@ -32,10 +32,9 @@ struct StringHashMapCell : public HashMapCell -struct StringHashMapCell - : public HashMapCell +struct StringHashMapCell : public HashMapCell { - using Base = HashMapCell; + using Base = HashMapCell; using value_type = typename Base::value_type; using Base::Base; static constexpr bool need_zero_value_storage = false; @@ -53,10 +52,9 @@ struct StringHashMapCell }; template -struct StringHashMapCell - : public HashMapCell +struct StringHashMapCell : public HashMapCell { - using Base = HashMapCell; + using Base = HashMapCell; using value_type = typename Base::value_type; using Base::Base; static constexpr bool need_zero_value_storage = false; @@ -74,10 +72,9 @@ struct StringHashMapCell }; template -struct StringHashMapCell - : public HashMapCellWithSavedHash +struct StringHashMapCell : public HashMapCellWithSavedHash { - using Base = HashMapCellWithSavedHash; + using Base = HashMapCellWithSavedHash; using value_type = typename Base::value_type; using Base::Base; static constexpr bool need_zero_value_storage = false; @@ -87,42 +84,42 @@ struct StringHashMapCell static const StringRef & getKey(const value_type & value_) { return value_.first; } }; -template +template struct StringHashMapSubMaps { using T0 = StringHashTableEmpty>; using T1 = HashMapTable< StringKey8, StringHashMapCell, - StringHashTableHash, + typename HashSelector::StringKey8Hash, StringHashTableGrower<>, Allocator>; using T2 = HashMapTable< StringKey16, StringHashMapCell, - StringHashTableHash, + typename HashSelector::StringKey16Hash, StringHashTableGrower<>, Allocator>; using T3 = HashMapTable< StringKey24, StringHashMapCell, - StringHashTableHash, + typename HashSelector::StringKey24Hash, StringHashTableGrower<>, Allocator>; using Ts = HashMapTable< StringRef, StringHashMapCell, - StringHashTableHash, + typename HashSelector::StringStrHash, StringHashTableGrower<>, Allocator>; }; -template -class StringHashMap : public StringHashTable> +template +class StringHashMap : public StringHashTable> { public: using Key = StringRef; - using Base = StringHashTable>; + using Base = StringHashTable>; using Self = StringHashMap; using LookupResult = typename Base::LookupResult; diff --git a/dbms/src/Common/HashTable/StringHashTable.h b/dbms/src/Common/HashTable/StringHashTable.h index aa4825f171a..8ab2f5764f6 100644 --- a/dbms/src/Common/HashTable/StringHashTable.h +++ b/dbms/src/Common/HashTable/StringHashTable.h @@ -16,11 +16,16 @@ #include #include +#include +#include +#include #include #include - +struct StringKey0 +{ +}; using StringKey8 = UInt64; using StringKey16 = DB::UInt128; struct StringKey24 @@ -48,23 +53,53 @@ inline StringRef ALWAYS_INLINE toStringRef(const StringKey24 & n) return {reinterpret_cast(&n), 24ul - (__builtin_clzll(n.c) >> 3)}; } -struct StringHashTableHash +inline size_t hash_string_key_24(uint64_t seed, const StringKey24 & v) { + hash_combine(seed, v.a); + hash_combine(seed, v.b); + hash_combine(seed, v.c); + return seed; +} + +template <> +struct HashWithMixSeed +{ + static inline size_t operator()(const StringKey24 & v) + { + return HashWithMixSeedHelper::operator()(hash_string_key_24(0, v)); + } +}; + +struct StringKey0Hash +{ + static size_t ALWAYS_INLINE operator()(StringKey0) { return 0; } +}; + #if defined(__SSE4_2__) - size_t ALWAYS_INLINE operator()(StringKey8 key) const +struct StringKey8Hash +{ + static size_t ALWAYS_INLINE operator()(StringKey8 key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key); return res; } - size_t ALWAYS_INLINE operator()(const StringKey16 & key) const +}; + +struct StringKey16Hash +{ + static size_t ALWAYS_INLINE operator()(const StringKey16 & key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key.low); res = _mm_crc32_u64(res, key.high); return res; } - size_t ALWAYS_INLINE operator()(const StringKey24 & key) const +}; + +struct StringKey24Hash +{ + static size_t ALWAYS_INLINE operator()(const StringKey24 & key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key.a); @@ -72,21 +107,60 @@ struct StringHashTableHash res = _mm_crc32_u64(res, key.c); return res; } +}; #else - size_t ALWAYS_INLINE operator()(StringKey8 key) const +struct StringKey8Hash +{ + static size_t ALWAYS_INLINE operator()(StringKey8 key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 8); } - size_t ALWAYS_INLINE operator()(const StringKey16 & key) const +}; + +struct StringKey16Hash +{ + static size_t ALWAYS_INLINE operator()(const StringKey16 & key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 16); } - size_t ALWAYS_INLINE operator()(const StringKey24 & key) const +}; + +struct StringKey24Hash +{ + static size_t ALWAYS_INLINE operator()(const StringKey24 & key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 24); } +}; #endif - size_t ALWAYS_INLINE operator()(StringRef key) const { return StringRefHash()(key); } +struct StringStrHash +{ + static size_t ALWAYS_INLINE operator()(StringRef key) { return StringRefHash()(key); } +}; + +template +struct StringHashTableHashSelector; + +template <> +struct StringHashTableHashSelector +{ + using StringKey0Hash = StringKey0Hash; + using StringKey8Hash = HashWithMixSeed; + using StringKey16Hash = HashWithMixSeed; + using StringKey24Hash = HashWithMixSeed; + + + using StringStrHash = StringStrHash; +}; + +template <> +struct StringHashTableHashSelector +{ + using StringKey0Hash = StringKey0Hash; + using StringKey8Hash = StringKey8Hash; + using StringKey16Hash = StringKey16Hash; + using StringKey24Hash = StringKey24Hash; + using StringStrHash = StringStrHash; }; template @@ -98,6 +172,8 @@ struct StringHashTableEmpty //-V730 std::aligned_storage_t zero_value_storage; /// Storage of element with zero key. public: + using Hash = StringKey0Hash; + bool hasZero() const { return has_zero; } void setHasZero() @@ -150,6 +226,7 @@ struct StringHashTableEmpty //-V730 return hasZero() ? zeroValue() : nullptr; } + ALWAYS_INLINE inline void prefetch(size_t) {} void write(DB::WriteBuffer & wb) const { zeroValue()->write(wb); } void writeText(DB::WriteBuffer & wb) const { zeroValue()->writeText(wb); } void read(DB::ReadBuffer & rb) { zeroValue()->read(rb); } @@ -157,6 +234,7 @@ struct StringHashTableEmpty //-V730 size_t size() const { return hasZero() ? 1 : 0; } bool empty() const { return !hasZero(); } size_t getBufferSizeInBytes() const { return sizeof(Cell); } + size_t getBufferSizeInCells() const { return 1; } void setResizeCallback(const ResizeCallback &) {} size_t getCollisions() const { return 0; } }; @@ -190,25 +268,121 @@ struct StringHashTableLookupResult friend bool operator!=(const std::nullptr_t &, const StringHashTableLookupResult & b) { return b.mapped_ptr; } }; -template +template +static auto +#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) + NO_INLINE NO_SANITIZE_ADDRESS NO_SANITIZE_THREAD +#else + ALWAYS_INLINE +#endif + dispatchStringHashTable( + size_t row, + KeyHolder && key_holder, + Func0 && func0, + Func8 && func8, + Func16 && func16, + Func24 && func24, + FuncStr && func_str) +{ + const StringRef & x = keyHolderGetKey(key_holder); + const size_t sz = x.size; + if (sz == 0) + { + return func0(StringKey0{}, row); + } + + if (x.data[sz - 1] == 0) + { + // Strings with trailing zeros are not representable as fixed-size + // string keys. Put them to the generic table. + return func_str(key_holder, row); + } + + const char * p = x.data; + // pending bits that needs to be shifted out + const char s = (-sz & 7) * 8; + union + { + StringKey8 k8; + StringKey16 k16; + StringKey24 k24; + UInt64 n[3]; + }; + switch ((sz - 1) >> 3) + { + case 0: // 1..8 bytes + { + // first half page + if ((reinterpret_cast(p) & 2048) == 0) + { + memcpy(&n[0], p, 8); + if constexpr (DB::isLittleEndian()) + n[0] &= (-1ULL >> s); + else + n[0] &= (-1ULL << s); + } + else + { + const char * lp = x.data + x.size - 8; + memcpy(&n[0], lp, 8); + if constexpr (DB::isLittleEndian()) + n[0] >>= s; + else + n[0] <<= s; + } + return func8(k8, row); + } + case 1: // 9..16 bytes + { + memcpy(&n[0], p, 8); + const char * lp = x.data + x.size - 8; + memcpy(&n[1], lp, 8); + if constexpr (DB::isLittleEndian()) + n[1] >>= s; + else + n[1] <<= s; + return func16(k16, row); + } + case 2: // 17..24 bytes + { + memcpy(&n[0], p, 16); + const char * lp = x.data + x.size - 8; + memcpy(&n[2], lp, 8); + if constexpr (DB::isLittleEndian()) + n[2] >>= s; + else + n[2] <<= s; + return func24(k24, row); + } + default: // >= 25 bytes + { + return func_str(key_holder, row); + } + } +} + +template class StringHashTable : private boost::noncopyable { protected: static constexpr size_t NUM_MAPS = 5; + using Self = StringHashTable; + // Map for storing empty string - using T0 = typename SubMaps::T0; + using T0 = typename TSubMaps::T0; // Short strings are stored as numbers - using T1 = typename SubMaps::T1; - using T2 = typename SubMaps::T2; - using T3 = typename SubMaps::T3; + using T1 = typename TSubMaps::T1; + using T2 = typename TSubMaps::T2; + using T3 = typename TSubMaps::T3; // Long strings are stored as StringRef along with saved hash - using Ts = typename SubMaps::Ts; - using Self = StringHashTable; + using Ts = typename TSubMaps::Ts; template friend class TwoLevelStringHashTable; + template + friend struct StringHashTableSubMapSelector; T0 m0; T1 m1; @@ -222,10 +396,14 @@ class StringHashTable : private boost::noncopyable using mapped_type = typename Ts::mapped_type; using value_type = typename Ts::value_type; using cell_type = typename Ts::cell_type; + using SubMaps = TSubMaps; using LookupResult = StringHashTableLookupResult; using ConstLookupResult = StringHashTableLookupResult; + static constexpr bool is_string_hash_map = true; + static constexpr bool is_two_level = false; + StringHashTable() = default; explicit StringHashTable(size_t reserve_for_num_elements) @@ -257,7 +435,6 @@ class StringHashTable : private boost::noncopyable #endif dispatch(Self & self, KeyHolder && key_holder, Func && func) { - StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); const size_t sz = x.size; if (sz == 0) @@ -270,7 +447,7 @@ class StringHashTable : private boost::noncopyable { // Strings with trailing zeros are not representable as fixed-size // string keys. Put them to the generic table. - return func(self.ms, std::forward(key_holder), hash(x)); + return func(self.ms, std::forward(key_holder), SubMaps::Ts::Hash::operator()(x)); } const char * p = x.data; @@ -306,7 +483,7 @@ class StringHashTable : private boost::noncopyable n[0] <<= s; } keyHolderDiscardKey(key_holder); - return func(self.m1, k8, hash(k8)); + return func(self.m1, k8, SubMaps::T1::Hash::operator()(k8)); } case 1: // 9..16 bytes { @@ -318,7 +495,7 @@ class StringHashTable : private boost::noncopyable else n[1] <<= s; keyHolderDiscardKey(key_holder); - return func(self.m2, k16, hash(k16)); + return func(self.m2, k16, SubMaps::T2::Hash::operator()(k16)); } case 2: // 17..24 bytes { @@ -330,11 +507,11 @@ class StringHashTable : private boost::noncopyable else n[2] <<= s; keyHolderDiscardKey(key_holder); - return func(self.m3, k24, hash(k24)); + return func(self.m3, k24, SubMaps::T3::Hash::operator()(k24)); } default: // >= 25 bytes { - return func(self.ms, std::forward(key_holder), hash(x)); + return func(self.ms, std::forward(key_holder), SubMaps::Ts::Hash::operator()(x)); } } } @@ -434,6 +611,11 @@ class StringHashTable : private boost::noncopyable bool empty() const { return m0.empty() && m1.empty() && m2.empty() && m3.empty() && ms.empty(); } + size_t getBufferSizeInCells() const + { + return m0.getBufferSizeInCells() + m1.getBufferSizeInCells() + m2.getBufferSizeInCells() + + m3.getBufferSizeInCells() + ms.getBufferSizeInCells(); + } size_t getBufferSizeInBytes() const { return m0.getBufferSizeInBytes() + m1.getBufferSizeInBytes() + m2.getBufferSizeInBytes() @@ -458,3 +640,46 @@ class StringHashTable : private boost::noncopyable ms.clearAndShrink(); } }; + +template +struct StringHashTableSubMapSelector; + +template +struct StringHashTableSubMapSelector<0, false, Data> +{ + using Hash = typename Data::SubMaps::T0::Hash; + + static typename Data::T0 & getSubMap(size_t, Data & data) { return data.m0; } +}; + +template +struct StringHashTableSubMapSelector<1, false, Data> +{ + using Hash = typename Data::SubMaps::T1::Hash; + + static typename Data::T1 & getSubMap(size_t, Data & data) { return data.m1; } +}; + +template +struct StringHashTableSubMapSelector<2, false, Data> +{ + using Hash = typename Data::SubMaps::T2::Hash; + + static typename Data::T2 & getSubMap(size_t, Data & data) { return data.m2; } +}; + +template +struct StringHashTableSubMapSelector<3, false, Data> +{ + using Hash = typename Data::SubMaps::T3::Hash; + + static typename Data::T3 & getSubMap(size_t, Data & data) { return data.m3; } +}; + +template +struct StringHashTableSubMapSelector<4, false, Data> +{ + using Hash = typename Data::SubMaps::Ts::Hash; + + static typename Data::Ts & getSubMap(size_t, Data & data) { return data.ms; } +}; diff --git a/dbms/src/Common/HashTable/TwoLevelHashMap.h b/dbms/src/Common/HashTable/TwoLevelHashMap.h index e79532f4733..a7da17e67e5 100644 --- a/dbms/src/Common/HashTable/TwoLevelHashMap.h +++ b/dbms/src/Common/HashTable/TwoLevelHashMap.h @@ -63,7 +63,7 @@ template < typename Grower = TwoLevelHashTableGrower<>, typename Allocator = HashTableAllocator, template typename ImplTable = HashMapTable> -using TwoLevelHashMap = TwoLevelHashMapTable, Hash, Grower, Allocator, ImplTable>; +using TwoLevelHashMap = TwoLevelHashMapTable, Hash, Grower, Allocator, ImplTable>; template < @@ -74,4 +74,4 @@ template < typename Allocator = HashTableAllocator, template typename ImplTable = HashMapTable> using TwoLevelHashMapWithSavedHash - = TwoLevelHashMapTable, Hash, Grower, Allocator, ImplTable>; + = TwoLevelHashMapTable, Hash, Grower, Allocator, ImplTable>; diff --git a/dbms/src/Common/HashTable/TwoLevelHashTable.h b/dbms/src/Common/HashTable/TwoLevelHashTable.h index 6778cd4a3e8..da140938c0e 100644 --- a/dbms/src/Common/HashTable/TwoLevelHashTable.h +++ b/dbms/src/Common/HashTable/TwoLevelHashTable.h @@ -60,6 +60,9 @@ class TwoLevelHashTable : private boost::noncopyable static constexpr size_t NUM_BUCKETS = 1ULL << BITS_FOR_BUCKET; static constexpr size_t MAX_BUCKET = NUM_BUCKETS - 1; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = true; + size_t hash(const Key & x) const { return Hash::operator()(x); } /// NOTE Bad for hash tables with more than 2^32 cells. @@ -112,9 +115,9 @@ class TwoLevelHashTable : private boost::noncopyable /// Copy the data from another (normal) hash table. It should have the same hash function. template - explicit TwoLevelHashTable(const Source & src) + explicit TwoLevelHashTable(Source & src) { - typename Source::const_iterator it = src.begin(); + typename Source::iterator it = src.begin(); /// It is assumed that the zero key (stored separately) is first in iteration order. if (it != src.end() && it.getPtr()->isZero(src)) @@ -125,10 +128,21 @@ class TwoLevelHashTable : private boost::noncopyable for (; it != src.end(); ++it) { - const Cell * cell = it.getPtr(); - size_t hash_value = cell->getHash(src); - size_t buck = getBucketFromHash(hash_value); - impls[buck].insertUniqueNonZero(cell, hash_value); + if constexpr (std::is_same_v) + { + const Cell * cell = it.getPtr(); + size_t hash_value = cell->getHash(src); + size_t buck = getBucketFromHash(hash_value); + impls[buck].insertUniqueNonZero(cell, hash_value); + } + else + { + auto * cell = it.getPtr(); + size_t hash_value = Hash::operator()(cell->getKey()); + cell->setHash(hash_value); + size_t buck = getBucketFromHash(hash_value); + impls[buck].insertUniqueNonZero(cell, hash_value); + } } } @@ -285,6 +299,12 @@ class TwoLevelHashTable : private boost::noncopyable impls[buck].emplace(key_holder, it, inserted, hash_value); } + void ALWAYS_INLINE prefetch(size_t hashval) const + { + size_t buck = getBucketFromHash(hashval); + impls[buck].prefetch(hashval); + } + LookupResult ALWAYS_INLINE find(Key x, size_t hash_value) { size_t buck = getBucketFromHash(hash_value); @@ -352,6 +372,13 @@ class TwoLevelHashTable : private boost::noncopyable return true; } + size_t getBufferSizeInCells() const + { + size_t res = 0; + for (const auto & impl : impls) + res += impl.getBufferSizeInCells(); + return res; + } size_t getBufferSizeInBytes() const { size_t res = 0; diff --git a/dbms/src/Common/HashTable/TwoLevelStringHashMap.h b/dbms/src/Common/HashTable/TwoLevelStringHashMap.h index 750e7efc415..7f38246b610 100644 --- a/dbms/src/Common/HashTable/TwoLevelStringHashMap.h +++ b/dbms/src/Common/HashTable/TwoLevelStringHashMap.h @@ -19,15 +19,20 @@ template < typename TMapped, + typename HashSelector, typename Allocator = HashTableAllocator, template typename ImplTable = StringHashMap> class TwoLevelStringHashMap - : public TwoLevelStringHashTable, ImplTable> + : public TwoLevelStringHashTable< + StringHashMapSubMaps, + ImplTable> { public: using Key = StringRef; using Self = TwoLevelStringHashMap; - using Base = TwoLevelStringHashTable, StringHashMap>; + using Base = TwoLevelStringHashTable< + StringHashMapSubMaps, + StringHashMap>; using LookupResult = typename Base::LookupResult; using Base::Base; diff --git a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h index 5bdb24a3d13..9547f14f380 100644 --- a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h +++ b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h @@ -16,7 +16,7 @@ #include -template , size_t BITS_FOR_BUCKET = 8> +template , size_t BITS_FOR_BUCKET = 8> class TwoLevelStringHashTable : private boost::noncopyable { protected: @@ -26,10 +26,14 @@ class TwoLevelStringHashTable : private boost::noncopyable public: using Key = StringRef; using Impl = ImplTable; + using SubMaps = TSubMaps; static constexpr size_t NUM_BUCKETS = 1ULL << BITS_FOR_BUCKET; static constexpr size_t MAX_BUCKET = NUM_BUCKETS - 1; + static constexpr bool is_string_hash_map = true; + static constexpr bool is_two_level = true; + // TODO: currently hashing contains redundant computations when doing distributed or external aggregations size_t hash(const Key & x) const { @@ -62,35 +66,34 @@ class TwoLevelStringHashTable : private boost::noncopyable TwoLevelStringHashTable() = default; template - explicit TwoLevelStringHashTable(const Source & src) + explicit TwoLevelStringHashTable(Source & src) { if (src.m0.hasZero()) impls[0].m0.setHasZero(*src.m0.zeroValue()); - for (auto & v : src.m1) - { - size_t hash_value = v.getHash(src.m1); - size_t buck = getBucketFromHash(hash_value); - impls[buck].m1.insertUniqueNonZero(&v, hash_value); - } - for (auto & v : src.m2) - { - size_t hash_value = v.getHash(src.m2); - size_t buck = getBucketFromHash(hash_value); - impls[buck].m2.insertUniqueNonZero(&v, hash_value); - } - for (auto & v : src.m3) - { - size_t hash_value = v.getHash(src.m3); - size_t buck = getBucketFromHash(hash_value); - impls[buck].m3.insertUniqueNonZero(&v, hash_value); - } - for (auto & v : src.ms) - { - size_t hash_value = v.getHash(src.ms); - size_t buck = getBucketFromHash(hash_value); - impls[buck].ms.insertUniqueNonZero(&v, hash_value); - } +#define M(SUBMAP) \ + for (auto & v : src.m##SUBMAP) \ + { \ + if constexpr (std::is_same_v) \ + { \ + const size_t hash_value = v.getHash(src.m##SUBMAP); \ + size_t buck = getBucketFromHash(hash_value); \ + impls[buck].m##SUBMAP.insertUniqueNonZero(&v, hash_value); \ + } \ + else \ + { \ + const size_t hash_value = SubMaps::T##SUBMAP::Hash::operator()(v.getKey(v.getValue())); \ + v.setHash(hash_value); \ + size_t buck = getBucketFromHash(hash_value); \ + impls[buck].m##SUBMAP.insertUniqueNonZero(&v, hash_value); \ + } \ + } + + M(1) + M(2) + M(3) + M(s) +#undef M } // This function is mostly the same as StringHashTable::dispatch, but with @@ -104,7 +107,6 @@ class TwoLevelStringHashTable : private boost::noncopyable #endif dispatch(Self & self, KeyHolder && key_holder, Func && func) { - StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); const size_t sz = x.size; if (sz == 0) @@ -117,7 +119,7 @@ class TwoLevelStringHashTable : private boost::noncopyable { // Strings with trailing zeros are not representable as fixed-size // string keys. Put them to the generic table. - auto res = hash(x); + auto res = SubMaps::Ts::Hash::operator()(x); auto buck = getBucketFromHash(res); return func(self.impls[buck].ms, std::forward(key_holder), res); } @@ -154,7 +156,7 @@ class TwoLevelStringHashTable : private boost::noncopyable else n[0] <<= s; } - auto res = hash(k8); + auto res = SubMaps::T1::Hash::operator()(k8); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m1, k8, res); @@ -168,7 +170,7 @@ class TwoLevelStringHashTable : private boost::noncopyable n[1] >>= s; else n[1] <<= s; - auto res = hash(k16); + auto res = SubMaps::T2::Hash::operator()(k16); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m2, k16, res); @@ -182,14 +184,14 @@ class TwoLevelStringHashTable : private boost::noncopyable n[2] >>= s; else n[2] <<= s; - auto res = hash(k24); + auto res = SubMaps::T3::Hash::operator()(k24); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m3, k24, res); } default: { - auto res = hash(x); + auto res = SubMaps::Ts::Hash::operator()(x); auto buck = getBucketFromHash(res); return func(self.impls[buck].ms, std::forward(key_holder), res); } @@ -202,9 +204,9 @@ class TwoLevelStringHashTable : private boost::noncopyable dispatch(*this, key_holder, typename Impl::EmplaceCallable{it, inserted}); } - LookupResult ALWAYS_INLINE find(const Key x) { return dispatch(*this, x, typename Impl::FindCallable{}); } + LookupResult ALWAYS_INLINE find(const Key & x) { return dispatch(*this, x, typename Impl::FindCallable{}); } - ConstLookupResult ALWAYS_INLINE find(const Key x) const + ConstLookupResult ALWAYS_INLINE find(const Key & x) const { return dispatch(*this, x, typename Impl::FindCallable{}); } @@ -259,6 +261,13 @@ class TwoLevelStringHashTable : private boost::noncopyable return true; } + size_t getBufferSizeInCells() const + { + size_t res = 0; + for (const auto & impl : impls) + res += impl.getBufferSizeInCells(); + return res; + } size_t getBufferSizeInBytes() const { size_t res = 0; @@ -268,3 +277,63 @@ class TwoLevelStringHashTable : private boost::noncopyable return res; } }; + +template +struct StringHashTableSubMapSelector<0, true, Data> +{ + using Hash = typename Data::SubMaps::T0::Hash; + + static typename Data::Impl::T0 & getSubMap(size_t hashval, Data & data) + { + const auto bucket = Data::getBucketFromHash(hashval); + return data.impls[bucket].m0; + } +}; + +template +struct StringHashTableSubMapSelector<1, true, Data> +{ + using Hash = typename Data::SubMaps::T1::Hash; + + static typename Data::Impl::T1 & getSubMap(size_t hashval, Data & data) + { + const auto bucket = Data::getBucketFromHash(hashval); + return data.impls[bucket].m1; + } +}; + +template +struct StringHashTableSubMapSelector<2, true, Data> +{ + using Hash = typename Data::SubMaps::T2::Hash; + + static typename Data::Impl::T2 & getSubMap(size_t hashval, Data & data) + { + const auto bucket = Data::getBucketFromHash(hashval); + return data.impls[bucket].m2; + } +}; + +template +struct StringHashTableSubMapSelector<3, true, Data> +{ + using Hash = typename Data::SubMaps::T3::Hash; + + static typename Data::Impl::T3 & getSubMap(size_t hashval, Data & data) + { + const auto bucket = Data::getBucketFromHash(hashval); + return data.impls[bucket].m3; + } +}; + +template +struct StringHashTableSubMapSelector<4, true, Data> +{ + using Hash = typename Data::SubMaps::Ts::Hash; + + static typename Data::Impl::Ts & getSubMap(size_t hashval, Data & data) + { + const auto bucket = Data::getBucketFromHash(hashval); + return data.impls[bucket].ms; + } +}; diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 7193f24eddb..3a79025f244 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -24,6 +24,7 @@ namespace DB namespace FailPoints { extern const char force_agg_on_partial_block[]; +extern const char force_agg_prefetch[]; extern const char force_agg_two_level_hash_table_before_merge[]; } // namespace FailPoints namespace tests @@ -238,16 +239,22 @@ class AggExecutorTestRunner : public ExecutorTest ColumnWithUInt64 col_pr{1, 2, 0, 3290124, 968933, 3125, 31236, 4327, 80000}; }; -#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \ - std::vector partial_blocks{true, false}; \ - for (auto partial_block : partial_blocks) \ - { \ - if (partial_block) \ - FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ - else \ - FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); +#define WRAP_FOR_AGG_FAILPOINTS_START \ + std::vector enables{true, false}; \ + for (auto enable : enables) \ + { \ + if (enable) \ + { \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); \ + } \ + else \ + { \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); \ + } -#define WRAP_FOR_AGG_PARTIAL_BLOCK_END } +#define WRAP_FOR_AGG_FAILPOINTS_END } /// Guarantee the correctness of group by TEST_F(AggExecutorTestRunner, GroupBy) @@ -363,9 +370,9 @@ try FailPointHelper::enableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); else FailPointHelper::disableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -429,9 +436,9 @@ try FailPointHelper::enableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); else FailPointHelper::disableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -464,9 +471,9 @@ try for (size_t i = 0; i < test_num; ++i) { request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } /// Min function tests @@ -485,9 +492,9 @@ try for (size_t i = 0; i < test_num; ++i) { request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -545,9 +552,9 @@ try { request = buildDAGRequest(std::make_pair(db_name, table_name), {agg_funcs[i]}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -615,9 +622,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } { @@ -629,9 +636,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } for (auto collation_id : {0, static_cast(TiDB::ITiDBCollator::BINARY)}) @@ -668,9 +675,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -683,9 +690,9 @@ try executeAndAssertColumnsEqual(request, {{toNullableVec({"banana"})}}); request = context.scan("aggnull_test", "t1").aggregation({}, {col("s1")}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {{toNullableVec("s1", {{}, "banana"})}}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } CATCH @@ -697,9 +704,9 @@ try = {toNullableVec({3}), toNullableVec({1}), toVec({6})}; auto test_single_function = [&](size_t index) { auto request = context.scan("test_db", "test_table").aggregation({functions[index]}, {}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {functions_result[index]}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END }; for (size_t i = 0; i < functions.size(); ++i) test_single_function(i); @@ -720,9 +727,9 @@ try results.push_back(functions_result[k]); auto request = context.scan("test_db", "test_table").aggregation(funcs, {}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, results); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END funcs.pop_back(); results.pop_back(); @@ -758,9 +765,9 @@ try context.context->setSetting( "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -791,7 +798,7 @@ try "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); context.context->setSetting("max_block_size", Field(static_cast(block_size))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); size_t actual_row = 0; for (auto & block : blocks) @@ -800,7 +807,7 @@ try actual_row += block.rows(); } ASSERT_EQ(actual_row, expect_rows[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -914,7 +921,7 @@ try "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); context.context->setSetting("max_block_size", Field(static_cast(block_size))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) { @@ -939,7 +946,7 @@ try vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(), false)); } - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -967,18 +974,18 @@ try request = context.receive("empty_recv", 5).aggregation({Max(col("s1"))}, {col("s2")}, 5).build(context); { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } request = context.scan("test_db", "empty_table") .aggregation({Count(lit(Field(static_cast(1))))}, {}) .build(context); { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {toVec({0})}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -1035,6 +1042,24 @@ try toVec("col_tinyint", col_data_tinyint), }); + std::vector max_block_sizes{1, 2, DEFAULT_BLOCK_SIZE}; + std::vector two_level_thresholds{0, 1}; + + context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(0))); +#define WRAP_FOR_AGG_STRING_TEST_BEGIN \ + for (const auto & max_block_size : max_block_sizes) \ + { \ + for (const auto & two_level_threshold : two_level_thresholds) \ + { \ + context.context->setSetting( \ + "group_by_two_level_threshold", \ + Field(static_cast(two_level_threshold))); \ + context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); +#define WRAP_FOR_AGG_STRING_TEST_END \ + } \ + } + + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); { // case-1: select count(1), col_tinyint from t group by col_int, col_tinyint // agg method: keys64(AggregationMethodKeysFixed) @@ -1049,7 +1074,9 @@ try toNullableVec("first_row(col_tinyint)", ColumnWithNullableInt8{0, 1, 2, 3}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3}), toVec("col_tinyint", ColumnWithInt8{0, 1, 2, 3})}; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } { @@ -1065,7 +1092,9 @@ try = {toVec("count(1)", ColumnWithUInt64{rows_per_type, rows_per_type, rows_per_type, rows_per_type}), toNullableVec("first_row(col_int)", ColumnWithNullableInt32{0, 1, 2, 3}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3})}; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } { @@ -1099,7 +1128,9 @@ try toNullableVec("first_row(col_string_with_collator)", ColumnWithNullableString{"a", "b", "c", "d"}), toVec("col_string_with_collator", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } { @@ -1116,7 +1147,9 @@ try toVec("count(1)", ColumnWithUInt64{rows_per_type, rows_per_type, rows_per_type, rows_per_type}), toVec("first_row(col_string_with_collator)", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } // case-5: none @@ -1138,7 +1171,9 @@ try toVec("col_int", ColumnWithInt32{0, 1, 2, 3}), toVec("col_string_no_collator", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } { @@ -1155,8 +1190,13 @@ try toNullableVec("first_row(col_string_with_collator)", ColumnWithNullableString{"a", "b", "c", "d"}), toVec("col_string_with_collator", ColumnWithString{"a", "b", "c", "d"}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3})}; + WRAP_FOR_AGG_STRING_TEST_BEGIN executeAndAssertColumnsEqual(request, expected); + WRAP_FOR_AGG_STRING_TEST_END } + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); +#undef WRAP_FOR_AGG_STRING_TEST_BEGIN +#undef WRAP_FOR_AGG_STRING_TEST_END } CATCH @@ -1187,13 +1227,9 @@ try context .addExchangeReceiver("exchange_receiver_1_concurrency", column_infos, column_data, 1, partition_column_infos); - context - .addExchangeReceiver("exchange_receiver_3_concurrency", column_infos, column_data, 3, partition_column_infos); - context - .addExchangeReceiver("exchange_receiver_5_concurrency", column_infos, column_data, 5, partition_column_infos); context .addExchangeReceiver("exchange_receiver_10_concurrency", column_infos, column_data, 10, partition_column_infos); - std::vector exchange_receiver_concurrency = {1, 3, 5, 10}; + std::vector exchange_receiver_concurrency = {1, 10}; auto gen_request = [&](size_t exchange_concurrency) { return context @@ -1205,15 +1241,15 @@ try auto baseline = executeStreams(gen_request(1), 1); for (size_t exchange_concurrency : exchange_receiver_concurrency) { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(gen_request(exchange_concurrency), baseline); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END +#undef WRAP_FOR_AGG_FAILPOINTS_START +#undef WRAP_FOR_AGG_FAILPOINTS_END } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index 69b2242df3d..3c4020db45e 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -39,6 +39,7 @@ extern const char exception_before_mpp_root_task_run[]; extern const char exception_during_mpp_non_root_task_run[]; extern const char exception_during_mpp_root_task_run[]; extern const char exception_during_query_run[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints namespace tests @@ -1369,6 +1370,7 @@ try FailPoints::exception_during_mpp_non_root_task_run, FailPoints::exception_during_mpp_root_task_run, FailPoints::exception_during_query_run, + FailPoints::force_agg_prefetch, }; size_t query_index = 0; for (const auto & failpoint : failpoint_names) @@ -1843,6 +1845,7 @@ try auto_pass_through_test_data.nullable_high_ndv_tbl_name, auto_pass_through_test_data.nullable_medium_ndv_tbl_name, }; + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); for (const auto & tbl_name : workloads) { const String db_name = auto_pass_through_test_data.db_name; @@ -1868,6 +1871,7 @@ try res_no_pass_through); WRAP_FOR_SERVER_TEST_END } + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); } CATCH diff --git a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp index b19aaf03c4c..583e6e038fa 100644 --- a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp +++ b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp @@ -23,6 +23,7 @@ namespace FailPoints { extern const char force_agg_on_partial_block[]; extern const char force_thread_0_no_agg_spill[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints namespace tests @@ -37,16 +38,22 @@ class SpillAggregationTestRunner : public DB::tests::ExecutorTest } }; -#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \ - std::vector partial_blocks{true, false}; \ - for (auto partial_block : partial_blocks) \ - { \ - if (partial_block) \ - FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ - else \ - FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); +#define WRAP_FOR_AGG_FAILPOINTS_START \ + std::vector enables{true, false}; \ + for (auto enable : enables) \ + { \ + if (enable) \ + { \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); \ + } \ + else \ + { \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); \ + } -#define WRAP_FOR_AGG_PARTIAL_BLOCK_END } +#define WRAP_FOR_AGG_FAILPOINTS_END } #define WRAP_FOR_AGG_THREAD_0_NO_SPILL_START \ for (auto thread_0_no_spill : {true, false}) \ @@ -114,13 +121,13 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); /// don't use `executeAndAssertColumnsEqual` since it takes too long to run /// test single thread aggregation - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1)); /// test parallel aggregation ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END /// enable spill and use small max_cached_data_bytes_in_spiller context.context->setSetting("max_cached_data_bytes_in_spiller", Field(static_cast(total_data_size / 200))); /// test single thread aggregation @@ -262,7 +269,7 @@ try Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) @@ -289,7 +296,7 @@ try false)); } WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -369,6 +376,7 @@ try { for (const auto & agg_func : agg_funcs) { + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); context.setCollation(collator_id); const auto * current_collator = TiDB::ITiDBCollator::getCollator(collator_id); ASSERT_TRUE(current_collator != nullptr); @@ -417,7 +425,7 @@ try Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) @@ -444,7 +452,7 @@ try false)); } WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -518,9 +526,9 @@ try /// don't use `executeAndAssertColumnsEqual` since it takes too long to run auto request = gen_request(exchange_concurrency); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START ASSERT_COLUMNS_EQ_UR(baseline, executeStreams(request, exchange_concurrency)); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -528,8 +536,8 @@ CATCH #undef WRAP_FOR_SPILL_TEST_BEGIN #undef WRAP_FOR_SPILL_TEST_END -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END +#undef WRAP_FOR_AGG_FAILPOINTS_START +#undef WRAP_FOR_AGG_FAILPOINTS_END } // namespace tests } // namespace DB diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index f25c22717e8..1190121ddb4 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -43,6 +43,7 @@ extern const char random_aggregate_create_state_failpoint[]; extern const char random_aggregate_merge_failpoint[]; extern const char force_agg_on_partial_block[]; extern const char random_fail_in_resize_callback[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints #define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME @@ -665,7 +666,102 @@ void NO_INLINE Aggregator::executeImpl( { typename Method::State state(agg_process_info.key_columns, key_sizes, collators); - executeImplBatch(method, state, aggregates_pool, agg_process_info); + // start_row!=0 and stringHashTableRecoveryInfo not empty and cannot be true at the same time. + RUNTIME_CHECK(!(agg_process_info.start_row != 0 && !agg_process_info.stringHashTableRecoveryInfoEmpty())); + +#ifndef NDEBUG + bool disable_prefetch = (method.data.getBufferSizeInCells() < 8192); + fiu_do_on(FailPoints::force_agg_prefetch, { disable_prefetch = false; }); +#else + const bool disable_prefetch = (method.data.getBufferSizeInCells() < 8192); +#endif + + // key_serialized and key_string(StringHashMap) needs column-wise handling for prefetch. + // Because: + // 1. StringHashMap(key_string) is composed by 5 submaps, so prefetch needs to be done for each specific submap. + // 2. getKeyHolder of key_serialized have to copy real data into Arena. + // It means we better getKeyHolder for all Columns once and then use it both for getHash() and emplaceKey(). + // 3. For other group by key(key_int8/16/32/...), it's ok to use row-wise handling even prefetch is enabled. + // But getHashVals() still needs to be column-wise. + if constexpr (Method::State::is_serialized_key) + { + // TODO: batch serialize method for Columns is still under development. + // if (!disable_prefetch) + // executeImplSerializedKeyByCol(); + // else + // executeImplByRow(method, state, aggregates_pool, agg_process_info); + executeImplByRow(method, state, aggregates_pool, agg_process_info); + } + else if constexpr (Method::Data::is_string_hash_map) + { + // If agg_process_info.start_row != 0, it means the computation process of the current block was interrupted by resize exception in executeImplByRow. + // For clarity and simplicity of implementation, the processing functions for column-wise and row-wise methods handle the entire block independently. + // A block will not be processed first by the row-wise method and then by the column-wise method, or vice-versa. + if (!disable_prefetch && likely(agg_process_info.start_row == 0)) + executeImplStringHashMapByCol( + method, + state, + aggregates_pool, + agg_process_info); + else + executeImplByRow(method, state, aggregates_pool, agg_process_info); + } + else + { + if (disable_prefetch) + executeImplByRow(method, state, aggregates_pool, agg_process_info); + else + executeImplByRow(method, state, aggregates_pool, agg_process_info); + } +} + +template +void getHashVals( + size_t start_row, + size_t end_row, + const Data & data, + const State & state, + std::vector & sort_key_containers, + Arena * pool, + std::vector & hashvals) +{ + hashvals.resize(state.total_rows); + for (size_t i = start_row; i < end_row; ++i) + { + hashvals[i] = state.getHash(data, i, *pool, sort_key_containers); + } +} + +template +std::optional::ResultType> Aggregator::emplaceOrFindKey( + Method & method, + typename Method::State & state, + size_t index, + Arena & aggregates_pool, + std::vector & sort_key_containers, + const std::vector & hashvals) const +{ + try + { + if constexpr (only_lookup) + return state.template findKey( + method.data, + index, + aggregates_pool, + sort_key_containers, + hashvals); + else + return state.template emplaceKey( + method.data, + index, + aggregates_pool, + sort_key_containers, + hashvals); + } + catch (ResizeException &) + { + return {}; + } } template @@ -689,22 +785,117 @@ std::optional::Res } } -template -ALWAYS_INLINE void Aggregator::executeImplBatch( +// This is only used by executeImplStringHashMapByCol. +// It will choose specifix submap of StringHashMap then do emplace/find. +// StringKeyType can be StringRef/StringKey8/StringKey16/StringKey24/ArenaKeyHolder. +template < + size_t SubMapIndex, + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool zero_agg_func_size, + typename Data, + typename State, + typename StringKeyType> +size_t Aggregator::emplaceOrFindStringKey( + Data & data, + State & state, + const std::vector & key_infos, + std::vector & key_datas, + Arena & aggregates_pool, + std::vector & places, + AggProcessInfo & agg_process_info) const +{ + static_assert(!(collect_hit_rate && only_lookup)); + assert(key_infos.size() == key_datas.size()); + + using Hash = typename StringHashTableSubMapSelector>::Hash; + std::vector hashvals(key_infos.size(), 0); + for (size_t i = 0; i < key_infos.size(); ++i) + hashvals[i] = Hash::operator()(keyHolderGetKey(key_datas[i])); + + // alloc 0 bytes is useful when agg func size is zero. + AggregateDataPtr agg_state = aggregates_pool.alloc(0); + for (size_t i = 0; i < key_infos.size(); ++i) + { + try + { + if constexpr (only_lookup) + { + auto find_result + = state.template findStringKey(data, i, key_datas, hashvals); + if (find_result.isFound()) + { + agg_state = find_result.getMapped(); + } + else + { + agg_process_info.not_found_rows.push_back(key_infos[i]); + } + } + else + { + auto emplace_result + = state.template emplaceStringKey(data, i, key_datas, hashvals); + if (emplace_result.isInserted()) + { + if constexpr (zero_agg_func_size) + { + emplace_result.setMapped(agg_state); + } + else + { + emplace_result.setMapped(nullptr); + + agg_state + = aggregates_pool.alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); + createAggregateStates(agg_state); + + emplace_result.setMapped(agg_state); + } + } + else + { + if constexpr (!zero_agg_func_size) + agg_state = emplace_result.getMapped(); + + if constexpr (collect_hit_rate) + ++agg_process_info.hit_row_cnt; + } + if constexpr (!zero_agg_func_size) + places[i] = agg_state; + } + } + catch (ResizeException &) + { + return i; + } + } + return key_infos.size(); +} + +template +ALWAYS_INLINE void Aggregator::executeImplByRow( Method & method, typename Method::State & state, Arena * aggregates_pool, AggProcessInfo & agg_process_info) const { + LOG_TRACE(log, "executeImplByRow"); // collect_hit_rate and only_lookup cannot be true at the same time. static_assert(!(collect_hit_rate && only_lookup)); + // If agg_process_info.stringHashTableRecoveryInfoEmpty() is false, it means the current block was + // handled by executeImplStringHashMapByCol(column-wise) before, and resize execption happened. + // This situation is unexpected because for the sake of clarity, we assume that a block will be **fully** processed + // either column-wise or row-wise and cannot be split for processing. + RUNTIME_CHECK(agg_process_info.stringHashTableRecoveryInfoEmpty()); std::vector sort_key_containers; sort_key_containers.resize(params.keys_size, ""); - size_t agg_size = agg_process_info.end_row - agg_process_info.start_row; + size_t rows = agg_process_info.end_row - agg_process_info.start_row; fiu_do_on(FailPoints::force_agg_on_partial_block, { - if (agg_size > 0 && agg_process_info.start_row == 0) - agg_size = std::max(agg_size / 2, 1); + if (rows > 0 && agg_process_info.start_row == 0) + rows = std::max(rows / 2, 1); }); /// Optimization for special case when there are no aggregate functions. @@ -712,38 +903,66 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( { /// For all rows. AggregateDataPtr place = aggregates_pool->alloc(0); - for (size_t i = 0; i < agg_size; ++i) +#define HANDLE_AGG_EMPLACE_RESULT \ + if likely (emplace_result_hold.has_value()) \ + { \ + if constexpr (collect_hit_rate) \ + { \ + ++agg_process_info.hit_row_cnt; \ + } \ + \ + if constexpr (only_lookup) \ + { \ + if (!emplace_result_hold.value().isFound()) \ + agg_process_info.not_found_rows.push_back(i); \ + } \ + else \ + { \ + emplace_result_hold.value().setMapped(place); \ + } \ + processed_rows = i; \ + } \ + else \ + { \ + LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); \ + break; \ + } + + std::vector hashvals; + std::optional processed_rows; + if constexpr (enable_prefetch) { - auto emplace_result_hold = emplaceOrFindKey( - method, - state, + getHashVals( agg_process_info.start_row, - *aggregates_pool, - sort_key_containers); - if likely (emplace_result_hold.has_value()) + agg_process_info.end_row, + method.data, + state, + sort_key_containers, + aggregates_pool, + hashvals); + } + + for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + rows; ++i) + { + if constexpr (enable_prefetch) { - if constexpr (collect_hit_rate) - { - ++agg_process_info.hit_row_cnt; - } + auto emplace_result_hold + = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers, hashvals); - if constexpr (only_lookup) - { - if (!emplace_result_hold.value().isFound()) - agg_process_info.not_found_rows.push_back(i); - } - else - { - emplace_result_hold.value().setMapped(place); - } - ++agg_process_info.start_row; + HANDLE_AGG_EMPLACE_RESULT } else { - LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); - break; + auto emplace_result_hold + = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers); + + HANDLE_AGG_EMPLACE_RESULT } } + + if likely (processed_rows) + agg_process_info.start_row = *processed_rows + 1; +#undef HANDLE_AGG_EMPLACE_RESULT return; } @@ -755,7 +974,7 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( { inst->batch_that->addBatchLookupTable8( agg_process_info.start_row, - agg_size, + rows, reinterpret_cast(method.data.data()), inst->state_offset, [&](AggregateDataPtr & aggregate_data) { @@ -767,12 +986,12 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( inst->batch_arguments, aggregates_pool); } - agg_process_info.start_row += agg_size; + agg_process_info.start_row += rows; // For key8, assume all rows are hit. No need to do state switch for auto pass through hashagg. // Because HashMap of key8 is basically a vector of size 256. if constexpr (collect_hit_rate) - agg_process_info.hit_row_cnt = agg_size; + agg_process_info.hit_row_cnt = rows; // Because all rows are hit, so state will not switch to Selective. if constexpr (only_lookup) @@ -781,60 +1000,84 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } /// Generic case. - - std::unique_ptr places(new AggregateDataPtr[agg_size]); + std::unique_ptr places(new AggregateDataPtr[rows]); std::optional processed_rows; - for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i) +#define HANDLE_AGG_EMPLACE_RESULT \ + if unlikely (!emplace_result_holder.has_value()) \ + { \ + LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); \ + break; \ + } \ + \ + auto & emplace_result = emplace_result_holder.value(); \ + \ + if constexpr (only_lookup) \ + { \ + if (emplace_result.isFound()) \ + { \ + aggregate_data = emplace_result.getMapped(); \ + } \ + else \ + { \ + agg_process_info.not_found_rows.push_back(i); \ + } \ + } \ + else \ + { \ + if (emplace_result.isInserted()) \ + { \ + emplace_result.setMapped(nullptr); \ + \ + aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); \ + createAggregateStates(aggregate_data); \ + \ + emplace_result.setMapped(aggregate_data); \ + } \ + else \ + { \ + aggregate_data = emplace_result.getMapped(); \ + \ + if constexpr (collect_hit_rate) \ + ++agg_process_info.hit_row_cnt; \ + } \ + } \ + \ + places[i - agg_process_info.start_row] = aggregate_data; \ + processed_rows = i; + + std::vector hashvals; + if constexpr (enable_prefetch) + { + getHashVals( + agg_process_info.start_row, + agg_process_info.end_row, + method.data, + state, + sort_key_containers, + aggregates_pool, + hashvals); + } + + for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + rows; ++i) { AggregateDataPtr aggregate_data = nullptr; - - auto emplace_result_holder - = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers); - if unlikely (!emplace_result_holder.has_value()) + if constexpr (enable_prefetch) { - LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); - break; - } + auto emplace_result_holder + = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers, hashvals); - auto & emplace_result = emplace_result_holder.value(); - - if constexpr (only_lookup) - { - if (emplace_result.isFound()) - { - aggregate_data = emplace_result.getMapped(); - } - else - { - agg_process_info.not_found_rows.push_back(i); - } + HANDLE_AGG_EMPLACE_RESULT } else { - /// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key. - if (emplace_result.isInserted()) - { - /// exception-safety - if you can not allocate memory or create states, then destructors will not be called. - emplace_result.setMapped(nullptr); - - aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); - createAggregateStates(aggregate_data); + auto emplace_result_holder + = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers); - emplace_result.setMapped(aggregate_data); - } - else - { - aggregate_data = emplace_result.getMapped(); - - if constexpr (collect_hit_rate) - ++agg_process_info.hit_row_cnt; - } + HANDLE_AGG_EMPLACE_RESULT } - - places[i - agg_process_info.start_row] = aggregate_data; - processed_rows = i; } +#undef HANDLE_AGG_EMPLACE_RESULT if (processed_rows) { @@ -854,6 +1097,230 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } } +#define M(SUBMAPINDEX) \ + template \ + ALWAYS_INLINE inline void setupExceptionRecoveryInfoForStringHashTable( \ + Aggregator::AggProcessInfo & agg_process_info, \ + size_t row, \ + const std::vector & key_infos, \ + const std::vector & key_datas, \ + std::integral_constant) \ + { \ + agg_process_info.submap_m##SUBMAPINDEX##_infos \ + = std::vector(key_infos.begin() + row, key_infos.end()); \ + agg_process_info.submap_m##SUBMAPINDEX##_datas \ + = std::vector(key_datas.begin() + row, key_datas.end()); \ + } + +M(0) +M(1) +M(2) +M(3) +M(4) + +#undef M + +// prefetch/empalce each specifix submap directly instead of accessing StringHashMap interface, +// which is better for performance. +// NOTE: this function is column-wise, which means sort key buffer cannot be reused. +// This buffer will not be release until this block is processed done. +template +ALWAYS_INLINE void Aggregator::executeImplStringHashMapByCol( + Method & method, + typename Method::State & state, + Arena * aggregates_pool, + AggProcessInfo & agg_process_info) const +{ + LOG_TRACE(log, "executeImplStringHashMapByCol"); + // collect_hit_rate and only_lookup cannot be true at the same time. + static_assert(!(collect_hit_rate && only_lookup)); + static_assert(Method::Data::is_string_hash_map); + +#define M(SUBMAPINDEX) \ + RUNTIME_CHECK( \ + agg_process_info.submap_m##SUBMAPINDEX##_infos.size() \ + == agg_process_info.submap_m##SUBMAPINDEX##_datas.size()); + + M(0) + M(1) + M(2) + M(3) + M(4) +#undef M + + const size_t rows = agg_process_info.end_row - agg_process_info.start_row; + auto sort_key_pool = std::make_unique(); + std::vector sort_key_containers; + +#define M(INFO, DATA, KEYTYPE) \ + std::vector(INFO); \ + std::vector(DATA); + + M(key0_infos, key0_datas, StringKey0) + M(key8_infos, key8_datas, StringKey8) + M(key16_infos, key16_datas, StringKey16) + M(key24_infos, key24_datas, StringKey24) + M(key_str_infos, key_str_datas, ArenaKeyHolder) +#undef M + + // If no resize exception happens, so this is a new Block. + // If resize exception happens, start_row has already been set to zero at the end of this function. + RUNTIME_CHECK_MSG( + agg_process_info.start_row == 0, + "unexpected agg_process_info.start_row: {}, end_row: {}", + agg_process_info.start_row, + agg_process_info.end_row); + + if likely (agg_process_info.stringHashTableRecoveryInfoEmpty()) + { + // sort_key_pool should already been reset by AggProcessInfo::restBlock() + RUNTIME_CHECK(!agg_process_info.sort_key_pool); + + const size_t reserve_size = rows / 4; + +#define M(INFO, DATA, SUBMAPINDEX, KEYTYPE) \ + (INFO).reserve(reserve_size); \ + (DATA).reserve(reserve_size); \ + auto dispatch_callback_key##SUBMAPINDEX \ + = [&INFO, &DATA](const KEYTYPE & key, size_t row) { /* NOLINT(bugprone-macro-parentheses) */ \ + (INFO).push_back(row); \ + (DATA).push_back(key); \ + }; + + M(key0_infos, key0_datas, 0, StringKey0) + M(key8_infos, key8_datas, 8, StringKey8) + M(key16_infos, key16_datas, 16, StringKey16) + M(key24_infos, key24_datas, 24, StringKey24) + // Argument type is ArenaKeyHolder instead of StringRef, + // because it will only be persisted when insert into HashTable. + M(key_str_infos, key_str_datas, str, ArenaKeyHolder) +#undef M + + for (size_t i = 0; i < rows; ++i) + { + // Use Arena for collation sort key, because we are doing agg in column-wise way. + // So a big arena is needed to store decoded key, and we can avoid resize std::string by using Arena. + auto key_holder = state.getKeyHolder(i, aggregates_pool, sort_key_pool.get()); + dispatchStringHashTable( + i, + key_holder, + dispatch_callback_key0, + dispatch_callback_key8, + dispatch_callback_key16, + dispatch_callback_key24, + dispatch_callback_keystr); + } + } + else + { +#define M(INFO, DATA, SUBMAPINDEX) \ + (INFO) = agg_process_info.submap_m##SUBMAPINDEX##_infos; \ + (DATA) = agg_process_info.submap_m##SUBMAPINDEX##_datas; + + M(key0_infos, key0_datas, 0) + M(key8_infos, key8_datas, 1) + M(key16_infos, key16_datas, 2) + M(key24_infos, key24_datas, 3) + M(key_str_infos, key_str_datas, 4) +#undef M + } + + std::vector key0_places(key0_infos.size(), nullptr); + std::vector key8_places(key8_infos.size(), nullptr); + std::vector key16_places(key16_infos.size(), nullptr); + std::vector key24_places(key24_infos.size(), nullptr); + std::vector key_str_places(key_str_infos.size(), nullptr); + + bool got_resize_exception = false; + size_t emplaced_index = 0; + bool zero_agg_func_size = (params.aggregates_size == 0); + +#define M(INDEX, INFO, DATA, PLACES) \ + if (!got_resize_exception && !(INFO).empty()) \ + { \ + if (zero_agg_func_size) \ + emplaced_index = emplaceOrFindStringKey( \ + method.data, \ + state, \ + (INFO), \ + (DATA), \ + *aggregates_pool, \ + (PLACES), \ + agg_process_info); \ + else \ + emplaced_index = emplaceOrFindStringKey( \ + method.data, \ + state, \ + (INFO), \ + (DATA), \ + *aggregates_pool, \ + (PLACES), \ + agg_process_info); \ + if unlikely (emplaced_index != (INFO).size()) \ + got_resize_exception = true; \ + } \ + else \ + { \ + emplaced_index = 0; \ + } \ + setupExceptionRecoveryInfoForStringHashTable( \ + agg_process_info, \ + emplaced_index, \ + (INFO), \ + (DATA), \ + std::integral_constant{}); + + M(0, key0_infos, key0_datas, key0_places) + M(1, key8_infos, key8_datas, key8_places) + M(2, key16_infos, key16_datas, key16_places) + M(3, key24_infos, key24_datas, key24_places) + M(4, key_str_infos, key_str_datas, key_str_places) +#undef M + + if (!zero_agg_func_size) + { + std::vector places(rows, nullptr); +#define M(INFO, PLACES) \ + for (size_t i = 0; i < (INFO).size(); ++i) \ + { \ + const auto row = (INFO)[i]; \ + places[row] = (PLACES)[i]; \ + } + + M(key0_infos, key0_places) + M(key8_infos, key8_places) + M(key16_infos, key16_places) + M(key24_infos, key24_places) + M(key_str_infos, key_str_places) +#undef M + + for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that; + ++inst) + { + inst->batch_that->addBatch( + agg_process_info.start_row, + rows, + &places[0], + inst->state_offset, + inst->batch_arguments, + aggregates_pool); + } + } + + if unlikely (got_resize_exception) + { + RUNTIME_CHECK(!agg_process_info.stringHashTableRecoveryInfoEmpty()); + agg_process_info.sort_key_pool = std::move(sort_key_pool); + // For StringHashTable, start_row is meanless, instead submap_mx_infos/submap_mx_datas are used. + // So set it to zero when got_resize_exception. + agg_process_info.start_row = 0; + } + else + { + agg_process_info.start_row = agg_process_info.end_row; + } +} + void NO_INLINE Aggregator::executeWithoutKeyImpl(AggregatedDataWithoutKey & res, AggProcessInfo & agg_process_info, Arena * arena) { @@ -876,7 +1343,6 @@ Aggregator::executeWithoutKeyImpl(AggregatedDataWithoutKey & res, AggProcessInfo agg_process_info.start_row += agg_size; } - void Aggregator::prepareAggregateInstructions( Columns columns, AggregateColumns & aggregate_columns, diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 381bfba8462..6099d6b1655 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -80,7 +80,7 @@ using AggregatedDataWithUInt16Key = FixedImplicitZeroHashMap>; using AggregatedDataWithUInt64Key = HashMap>; -using AggregatedDataWithShortStringKey = StringHashMap; +using AggregatedDataWithShortStringKey = StringHashMap>; using AggregatedDataWithStringKey = HashMapWithSavedHash; using AggregatedDataWithInt256Key = HashMap>; @@ -88,16 +88,17 @@ using AggregatedDataWithInt256Key = HashMap>; using AggregatedDataWithKeys256 = HashMap>; -using AggregatedDataWithUInt32KeyTwoLevel = TwoLevelHashMap>; -using AggregatedDataWithUInt64KeyTwoLevel = TwoLevelHashMap>; +using AggregatedDataWithUInt32KeyTwoLevel = TwoLevelHashMap>; +using AggregatedDataWithUInt64KeyTwoLevel = TwoLevelHashMap>; -using AggregatedDataWithInt256KeyTwoLevel = TwoLevelHashMap>; +using AggregatedDataWithInt256KeyTwoLevel = TwoLevelHashMap>; -using AggregatedDataWithShortStringKeyTwoLevel = TwoLevelStringHashMap; +using AggregatedDataWithShortStringKeyTwoLevel + = TwoLevelStringHashMap>; using AggregatedDataWithStringKeyTwoLevel = TwoLevelHashMapWithSavedHash; -using AggregatedDataWithKeys128TwoLevel = TwoLevelHashMap>; -using AggregatedDataWithKeys256TwoLevel = TwoLevelHashMap>; +using AggregatedDataWithKeys128TwoLevel = TwoLevelHashMap>; +using AggregatedDataWithKeys256TwoLevel = TwoLevelHashMap>; /** Variants with better hash function, using more than 32 bits for hash. * Using for merging phase of external aggregation, where number of keys may be far greater than 4 billion, @@ -125,7 +126,7 @@ struct AggregationMethodOneNumber AggregationMethodOneNumber() = default; template - explicit AggregationMethodOneNumber(const Other & other) + explicit AggregationMethodOneNumber(Other & other) : data(other.data) {} @@ -179,7 +180,7 @@ struct AggregationMethodString AggregationMethodString() = default; template - explicit AggregationMethodString(const Other & other) + explicit AggregationMethodString(Other & other) : data(other.data) {} @@ -227,12 +228,11 @@ struct AggregationMethodStringNoCache AggregationMethodStringNoCache() = default; template - explicit AggregationMethodStringNoCache(const Other & other) + explicit AggregationMethodStringNoCache(Other & other) : data(other.data) {} - using State = ColumnsHashing:: - HashMethodString; + using State = ColumnsHashing::HashMethodString; template struct EmplaceOrFindKeyResult { @@ -276,7 +276,7 @@ struct AggregationMethodOneKeyStringNoCache AggregationMethodOneKeyStringNoCache() = default; template - explicit AggregationMethodOneKeyStringNoCache(const Other & other) + explicit AggregationMethodOneKeyStringNoCache(Other & other) : data(other.data) {} @@ -326,7 +326,7 @@ struct AggregationMethodMultiStringNoCache AggregationMethodMultiStringNoCache() = default; template - explicit AggregationMethodMultiStringNoCache(const Other & other) + explicit AggregationMethodMultiStringNoCache(Other & other) : data(other.data) {} @@ -356,7 +356,7 @@ struct AggregationMethodFastPathTwoKeysNoCache AggregationMethodFastPathTwoKeysNoCache() = default; template - explicit AggregationMethodFastPathTwoKeysNoCache(const Other & other) + explicit AggregationMethodFastPathTwoKeysNoCache(Other & other) : data(other.data) {} @@ -476,7 +476,7 @@ struct AggregationMethodFixedString AggregationMethodFixedString() = default; template - explicit AggregationMethodFixedString(const Other & other) + explicit AggregationMethodFixedString(Other & other) : data(other.data) {} @@ -524,11 +524,11 @@ struct AggregationMethodFixedStringNoCache AggregationMethodFixedStringNoCache() = default; template - explicit AggregationMethodFixedStringNoCache(const Other & other) + explicit AggregationMethodFixedStringNoCache(Other & other) : data(other.data) {} - using State = ColumnsHashing::HashMethodFixedString; + using State = ColumnsHashing::HashMethodFixedString; template struct EmplaceOrFindKeyResult { @@ -573,7 +573,7 @@ struct AggregationMethodKeysFixed AggregationMethodKeysFixed() = default; template - explicit AggregationMethodKeysFixed(const Other & other) + explicit AggregationMethodKeysFixed(Other & other) : data(other.data) {} @@ -680,7 +680,7 @@ struct AggregationMethodSerialized AggregationMethodSerialized() = default; template - explicit AggregationMethodSerialized(const Other & other) + explicit AggregationMethodSerialized(Other & other) : data(other.data) {} @@ -1319,11 +1319,32 @@ class Aggregator size_t hit_row_cnt = 0; std::vector not_found_rows; + // For StringHashTable, when resize exception happens, the process will be interrupted. + // So we need these infos to continue. + std::vector submap_m0_infos{}; + std::vector submap_m1_infos{}; + std::vector submap_m2_infos{}; + std::vector submap_m3_infos{}; + std::vector submap_m4_infos{}; + std::vector submap_m0_datas{}; + std::vector submap_m1_datas{}; + std::vector submap_m2_datas{}; + std::vector submap_m3_datas{}; + std::vector submap_m4_datas{}; + std::unique_ptr sort_key_pool; + void prepareForAgg(); bool allBlockDataHandled() const { assert(start_row <= end_row); - return start_row == end_row || aggregator->isCancelled(); + // submap_mx_infos.size() and submap_mx_datas.size() are always equal. + // So only need to check submap_mx_infos is enough. + return (start_row == end_row && stringHashTableRecoveryInfoEmpty()) || aggregator->isCancelled(); + } + bool stringHashTableRecoveryInfoEmpty() const + { + return submap_m0_infos.empty() && submap_m1_infos.empty() && submap_m3_infos.empty() + && submap_m4_infos.empty(); } void resetBlock(const Block & block_) { @@ -1337,6 +1358,8 @@ class Aggregator hit_row_cnt = 0; not_found_rows.clear(); not_found_rows.reserve(block_.rows() / 2); + + sort_key_pool.reset(); } }; @@ -1454,13 +1477,29 @@ class Aggregator AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const; - template - void executeImplBatch( + template + void executeImplByRow( + Method & method, + typename Method::State & state, + Arena * aggregates_pool, + AggProcessInfo & agg_process_info) const; + + template + void executeImplStringHashMapByCol( Method & method, typename Method::State & state, Arena * aggregates_pool, AggProcessInfo & agg_process_info) const; + template + std::optional::ResultType> emplaceOrFindKey( + Method & method, + typename Method::State & state, + size_t index, + Arena & aggregates_pool, + std::vector & sort_key_containers, + const std::vector & hashvals) const; + template std::optional::ResultType> emplaceOrFindKey( Method & method, @@ -1469,6 +1508,24 @@ class Aggregator Arena & aggregates_pool, std::vector & sort_key_containers) const; + template < + size_t SubMapIndex, + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool zero_agg_func_size, + typename Data, + typename State, + typename StringKeyType> + size_t emplaceOrFindStringKey( + Data & data, + State & state, + const std::vector & key_infos, + std::vector & key_datas, + Arena & aggregates_pool, + std::vector & places, + AggProcessInfo & agg_process_info) const; + /// For case when there are no keys (all aggregate into one row). static void executeWithoutKeyImpl(AggregatedDataWithoutKey & res, AggProcessInfo & agg_process_info, Arena * arena); diff --git a/dbms/src/Interpreters/JoinPartition.cpp b/dbms/src/Interpreters/JoinPartition.cpp index a060878c4f7..294c72c19a3 100644 --- a/dbms/src/Interpreters/JoinPartition.cpp +++ b/dbms/src/Interpreters/JoinPartition.cpp @@ -412,7 +412,7 @@ struct KeyGetterForTypeImpl template struct KeyGetterForTypeImpl { - using Type = ColumnsHashing::HashMethodString; + using Type = ColumnsHashing::HashMethodString; }; template struct KeyGetterForTypeImpl @@ -427,7 +427,7 @@ struct KeyGetterForTypeImpl template struct KeyGetterForTypeImpl { - using Type = ColumnsHashing::HashMethodFixedString; + using Type = ColumnsHashing::HashMethodFixedString; }; template struct KeyGetterForTypeImpl @@ -652,18 +652,18 @@ void NO_INLINE insertBlockIntoMapsTypeCase( insert_indexes.emplace_back(insert_index); } -#define INSERT_TO_MAP(join_partition, segment_index) \ - auto & current_map = (join_partition)->getHashMap(); \ - for (auto & s_i : (segment_index)) \ - { \ - Inserter::insert( \ - current_map, \ - key_getter, \ - stored_block, \ - s_i, \ - pool, \ - sort_key_containers, \ - probe_cache_column_threshold); \ +#define INSERT_TO_MAP(join_partition, segment_index) \ + auto & current_map = (join_partition) -> getHashMap(); \ + for (auto & s_i : (segment_index)) \ + { \ + Inserter::insert( \ + current_map, \ + key_getter, \ + stored_block, \ + s_i, \ + pool, \ + sort_key_containers, \ + probe_cache_column_threshold); \ } #define INSERT_TO_NOT_INSERTED_MAP \ diff --git a/dbms/src/Interpreters/SetVariants.h b/dbms/src/Interpreters/SetVariants.h index a1591f8c13a..5c503240b7b 100644 --- a/dbms/src/Interpreters/SetVariants.h +++ b/dbms/src/Interpreters/SetVariants.h @@ -54,7 +54,7 @@ struct SetMethodString Data data; - using State = ColumnsHashing::HashMethodString; + using State = ColumnsHashing::HashMethodString; }; template @@ -77,7 +77,7 @@ struct SetMethodFixedString Data data; - using State = ColumnsHashing::HashMethodFixedString; + using State = ColumnsHashing::HashMethodFixedString; }; namespace set_impl diff --git a/dbms/src/TiDB/Collation/Collator.cpp b/dbms/src/TiDB/Collation/Collator.cpp index bf27400f8c4..4365f1f0988 100644 --- a/dbms/src/TiDB/Collation/Collator.cpp +++ b/dbms/src/TiDB/Collation/Collator.cpp @@ -192,6 +192,11 @@ class BinCollator final : public ITiDBCollator return DB::BinCollatorSortKey(s, length); } + StringRef sortKey(const char * s, size_t length, DB::Arena &) const override + { + return DB::BinCollatorSortKey(s, length); + } + StringRef sortKeyNoTrim(const char * s, size_t length, std::string &) const override { return convertForBinCollator(s, length, nullptr); @@ -273,11 +278,54 @@ class GeneralCICollator final : public ITiDBCollator return convertImpl(s, length, container, nullptr); } + StringRef sortKey(const char * s, size_t length, DB::Arena & pool) const override + { + return convertImpl(s, length, pool, nullptr); + } + StringRef sortKeyNoTrim(const char * s, size_t length, std::string & container) const override { return convertImpl(s, length, container, nullptr); } + template + StringRef convertImpl(const char * s, size_t length, DB::Arena & pool, std::vector * lens) const + { + std::string_view v; + + if constexpr (need_trim) + v = rtrim(s, length); + else + v = std::string_view(s, length); + + const size_t size = length * sizeof(WeightType); + auto * buffer = pool.alignedAlloc(size, 16); + + size_t offset = 0; + size_t total_size = 0; + size_t v_length = v.length(); + + if constexpr (need_len) + { + if (lens->capacity() < v_length) + lens->reserve(v_length); + lens->resize(0); + } + + while (offset < v_length) + { + auto c = decodeChar(s, offset); + auto sk = weight(c); + buffer[total_size++] = static_cast(sk >> 8); + buffer[total_size++] = static_cast(sk); + + if constexpr (need_len) + lens->push_back(2); + } + + return StringRef(buffer, total_size); + } + template StringRef convertImpl(const char * s, size_t length, std::string & container, std::vector * lens) const { @@ -479,11 +527,65 @@ class UCACICollator final : public ITiDBCollator return convertImpl(s, length, container, nullptr); } + StringRef sortKey(const char * s, size_t length, DB::Arena & pool) const override + { + return convertImpl(s, length, pool, nullptr); + } + StringRef sortKeyNoTrim(const char * s, size_t length, std::string & container) const override { return convertImpl(s, length, container, nullptr); } + // Use Arena to store decoded string. Normally it's used by column-wise Agg/Join, + // because column-wise process cannot reuse string container. + template + StringRef convertImpl(const char * s, size_t length, DB::Arena & pool, std::vector * lens) const + { + std::string_view v; + + if constexpr (need_trim) + v = preprocess(s, length); + else + v = std::string_view(s, length); + + // every char have 8 uint16 at most. + const auto size = 8 * length * sizeof(uint16_t); + auto * buffer = pool.alignedAlloc(size, 16); + + size_t offset = 0; + size_t total_size = 0; + size_t v_length = v.length(); + + uint64_t first = 0, second = 0; + + if constexpr (need_len) + { + if (lens->capacity() < v_length) + lens->reserve(v_length); + lens->resize(0); + } + + while (offset < v_length) + { + weight(first, second, offset, v_length, s); + + if constexpr (need_len) + lens->push_back(total_size); + + writeResult(first, buffer, total_size); + writeResult(second, buffer, total_size); + + if constexpr (need_len) + { + size_t end_idx = lens->size() - 1; + (*lens)[end_idx] = total_size - (*lens)[end_idx]; + } + } + + return StringRef(buffer, total_size); + } + template StringRef convertImpl(const char * s, size_t length, std::string & container, std::vector * lens) const { @@ -550,6 +652,16 @@ class UCACICollator final : public ITiDBCollator } } + static inline void writeResult(uint64_t & w, char * buffer, size_t & total_size) + { + while (w != 0) + { + buffer[total_size++] = static_cast(w >> 8); + buffer[total_size++] = static_cast(w); + w >>= 16; + } + } + static inline bool regexEq(CharType a, CharType b) { return T::regexEq(a, b); } static inline void weight(uint64_t & first, uint64_t & second, size_t & offset, size_t length, const char * s) diff --git a/dbms/src/TiDB/Collation/Collator.h b/dbms/src/TiDB/Collation/Collator.h index 6bb87883ef1..08c017ba57d 100644 --- a/dbms/src/TiDB/Collation/Collator.h +++ b/dbms/src/TiDB/Collation/Collator.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -101,6 +102,7 @@ class ITiDBCollator = 0; virtual StringRef sortKeyNoTrim(const char * s, size_t length, std::string & container) const = 0; virtual StringRef sortKey(const char * s, size_t length, std::string & container) const = 0; + virtual StringRef sortKey(const char * s, size_t length, DB::Arena &) const = 0; virtual std::unique_ptr pattern() const = 0; int32_t getCollatorId() const { return collator_id; } CollatorType getCollatorType() const { return collator_type; } @@ -135,6 +137,14 @@ class ITiDBCollator } return sortKey(s, length, container); } + ALWAYS_INLINE inline StringRef sortKeyFastPath(const char * s, size_t length, DB::Arena & pool) const + { + if (likely(isPaddingBinary())) + { + return DB::BinCollatorSortKey(s, length); + } + return sortKey(s, length, pool); + } protected: explicit ITiDBCollator(int32_t collator_id_); diff --git a/libs/libcommon/include/common/StringRef.h b/libs/libcommon/include/common/StringRef.h index a87b54a7670..bf1ab026a49 100644 --- a/libs/libcommon/include/common/StringRef.h +++ b/libs/libcommon/include/common/StringRef.h @@ -180,7 +180,7 @@ inline size_t hashLessThan8(const char * data, size_t size) struct CRC32Hash { - size_t operator()(StringRef x) const + static size_t operator()(const StringRef & x) { const char * pos = x.data; size_t size = x.size;