diff --git a/include/knowhere/comp/brute_force.h b/include/knowhere/comp/brute_force.h index d3a11fe49..7b089bd05 100644 --- a/include/knowhere/comp/brute_force.h +++ b/include/knowhere/comp/brute_force.h @@ -51,7 +51,7 @@ class BruteForce { template static expected> AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset); + const BitsetView& bitset, bool use_knowhere_search_pool = true); }; } // namespace knowhere diff --git a/include/knowhere/expected.h b/include/knowhere/expected.h index 0c46df4cd..f21f57896 100644 --- a/include/knowhere/expected.h +++ b/include/knowhere/expected.h @@ -49,6 +49,7 @@ enum class Status { internal_error = 27, invalid_serialized_index_type = 28, sparse_inner_error = 29, + brute_force_inner_error = 30, }; inline std::string @@ -104,6 +105,8 @@ Status2String(knowhere::Status status) { return "the serialized index type is not recognized"; case knowhere::Status::sparse_inner_error: return "sparse index inner error"; + case knowhere::Status::brute_force_inner_error: + return "brute_force index inner error"; default: return "unexpected status"; } diff --git a/include/knowhere/index/index.h b/include/knowhere/index/index.h index 4f3f21604..7bbb22ec4 100644 --- a/include/knowhere/index/index.h +++ b/include/knowhere/index/index.h @@ -161,7 +161,8 @@ class Index { Search(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; expected> - AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; + AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset, + bool use_knowhere_search_pool = true) const; expected RangeSearch(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; diff --git a/include/knowhere/index/index_node.h b/include/knowhere/index/index_node.h index 3200c624d..5b24d525f 100644 --- a/include/knowhere/index/index_node.h +++ b/include/knowhere/index/index_node.h @@ -146,7 +146,8 @@ class IndexNode : public Object { using IteratorPtr = std::shared_ptr; virtual expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool = true) const { return expected>>::Err( Status::not_implemented, "annIterator not supported for current index type"); } @@ -202,7 +203,10 @@ class IndexNode : public Object { return GenResultDataSet(nq, std::move(range_search_result)); } - auto its_or = AnnIterator(dataset, std::move(cfg), bitset); + // The range_search function has utilized the search_pool to concurrently handle various queries. + // To prevent potential deadlocks, the iterator for a single query no longer requires additional thread + // control over the next() call. + auto its_or = AnnIterator(dataset, std::move(cfg), bitset, false); if (!its_or.has_value()) { return expected::Err(its_or.error(), "RangeSearch failed due to AnnIterator failure: " + its_or.what()); @@ -290,13 +294,17 @@ class IndexNode : public Object { futs.reserve(nq); if (retain_iterator_order) { for (size_t i = 0; i < nq; i++) { - futs.emplace_back( - ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { task_with_ordered_iterator(idx); })); + futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { + ThreadPool::ScopedSearchOmpSetter setter(1); + task_with_ordered_iterator(idx); + })); } } else { for (size_t i = 0; i < nq; i++) { - futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push( - [&, idx = i]() { task_with_unordered_iterator(idx); })); + futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { + ThreadPool::ScopedSearchOmpSetter setter(1); + task_with_unordered_iterator(idx); + })); } } WaitAllSuccess(futs); @@ -437,6 +445,10 @@ class IndexNode : public Object { // with quantization, override `raw_distance`. // Internally, this structure uses the same priority queue class, but may multiply all // incoming distances to (-1) value in order to turn max priority queue into a min one. +// If use_knowhere_search_pool is True (the default), the iterator->Next() will be scheduled by the +// knowhere_search_thread_pool. +// If False, will Not involve thread scheduling internally, so please take caution. +template class IndexIterator : public IndexNode::iterator { public: IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f, bool retain_iterator_order = false) @@ -449,7 +461,7 @@ class IndexIterator : public IndexNode::iterator { std::pair Next() override { if (!initialized_) { - throw std::runtime_error("Next should not be called before initialization"); + initialize(); } auto& q = refined_res_.empty() ? res_ : refined_res_; if (q.empty()) { @@ -457,28 +469,46 @@ class IndexIterator : public IndexNode::iterator { } auto ret = q.top(); q.pop(); - UpdateNext(); - if (retain_iterator_order_) { - while (HasNext()) { - auto& q = refined_res_.empty() ? res_ : refined_res_; - auto next_ret = q.top(); - // with the help of `sign_`, both `res_` and `refine_res` are min-heap. - // such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`. - // just make sure that the next value is greater than or equal to the current value. - if (next_ret.val >= ret.val) { - break; + + auto update_next_func = [&]() { + UpdateNext(); + if (retain_iterator_order_) { + while (HasNext()) { + auto& q = refined_res_.empty() ? res_ : refined_res_; + auto next_ret = q.top(); + // with the help of `sign_`, both `res_` and `refine_res` are min-heap. + // such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`. + // just make sure that the next value is greater than or equal to the current value. + if (next_ret.val >= ret.val) { + break; + } + q.pop(); + UpdateNext(); } - q.pop(); - UpdateNext(); } + }; + if constexpr (use_knowhere_search_pool) { +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) + std::vector> futs; + futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { + ThreadPool::ScopedSearchOmpSetter setter(1); + update_next_func(); + })); + WaitAllSuccess(futs); +#else + update_next_func(); +#endif + } else { + update_next_func(); } + return std::make_pair(ret.id, ret.val * sign_); } [[nodiscard]] bool HasNext() override { if (!initialized_) { - throw std::runtime_error("HasNext should not be called before initialization"); + initialize(); } return !res_.empty() || !refined_res_.empty(); } @@ -544,42 +574,90 @@ class IndexIterator : public IndexNode::iterator { }; // An iterator implementation that accepts a list of distances and ids and returns them in order. +template class PrecomputedDistanceIterator : public IndexNode::iterator { public: - PrecomputedDistanceIterator(std::vector&& distances_ids, bool larger_is_closer) - : larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) { - sort_size_ = get_sort_size(results_.size()); - sort_next(); - } - - // Construct an iterator from a list of distances with index being id, filtering out zero distances. - PrecomputedDistanceIterator(const std::vector& distances, bool larger_is_closer) - : larger_is_closer_(larger_is_closer) { - // 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero - // dims out of 30000 total dims) sharing at least 1 common non-zero dimension. - results_.reserve(distances.size() * 0.3); - for (size_t i = 0; i < distances.size(); i++) { - if (distances[i] != 0) { - results_.emplace_back((int64_t)i, distances[i]); - } - } - sort_size_ = get_sort_size(results_.size()); - sort_next(); + PrecomputedDistanceIterator(Compute_Dist_Func compute_dist_func, bool larger_is_closer) + : compute_dist_func_(compute_dist_func), larger_is_closer_(larger_is_closer) { } std::pair Next() override { - sort_next(); + if (!initialized_) { + initialize(); + } + if constexpr (use_knowhere_search_pool) { +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) + std::vector> futs; + futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { + ThreadPool::ScopedSearchOmpSetter setter(1); + sort_next(); + })); + WaitAllSuccess(futs); +#else + sort_next(); +#endif + } else { + sort_next(); + } auto& result = results_[next_++]; return std::make_pair(result.id, result.val); } [[nodiscard]] bool HasNext() override { + if (!initialized_) { + initialize(); + } return next_ < results_.size() && results_[next_].id != -1; } + void + initialize() { + if (initialized_) { + throw std::runtime_error("initialize should not be called twice"); + } + if constexpr (use_knowhere_search_pool) { +#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) + std::vector> futs; + futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { + ThreadPool::ScopedSearchOmpSetter setter(1); + compute_all_dist_ids(); + })); + WaitAllSuccess(futs); +#else + compute_all_dist_ids(); +#endif + } else { + compute_all_dist_ids(); + } + sort_next(); + initialized_ = true; + } + private: + void + compute_all_dist_ids() { + using ReturnType = std::invoke_result_t; + if constexpr (std::is_same_v>) { + results_ = compute_dist_func_(); + } else if constexpr (std::is_same_v>) { + // From a list of distances with index being id, filtering out zero distances. + std::vector dists = compute_dist_func_(); + // 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non + // zero dims out of 30000 total dims) sharing at least 1 common non-zero dimension. + results_.reserve(dists.size() * 0.3); + for (size_t i = 0; i < dists.size(); i++) { + if (dists[i] != 0) { + results_.emplace_back((int64_t)i, dists[i]); + } + } + } else { + throw std::runtime_error("unknown compute_dist_func"); + } + sort_size_ = get_sort_size(results_.size()); + } + static inline size_t get_sort_size(size_t rows) { return std::max((size_t)50000, rows / 10); @@ -602,8 +680,10 @@ class PrecomputedDistanceIterator : public IndexNode::iterator { sorted_ = current_end; } - const bool larger_is_closer_; + Compute_Dist_Func compute_dist_func_; + const bool larger_is_closer_; + bool initialized_ = false; std::vector results_; size_t next_ = 0; size_t sorted_ = 0; diff --git a/include/knowhere/index/index_node_data_mock_wrapper.h b/include/knowhere/index/index_node_data_mock_wrapper.h index 0d03c9bcb..3e19bde72 100644 --- a/include/knowhere/index/index_node_data_mock_wrapper.h +++ b/include/knowhere/index/index_node_data_mock_wrapper.h @@ -42,7 +42,8 @@ class IndexNodeDataMockWrapper : public IndexNode { RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override; expected GetVectorByIds(const DataSetPtr dataset) const override; diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index d0c5a1750..2fc2dc2aa 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -748,13 +748,9 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d template expected> BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); + const BitsetView& bitset, bool use_knowhere_search_pool) { auto nb = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); - auto xb_id_offset = base_dataset->GetTensorBeginId(); - - auto xq = query_dataset->GetTensor(); auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; @@ -786,84 +782,88 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da #endif faiss::MetricType faiss_metric_type = result.value(); bool is_cosine = IsMetricType(metric_str, metric::COSINE); - - std::unique_ptr norms = is_cosine ? GetVecNorms(base_dataset) : nullptr; - auto pool = ThreadPool::GetGlobalSearchThreadPool(); + auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine; auto vec = std::vector(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - - for (int i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { - ThreadPool::ScopedSearchOmpSetter setter(1); - - BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); - [[maybe_unused]] faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine; - auto max_dis = larger_is_closer ? std::numeric_limits::lowest() : std::numeric_limits::max(); - std::vector distances_ids(nb, {-1, max_dis}); - [[maybe_unused]] auto cur_query = (const DataType*)xq + dim * index; - switch (faiss_metric_type) { - case faiss::METRIC_L2: { - if constexpr (std::is_same_v) { - faiss::all_L2sqr(cur_query, (const float*)xb, dim, 1, nb, distances_ids, nullptr, id_selector); - } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { - faiss::half_precision_floating_point_all_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, - distances_ids, nullptr, id_selector); - } else { - LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; - return Status::faiss_inner_error; - } - break; - } - case faiss::METRIC_INNER_PRODUCT: { - if (is_cosine) { + std::shared_ptr norms = GetVecNorms(base_dataset); + + try { + for (int i = 0; i < nq; ++i) { + auto compute_dist_func = [=]() -> std::vector { + auto xb = base_dataset->GetTensor(); + auto xq = query_dataset->GetTensor(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); + [[maybe_unused]] faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; + auto max_dis = + larger_is_closer ? std::numeric_limits::lowest() : std::numeric_limits::max(); + std::vector distances_ids(nb, {-1, max_dis}); + [[maybe_unused]] auto cur_query = (const DataType*)xq + dim * i; + switch (faiss_metric_type) { + case faiss::METRIC_L2: { if constexpr (std::is_same_v) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, - distances_ids, id_selector); + faiss::all_L2sqr(cur_query, (const float*)xb, dim, 1, nb, distances_ids, nullptr, + id_selector); } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { - // normalize query vector may cause presision loss, so div query norms in apply function - - faiss::half_precision_floating_point_all_cosine(cur_query, (const DataType*)xb, norms.get(), - dim, 1, nb, distances_ids, id_selector); + faiss::half_precision_floating_point_all_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, + distances_ids, nullptr, id_selector); } else { - LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; - return Status::faiss_inner_error; + LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; + throw std::runtime_error("Metric L2 only used in floating point data"); } - } else { - if constexpr (std::is_same_v) { - faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, - id_selector); - } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { - faiss::half_precision_floating_point_all_inner_product(cur_query, (const DataType*)xb, dim, - 1, nb, distances_ids, id_selector); + break; + } + case faiss::METRIC_INNER_PRODUCT: { + if (is_cosine) { + if constexpr (std::is_same_v) { + auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, + distances_ids, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + // normalize query vector may cause presision loss, so div query norms in apply function + + faiss::half_precision_floating_point_all_cosine(cur_query, (const DataType*)xb, + norms.get(), dim, 1, nb, distances_ids, + id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + throw std::runtime_error("Metric Cosine only used in floating point data"); + } } else { - LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; - return Status::faiss_inner_error; + if constexpr (std::is_same_v) { + faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, + id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_all_inner_product( + cur_query, (const DataType*)xb, dim, 1, nb, distances_ids, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + throw std::runtime_error("Metric Cosine only used in floating point data"); + } } + break; + } + default: { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); + throw std::runtime_error("Invalid metric type"); } - break; - } - default: { - LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); - return Status::invalid_metric_type; } - } - if (xb_id_offset != 0) { - for (auto& distances_id : distances_ids) { - distances_id.id = distances_id.id == -1 ? -1 : distances_id.id + xb_id_offset; + if (xb_id_offset != 0) { + for (auto& distances_id : distances_ids) { + distances_id.id = distances_id.id == -1 ? -1 : distances_id.id + xb_id_offset; + } } + return distances_ids; + }; + if (use_knowhere_search_pool) { + vec[i] = std::make_shared>( + compute_dist_func, larger_is_closer); + } else { + vec[i] = std::make_shared>( + compute_dist_func, larger_is_closer); } - vec[index] = std::make_shared(std::move(distances_ids), larger_is_closer); - - return Status::success; - })); - } - - auto ret = WaitAllSuccess(futs); - if (ret != Status::success) { - return expected>::Err(ret, "failed to brute force search for iterator"); + } + } catch (const std::exception& e) { + return expected>::Err(Status::brute_force_inner_error, e.what()); } #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) @@ -880,12 +880,9 @@ template <> expected> BruteForce::AnnIterator>(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset) { - auto base = static_cast*>(base_dataset->GetTensor()); + const BitsetView& bitset, bool use_knowhere_search_pool) { auto rows = base_dataset->GetRows(); auto xb_id_offset = base_dataset->GetTensorBeginId(); - - auto xq = static_cast*>(query_dataset->GetTensor()); auto nq = query_dataset->GetRows(); BruteForceConfig cfg; @@ -925,37 +922,47 @@ BruteForce::AnnIterator>(const DataSetPtr bas } auto computer = computer_or.value(); - auto pool = ThreadPool::GetGlobalSearchThreadPool(); auto vec = std::vector(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - for (int64_t i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { - const auto& row = xq[index]; - std::vector distances_ids; - if (row.size() > 0) { - for (int64_t j = 0; j < rows; ++j) { - auto xb_id = j + xb_id_offset; - if (!bitset.empty() && bitset.test(xb_id)) { - continue; - } - float row_sum = 0; - if (is_bm25) { - for (size_t k = 0; k < base[j].size(); ++k) { - auto [d, v] = base[j][k]; - row_sum += v; + try { + for (int64_t i = 0; i < nq; ++i) { + auto compute_dist_func = [=]() -> std::vector { + auto xq = static_cast*>(query_dataset->GetTensor()); + auto base = static_cast*>(base_dataset->GetTensor()); + const auto& row = xq[i]; + std::vector distances_ids; + if (row.size() > 0) { + for (int64_t j = 0; j < rows; ++j) { + auto xb_id = j + xb_id_offset; + if (!bitset.empty() && bitset.test(xb_id)) { + continue; + } + float row_sum = 0; + if (is_bm25) { + for (size_t k = 0; k < base[j].size(); ++k) { + auto [d, v] = base[j][k]; + row_sum += v; + } + } + auto dist = row.dot(base[j], computer, row_sum); + if (dist > 0) { + distances_ids.emplace_back(xb_id, dist); } - } - auto dist = row.dot(base[j], computer, row_sum); - if (dist > 0) { - distances_ids.emplace_back(xb_id, dist); } } + return distances_ids; + }; + + if (use_knowhere_search_pool) { + vec[i] = std::make_shared>( + compute_dist_func, true); + } else { + vec[i] = std::make_shared>( + compute_dist_func, true); } - vec[index] = std::make_shared(std::move(distances_ids), true); - })); + } + } catch (const std::exception& e) { + return expected>::Err(Status::brute_force_inner_error, e.what()); } - WaitAllSuccess(futs); #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) // LCOV_EXCL_START @@ -1026,12 +1033,15 @@ knowhere::BruteForce::RangeSearch>(const know template knowhere::expected> knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); + const knowhere::Json& config, const knowhere::BitsetView& bitset, + bool use_knowhere_search_pool = true); template knowhere::expected> knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); + const knowhere::Json& config, const knowhere::BitsetView& bitset, + bool use_knowhere_search_pool = true); template knowhere::expected> knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); + const knowhere::Json& config, const knowhere::BitsetView& bitset, + bool use_knowhere_search_pool = true); diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index 0987f0317..54db80654 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -151,14 +151,16 @@ class DiskANNIndexNode : public IndexNode { } expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override; private: - class iterator : public IndexIterator { + template + class iterator : public IndexIterator { public: iterator(const bool transform, const DataType* query_data, const uint64_t lsearch, const uint64_t beam_width, const float filter_ratio, const knowhere::BitsetView& bitset, diskann::PQFlashIndex* index) - : IndexIterator(transform), + : IndexIterator(transform), index_(index), transform_(transform), workspace_(index_->getIteratorWorkspace(query_data, lsearch, beam_width, filter_ratio, bitset)) { @@ -552,18 +554,17 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, std::shared_ptr template expected> -DiskANNIndexNode::AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, - const BitsetView& bitset) const { +DiskANNIndexNode::AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; - return expected>>::Err(Status::empty_index, - "DiskANN not loaded"); + return expected>::Err(Status::empty_index, "DiskANN not loaded"); } auto search_conf = static_cast(*cfg); if (!CheckMetric(search_conf.metric_type.value())) { - return expected>>::Err(Status::invalid_metric_type, - "unsupported metric type"); + return expected>::Err(Status::invalid_metric_type, + "unsupported metric type"); } constexpr uint64_t k_lsearch_iterator = 32; @@ -575,25 +576,25 @@ DiskANNIndexNode::AnnIterator(const DataSetPtr dataset, std::unique_pt auto dim = dataset->GetDim(); auto xq = dataset->GetTensor(); - std::vector> futs; - futs.reserve(nq); auto vec = std::vector(nq, nullptr); auto metric = search_conf.metric_type.value(); bool transform = metric != knowhere::metric::L2; - for (int i = 0; i < nq; i++) { - futs.emplace_back(search_pool_->push([&, id = i]() { - auto single_query = (DataType*)xq + id * dim; - auto it = std::make_shared(transform, single_query, lsearch, beamwidth, filter_ratio, bitset, - pq_flash_index_.get()); - it->initialize(); - vec[id] = it; - })); - } - - if (TryDiskANNCall([&]() { WaitAllSuccess(futs); }) != Status::success) { - return expected>>::Err(Status::diskann_inner_error, - "some ann-iterator failed"); + try { + for (int i = 0; i < nq; i++) { + auto single_query = (DataType*)xq + i * dim; + if (use_knowhere_search_pool) { + auto it = std::make_shared>(transform, single_query, lsearch, beamwidth, filter_ratio, + bitset, pq_flash_index_.get()); + vec[i] = it; + } else { + auto it = std::make_shared>(transform, single_query, lsearch, beamwidth, filter_ratio, + bitset, pq_flash_index_.get()); + vec[i] = it; + } + } + } catch (const std::exception& e) { + return expected>::Err(Status::diskann_inner_error, e.what()); } return vec; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 28f0307a8..625adcaee 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -530,13 +530,13 @@ struct FaissHnswIteratorWorkspace { }; // Contains an iterator logic -class FaissHnswIterator : public IndexIterator { +template +class FaissHnswIterator : public IndexIterator { public: FaissHnswIterator(const std::shared_ptr& index_in, std::unique_ptr&& query_in, const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer, const float refine_ratio = 0.5f) - : IndexIterator(larger_is_closer, refine_ratio), index{index_in} { - // + : IndexIterator(larger_is_closer, refine_ratio), index{index_in} { workspace.accumulated_alpha = (bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() @@ -1219,7 +1219,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { public: // expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override { if (index == nullptr) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; return expected>::Err(Status::empty_index, "index not loaded"); @@ -1245,49 +1246,48 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto ef = hnsw_cfg.ef.value_or(kIteratorSeedEf); try { - std::vector> futs; - futs.reserve(n_queries); for (int64_t i = 0; i < n_queries; i++) { - futs.emplace_back(search_pool->push([&, idx = i] { - // The query data is always cloned - std::unique_ptr cur_query = std::make_unique(dim); - - switch (data_format) { - case DataFormatEnum::fp32: - std::copy_n(reinterpret_cast(data) + idx * dim, dim, cur_query.get()); - break; - case DataFormatEnum::fp16: - case DataFormatEnum::bf16: - case DataFormatEnum::int8: - convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim); - break; - default: - // invalid one. Should not be triggered, bcz input parameters are validated - throw; - } - - const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); - - const float iterator_refine_ratio = - should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; + // The query data is always cloned + std::unique_ptr cur_query = std::make_unique(dim); + + switch (data_format) { + case DataFormatEnum::fp32: + std::copy_n(reinterpret_cast(data) + i * dim, dim, cur_query.get()); + break; + case DataFormatEnum::fp16: + case DataFormatEnum::bf16: + case DataFormatEnum::int8: + convert_rows_to_fp32(data, cur_query.get(), data_format, i, 1, dim); + break; + default: + // invalid one. Should not be triggered, bcz input parameters are validated + throw; + } - // create an iterator and initialize it - auto it = - std::make_shared(index, std::move(cur_query), bitset, ef, larger_is_closer, - // // refine is not needed for flat - // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) - iterator_refine_ratio); + const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); - it->initialize(); + const float iterator_refine_ratio = + should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; + // create an iterator and initialize it + if (use_knowhere_search_pool) { + auto it = std::make_shared>(index, std::move(cur_query), bitset, ef, + larger_is_closer, + // // refine is not needed for flat + // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) + iterator_refine_ratio); // store - vec[idx] = it; - })); + vec[i] = it; + } else { + auto it = std::make_shared>( + index, std::move(cur_query), bitset, ef, larger_is_closer, + // // refine is not needed for flat + // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) + iterator_refine_ratio); + // store + vec[i] = it; + } } - - // wait for the completion - WaitAllSuccess(futs); - } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected>::Err(Status::faiss_inner_error, e.what()); @@ -1548,11 +1548,12 @@ class HNSWIndexNodeWithFallback : public IndexNode { } expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override { if (use_base_index) { - return base_index->AnnIterator(dataset, std::move(cfg), bitset); + return base_index->AnnIterator(dataset, std::move(cfg), bitset, use_knowhere_search_pool); } else { - return fallback_search_index->AnnIterator(dataset, std::move(cfg), bitset); + return fallback_search_index->AnnIterator(dataset, std::move(cfg), bitset, use_knowhere_search_pool); } } diff --git a/src/index/hnsw/hnsw.h b/src/index/hnsw/hnsw.h index 687561fcf..48d257c06 100644 --- a/src/index/hnsw/hnsw.h +++ b/src/index/hnsw/hnsw.h @@ -260,15 +260,17 @@ class HnswIndexNode : public IndexNode { } private: - class iterator : public IndexIterator { + template + class iterator : public IndexIterator { public: iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, const BitsetView& bitset, const size_t ef = kIteratorSeedEf, const float refine_ratio = 0.5f) - : IndexIterator(transform, (hnswlib::HierarchicalNSW::sq_enabled && - hnswlib::HierarchicalNSW::has_raw_data) - ? refine_ratio - : 0.0f), + : IndexIterator( + transform, (hnswlib::HierarchicalNSW::sq_enabled && + hnswlib::HierarchicalNSW::has_raw_data) + ? refine_ratio + : 0.0f), index_(index), transform_(transform), workspace_(index_->getIteratorWorkspace(query, ef, bitset)) { @@ -303,7 +305,8 @@ class HnswIndexNode : public IndexNode { public: expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool = true) const override { if (!index_) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; return expected>::Err(Status::empty_index, "index not loaded"); @@ -317,19 +320,22 @@ class HnswIndexNode : public IndexNode { bool transform = (index_->metric_type_ == hnswlib::Metric::INNER_PRODUCT || index_->metric_type_ == hnswlib::Metric::COSINE); auto vec = std::vector(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - for (int i = 0; i < nq; ++i) { - futs.emplace_back(search_pool_->push([&, i]() { + try { + for (int i = 0; i < nq; ++i) { auto single_query = (const char*)xq + i * index_->data_size_; - auto it = std::make_shared(this->index_, single_query, transform, bitset, ef, - hnsw_cfg.iterator_refine_ratio.value()); - it->initialize(); - vec[i] = it; - })); + if (use_knowhere_search_pool) { + auto it = std::make_shared>(this->index_, single_query, transform, bitset, ef, + hnsw_cfg.iterator_refine_ratio.value()); + vec[i] = it; + } else { + auto it = std::make_shared>(this->index_, single_query, transform, bitset, ef, + hnsw_cfg.iterator_refine_ratio.value()); + vec[i] = it; + } + } + } catch (const std::exception& e) { + return expected>::Err(Status::hnsw_inner_error, e.what()); } - // wait for initial search(in top layers and search for ef in base layer) to finish - WaitAllSuccess(futs); return vec; } diff --git a/src/index/index.cc b/src/index/index.cc index 2e6f7e421..d6da77f63 100644 --- a/src/index/index.cc +++ b/src/index/index.cc @@ -168,7 +168,8 @@ Index::Search(const DataSetPtr dataset, const Json& json, const BitsetView& b template inline expected>> -Index::AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset_) const { +Index::AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset_, + bool use_knowhere_search_pool) const { auto cfg = this->node->CreateConfig(); std::string msg; Status status = LoadConfig(cfg.get(), json, knowhere::ITERATOR, "Iterator", &msg); @@ -191,12 +192,12 @@ Index::AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetVi #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) // note that this time includes only the initial search phase of iterator. TimeRecorder rc("AnnIterator"); - auto res = this->node->AnnIterator(dataset, std::move(cfg), bitset); + auto res = this->node->AnnIterator(dataset, std::move(cfg), bitset, use_knowhere_search_pool); auto time = rc.ElapseFromBegin("done"); time *= 0.001; // convert to ms knowhere_search_latency.Observe(time); #else - auto res = this->node->AnnIterator(dataset, std::move(cfg), bitset); + auto res = this->node->AnnIterator(dataset, std::move(cfg), bitset, use_knowhere_search_pool); #endif return res; } diff --git a/src/index/index_node_data_mock_wrapper.cc b/src/index/index_node_data_mock_wrapper.cc index 7c9627071..c1b4f2ebc 100644 --- a/src/index/index_node_data_mock_wrapper.cc +++ b/src/index/index_node_data_mock_wrapper.cc @@ -60,9 +60,9 @@ IndexNodeDataMockWrapper::RangeSearch(const DataSetPtr dataset, std::u template expected> IndexNodeDataMockWrapper::AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, - const BitsetView& bitset) const { + const BitsetView& bitset, bool use_knowhere_search_pool) const { auto ds_ptr = ConvertFromDataTypeIfNeeded(dataset); - return index_node_->AnnIterator(ds_ptr, std::move(cfg), bitset); + return index_node_->AnnIterator(ds_ptr, std::move(cfg), bitset, use_knowhere_search_pool); } template diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index eab652362..957065e26 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -75,7 +75,8 @@ class IvfIndexNode : public IndexNode { expected RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override; expected GetVectorByIds(const DataSetPtr dataset) const override; @@ -315,11 +316,14 @@ class IvfIndexNode : public IndexNode { // iterator will own the copied_norm_query // TODO: iterator should copy and own query data. // TODO: If SCANN support Iterator, raw_distance() function should be override. - class iterator : public IndexIterator { + template + class iterator : public IndexIterator { public: iterator(const IndexType* index, std::unique_ptr&& copied_query, const BitsetView& bitset, size_t nprobe, bool larger_is_closer, const float refine_ratio = 0.5f) - : IndexIterator(larger_is_closer, refine_ratio), index_(index), copied_query_(std::move(copied_query)) { + : IndexIterator(larger_is_closer, refine_ratio), + index_(index), + copied_query_(std::move(copied_query)) { if (!bitset.empty()) { bw_idselector_ = std::make_unique(bitset); ivf_search_params_.sel = bw_idselector_.get(); @@ -334,7 +338,7 @@ class IvfIndexNode : public IndexNode { protected: void next_batch(std::function&)> batch_handler) override { - index_->getIteratorNextBatch(workspace_.get(), res_.size()); + index_->getIteratorNextBatch(workspace_.get(), this->res_.size()); batch_handler(workspace_->dists); workspace_->dists.clear(); } @@ -342,7 +346,7 @@ class IvfIndexNode : public IndexNode { float raw_distance(int64_t id) override { if constexpr (std::is_same_v) { - if (refine_) { + if (this->refine_) { return workspace_->dis_refine->operator()(id); } else { throw std::runtime_error("raw_distance should not be called if refine == false"); @@ -922,7 +926,7 @@ IvfIndexNode::RangeSearch(const DataSetPtr dataset, std::un template expected> IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, - const BitsetView& bitset) const { + const BitsetView& bitset, bool use_knowhere_search_pool) const { if (!index_) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; return expected>::Err(Status::empty_index, "index not loaded"); @@ -961,32 +965,29 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un } } try { - std::vector> futs; - futs.reserve(rows); for (int i = 0; i < rows; ++i) { - futs.emplace_back(search_pool_->push([&, index = i] { - auto cur_query = (const float*)data + index * dim; - // if cosine, need normalize - std::unique_ptr copied_query = nullptr; - if (is_cosine) { - copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - } else { - copied_query = std::make_unique(dim); - std::copy_n(cur_query, dim, copied_query.get()); - } + auto cur_query = (const float*)data + i * dim; + // if cosine, need normalize + std::unique_ptr copied_query = nullptr; + if (is_cosine) { + copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + } else { + copied_query = std::make_unique(dim); + std::copy_n(cur_query, dim, copied_query.get()); + } + if (use_knowhere_search_pool) { // iterator only own the copied_query. - auto it = std::make_shared(index_.get(), std::move(copied_query), bitset, nprobe, - larger_is_closer, iterator_refine_ratio); - it->initialize(); - vec[index] = it; - })); + auto it = std::make_shared>(index_.get(), std::move(copied_query), bitset, nprobe, + larger_is_closer, iterator_refine_ratio); + vec[i] = it; + } else { + auto it = std::make_shared>(index_.get(), std::move(copied_query), bitset, nprobe, + larger_is_closer, iterator_refine_ratio); + vec[i] = it; + } } - // wait for the completion - // initial search - scan at least (nprobe/nlist)% codes - WaitAllSuccess(futs); - } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected>::Err(Status::faiss_inner_error, e.what()); diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index 7e2e3e4b5..6fbd5d64f 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -123,12 +123,13 @@ class SparseInvertedIndexNode : public IndexNode { } private: - class RefineIterator : public IndexIterator { + template + class RefineIterator : public IndexIterator { public: RefineIterator(const sparse::BaseInvertedIndex* index, sparse::SparseRow&& query, - std::shared_ptr precomputed_it, + std::shared_ptr> precomputed_it, const sparse::DocValueComputer& computer, const float refine_ratio = 0.5f) - : IndexIterator(true, refine_ratio), + : IndexIterator(true, refine_ratio), index_(index), query_(std::move(query)), computer_(computer), @@ -158,14 +159,15 @@ class SparseInvertedIndexNode : public IndexNode { const sparse::BaseInvertedIndex* index_; sparse::SparseRow query_; const sparse::DocValueComputer computer_; - std::shared_ptr precomputed_it_; + std::shared_ptr> precomputed_it_; bool first_return_ = true; }; public: // TODO: for now inverted index and wand use the same impl for AnnIterator. [[nodiscard]] expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr config, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr config, const BitsetView& bitset, + bool use_knowhere_search_pool) const override { if (!index_) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; return expected>>::Err(Status::empty_index, @@ -186,23 +188,39 @@ class SparseInvertedIndexNode : public IndexNode { const bool approximated = index_->IsApproximated() || drop_ratio_search > 0; auto vec = std::vector>(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - for (int i = 0; i < nq; ++i) { - futs.emplace_back(search_pool_->push([&, i]() { - auto it = std::make_shared( - index_->GetAllDistances(queries[i], drop_ratio_search, bitset, computer), true); - if (!approximated || queries[i].size() == 0) { - vec[i] = it; + try { + for (int i = 0; i < nq; ++i) { + auto compute_dist_func = [=]() -> std::vector { + auto queries = static_cast*>(dataset->GetTensor()); + return index_->GetAllDistances(queries[i], drop_ratio_search, bitset, computer); + }; + if (use_knowhere_search_pool) { + if (!approximated || queries[i].size() == 0) { + auto it = std::make_shared>( + compute_dist_func, true); + vec[i] = it; + } else { + sparse::SparseRow query_copy(queries[i]); + auto it = std::make_shared>( + compute_dist_func, true); + vec[i] = std::make_shared>( + index_, std::move(query_copy), it, computer); + } } else { - sparse::SparseRow query_copy(queries[i]); - auto refine_it = std::make_shared(index_, std::move(query_copy), it, computer); - refine_it->initialize(); - vec[i] = std::move(refine_it); + auto it = std::make_shared>( + compute_dist_func, true); + if (!approximated || queries[i].size() == 0) { + vec[i] = it; + } else { + sparse::SparseRow query_copy(queries[i]); + vec[i] = std::make_shared>( + index_, std::move(query_copy), it, computer); + } } - })); + } + } catch (const std::exception& e) { + return expected>::Err(Status::sparse_inner_error, e.what()); } - WaitAllSuccess(futs); return vec; } @@ -421,14 +439,16 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { } expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override { ReadPermission permission(*this); // Always uses PrecomputedDistanceIterator for SparseInvertedIndexNodeCC: // If we want to use RefineIterator, it needs to get another ReadPermission when calling // index_->GetRawDistance(). If an Add task is added in between, there will be a deadlock. auto config = static_cast(*cfg); config.drop_ratio_search = 0.0f; - return SparseInvertedIndexNode::AnnIterator(dataset, std::move(cfg), bitset); + return SparseInvertedIndexNode::AnnIterator(dataset, std::move(cfg), bitset, + use_knowhere_search_pool); } expected diff --git a/tests/ut/test_index_node.cc b/tests/ut/test_index_node.cc index 553d71543..4084d338f 100644 --- a/tests/ut/test_index_node.cc +++ b/tests/ut/test_index_node.cc @@ -57,7 +57,8 @@ class BaseFlatIndexNode : public IndexNode { } expected> - AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override { LOG_KNOWHERE_INFO_ << "BaseFlatIndexNode::AnnIterator()"; return expected>::Err(Status::not_implemented, "BaseFlatIndexNode::AnnIterator() not implemented");