-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use search_pool to control iterator->Next() #1008
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,7 +146,8 @@ class IndexNode : public Object { | |
using IteratorPtr = std::shared_ptr<iterator>; | ||
|
||
virtual expected<std::vector<IteratorPtr>> | ||
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const { | ||
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset, | ||
bool use_knowhere_search_pool = true) const { | ||
return expected<std::vector<std::shared_ptr<iterator>>>::Err( | ||
Status::not_implemented, "annIterator not supported for current index type"); | ||
} | ||
|
@@ -202,7 +203,10 @@ class IndexNode : public Object { | |
return GenResultDataSet(nq, std::move(range_search_result)); | ||
} | ||
|
||
auto its_or = AnnIterator(dataset, std::move(cfg), bitset); | ||
// The range_search function has utilized the search_pool to concurrently handle various queries. | ||
// To prevent potential deadlocks, the iterator for a single query no longer requires additional thread | ||
// control over the next() call. | ||
auto its_or = AnnIterator(dataset, std::move(cfg), bitset, false); | ||
if (!its_or.has_value()) { | ||
return expected<DataSetPtr>::Err(its_or.error(), | ||
"RangeSearch failed due to AnnIterator failure: " + its_or.what()); | ||
|
@@ -290,13 +294,17 @@ class IndexNode : public Object { | |
futs.reserve(nq); | ||
if (retain_iterator_order) { | ||
for (size_t i = 0; i < nq; i++) { | ||
futs.emplace_back( | ||
ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { task_with_ordered_iterator(idx); })); | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { | ||
ThreadPool::ScopedSearchOmpSetter setter(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rewrite the push to include this omp setter instead of calling it everywhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cqy123456 Is this feasible? Are there any scenarios where we do not need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OmpSetter only use in faiss index. |
||
task_with_ordered_iterator(idx); | ||
})); | ||
} | ||
} else { | ||
for (size_t i = 0; i < nq; i++) { | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push( | ||
[&, idx = i]() { task_with_unordered_iterator(idx); })); | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { | ||
ThreadPool::ScopedSearchOmpSetter setter(1); | ||
task_with_unordered_iterator(idx); | ||
})); | ||
} | ||
} | ||
WaitAllSuccess(futs); | ||
|
@@ -437,48 +445,71 @@ class IndexNode : public Object { | |
// with quantization, override `raw_distance`. | ||
// Internally, this structure uses the same priority queue class, but may multiply all | ||
// incoming distances to (-1) value in order to turn max priority queue into a min one. | ||
// If use_knowhere_search_pool is True (the default), the iterator->Next() will be scheduled by the | ||
// knowhere_search_thread_pool. | ||
// If False, will Not involve thread scheduling internally, so please take caution. | ||
class IndexIterator : public IndexNode::iterator { | ||
public: | ||
IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f, bool retain_iterator_order = false) | ||
IndexIterator(bool larger_is_closer, bool use_knowhere_search_pool = true, float refine_ratio = 0.0f, | ||
bool retain_iterator_order = false) | ||
: refine_ratio_(refine_ratio), | ||
refine_(refine_ratio != 0.0f), | ||
retain_iterator_order_(retain_iterator_order), | ||
use_knowhere_search_pool_(use_knowhere_search_pool), | ||
sign_(larger_is_closer ? -1 : 1) { | ||
} | ||
|
||
std::pair<int64_t, float> | ||
Next() override { | ||
if (!initialized_) { | ||
throw std::runtime_error("Next should not be called before initialization"); | ||
initialize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curios why we need this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throwing error is a bit too strict? |
||
} | ||
auto& q = refined_res_.empty() ? res_ : refined_res_; | ||
if (q.empty()) { | ||
throw std::runtime_error("No more elements"); | ||
} | ||
auto ret = q.top(); | ||
q.pop(); | ||
UpdateNext(); | ||
if (retain_iterator_order_) { | ||
while (HasNext()) { | ||
auto& q = refined_res_.empty() ? res_ : refined_res_; | ||
auto next_ret = q.top(); | ||
// with the help of `sign_`, both `res_` and `refine_res` are min-heap. | ||
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`. | ||
// just make sure that the next value is greater than or equal to the current value. | ||
if (next_ret.val >= ret.val) { | ||
break; | ||
|
||
auto update_next_func = [&]() { | ||
UpdateNext(); | ||
if (retain_iterator_order_) { | ||
while (HasNext()) { | ||
auto& q = refined_res_.empty() ? res_ : refined_res_; | ||
auto next_ret = q.top(); | ||
// with the help of `sign_`, both `res_` and `refine_res` are min-heap. | ||
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`. | ||
// just make sure that the next value is greater than or equal to the current value. | ||
if (next_ret.val >= ret.val) { | ||
break; | ||
} | ||
q.pop(); | ||
UpdateNext(); | ||
} | ||
q.pop(); | ||
UpdateNext(); | ||
} | ||
}; | ||
if (use_knowhere_search_pool_) { | ||
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) | ||
std::vector<folly::Future<folly::Unit>> futs; | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { | ||
ThreadPool::ScopedSearchOmpSetter setter(1); | ||
update_next_func(); | ||
})); | ||
WaitAllSuccess(futs); | ||
#else | ||
update_next_func(); | ||
#endif | ||
} else { | ||
update_next_func(); | ||
} | ||
|
||
return std::make_pair(ret.id, ret.val * sign_); | ||
} | ||
|
||
[[nodiscard]] bool | ||
HasNext() override { | ||
if (!initialized_) { | ||
throw std::runtime_error("HasNext should not be called before initialization"); | ||
initialize(); | ||
} | ||
return !res_.empty() || !refined_res_.empty(); | ||
} | ||
|
@@ -540,45 +571,79 @@ class IndexIterator : public IndexNode::iterator { | |
|
||
bool initialized_ = false; | ||
bool retain_iterator_order_ = false; | ||
bool use_knowhere_search_pool_ = true; | ||
const int64_t sign_; | ||
}; | ||
|
||
// An iterator implementation that accepts a list of distances and ids and returns them in order. | ||
// An iterator implementation that accepts a function to get distances and ids list and returns them in order. | ||
// We do not directly accept a distance list as input. The main reason for this is to minimize the initialization time | ||
// for all types of iterators in the `ANNIterator` interface, moving heavy computations to the first '->Next()' call. | ||
// This way, the iterator initialization does not need to perform any concurrent acceleration, and the search pool | ||
// only needs to handle the heavy work of `->Next()` | ||
class PrecomputedDistanceIterator : public IndexNode::iterator { | ||
public: | ||
PrecomputedDistanceIterator(std::vector<DistId>&& distances_ids, bool larger_is_closer) | ||
: larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) { | ||
sort_size_ = get_sort_size(results_.size()); | ||
sort_next(); | ||
} | ||
|
||
// Construct an iterator from a list of distances with index being id, filtering out zero distances. | ||
PrecomputedDistanceIterator(const std::vector<float>& distances, bool larger_is_closer) | ||
: larger_is_closer_(larger_is_closer) { | ||
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero | ||
// dims out of 30000 total dims) sharing at least 1 common non-zero dimension. | ||
results_.reserve(distances.size() * 0.3); | ||
for (size_t i = 0; i < distances.size(); i++) { | ||
if (distances[i] != 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we do this? I think this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. submitted a new commit to reduce misunderstanding. |
||
results_.emplace_back((int64_t)i, distances[i]); | ||
} | ||
} | ||
sort_size_ = get_sort_size(results_.size()); | ||
sort_next(); | ||
PrecomputedDistanceIterator(std::function<std::vector<DistId>()> compute_dist_func, bool larger_is_closer, | ||
bool use_knowhere_search_pool = true) | ||
: compute_dist_func_(compute_dist_func), | ||
larger_is_closer_(larger_is_closer), | ||
use_knowhere_search_pool_(use_knowhere_search_pool) { | ||
} | ||
|
||
std::pair<int64_t, float> | ||
Next() override { | ||
sort_next(); | ||
if (!initialized_) { | ||
initialize(); | ||
} | ||
if (use_knowhere_search_pool_) { | ||
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) | ||
std::vector<folly::Future<folly::Unit>> futs; | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { | ||
ThreadPool::ScopedSearchOmpSetter setter(1); | ||
sort_next(); | ||
})); | ||
WaitAllSuccess(futs); | ||
#else | ||
sort_next(); | ||
#endif | ||
} else { | ||
sort_next(); | ||
} | ||
auto& result = results_[next_++]; | ||
return std::make_pair(result.id, result.val); | ||
} | ||
|
||
[[nodiscard]] bool | ||
HasNext() override { | ||
if (!initialized_) { | ||
initialize(); | ||
} | ||
return next_ < results_.size() && results_[next_].id != -1; | ||
} | ||
|
||
void | ||
initialize() { | ||
if (initialized_) { | ||
throw std::runtime_error("initialize should not be called twice"); | ||
} | ||
if (use_knowhere_search_pool_) { | ||
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) | ||
std::vector<folly::Future<folly::Unit>> futs; | ||
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() { | ||
ThreadPool::ScopedSearchOmpSetter setter(1); | ||
results_ = compute_dist_func_(); | ||
})); | ||
WaitAllSuccess(futs); | ||
#else | ||
results_ = compute_dist_func_(); | ||
#endif | ||
} else { | ||
results_ = compute_dist_func_(); | ||
} | ||
sort_size_ = get_sort_size(results_.size()); | ||
sort_next(); | ||
initialized_ = true; | ||
} | ||
|
||
private: | ||
static inline size_t | ||
get_sort_size(size_t rows) { | ||
|
@@ -602,8 +667,11 @@ class PrecomputedDistanceIterator : public IndexNode::iterator { | |
|
||
sorted_ = current_end; | ||
} | ||
const bool larger_is_closer_; | ||
|
||
std::function<std::vector<DistId>()> compute_dist_func_; | ||
const bool larger_is_closer_; | ||
bool use_knowhere_search_pool_ = true; | ||
bool initialized_ = false; | ||
std::vector<DistId> results_; | ||
size_t next_ = 0; | ||
size_t sorted_ = 0; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment about why this parameter is
false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.