Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use search_pool to control iterator->Next() #1008

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/knowhere/comp/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BruteForce {
template <typename DataType>
static expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);
const BitsetView& bitset, bool use_knowhere_search_pool = true);
};

} // namespace knowhere
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum class Status {
internal_error = 27,
invalid_serialized_index_type = 28,
sparse_inner_error = 29,
brute_force_inner_error = 30,
};

inline std::string
Expand Down Expand Up @@ -104,6 +105,8 @@ Status2String(knowhere::Status status) {
return "the serialized index type is not recognized";
case knowhere::Status::sparse_inner_error:
return "sparse index inner error";
case knowhere::Status::brute_force_inner_error:
return "brute_force inner error";
default:
return "unexpected status";
}
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class Index {
Search(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;

expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;
AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset,
bool use_knowhere_search_pool = true) const;

expected<DataSetPtr>
RangeSearch(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;
Expand Down
154 changes: 111 additions & 43 deletions include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ class IndexNode : public Object {
using IteratorPtr = std::shared_ptr<iterator>;

virtual expected<std::vector<IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const {
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset,
bool use_knowhere_search_pool = true) const {
return expected<std::vector<std::shared_ptr<iterator>>>::Err(
Status::not_implemented, "annIterator not supported for current index type");
}
Expand Down Expand Up @@ -202,7 +203,10 @@ class IndexNode : public Object {
return GenResultDataSet(nq, std::move(range_search_result));
}

auto its_or = AnnIterator(dataset, std::move(cfg), bitset);
// The range_search function has utilized the search_pool to concurrently handle various queries.
// To prevent potential deadlocks, the iterator for a single query no longer requires additional thread
// control over the next() call.
auto its_or = AnnIterator(dataset, std::move(cfg), bitset, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment about why this parameter is false

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

The range_search function has utilized the search_pool to concurrently handle various queries.
To prevent potential deadlocks, the iterator for a single query no longer requires additional thread control over the next() call.

if (!its_or.has_value()) {
return expected<DataSetPtr>::Err(its_or.error(),
"RangeSearch failed due to AnnIterator failure: " + its_or.what());
Expand Down Expand Up @@ -290,13 +294,17 @@ class IndexNode : public Object {
futs.reserve(nq);
if (retain_iterator_order) {
for (size_t i = 0; i < nq; i++) {
futs.emplace_back(
ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { task_with_ordered_iterator(idx); }));
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rewrite the push to include this omp setter instead of calling it everywhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cqy123456 Is this feasible? Are there any scenarios where we do not need the OmpSetter?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OmpSetter only use in faiss index.

task_with_ordered_iterator(idx);
}));
}
} else {
for (size_t i = 0; i < nq; i++) {
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push(
[&, idx = i]() { task_with_unordered_iterator(idx); }));
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
task_with_unordered_iterator(idx);
}));
}
}
WaitAllSuccess(futs);
Expand Down Expand Up @@ -437,48 +445,71 @@ class IndexNode : public Object {
// with quantization, override `raw_distance`.
// Internally, this structure uses the same priority queue class, but may multiply all
// incoming distances to (-1) value in order to turn max priority queue into a min one.
// If use_knowhere_search_pool is True (the default), the iterator->Next() will be scheduled by the
// knowhere_search_thread_pool.
// If False, will Not involve thread scheduling internally, so please take caution.
class IndexIterator : public IndexNode::iterator {
public:
IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f, bool retain_iterator_order = false)
IndexIterator(bool larger_is_closer, bool use_knowhere_search_pool = true, float refine_ratio = 0.0f,
bool retain_iterator_order = false)
: refine_ratio_(refine_ratio),
refine_(refine_ratio != 0.0f),
retain_iterator_order_(retain_iterator_order),
use_knowhere_search_pool_(use_knowhere_search_pool),
sign_(larger_is_closer ? -1 : 1) {
}

std::pair<int64_t, float>
Next() override {
if (!initialized_) {
throw std::runtime_error("Next should not be called before initialization");
initialize();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curios why we need this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing error is a bit too strict?

}
auto& q = refined_res_.empty() ? res_ : refined_res_;
if (q.empty()) {
throw std::runtime_error("No more elements");
}
auto ret = q.top();
q.pop();
UpdateNext();
if (retain_iterator_order_) {
while (HasNext()) {
auto& q = refined_res_.empty() ? res_ : refined_res_;
auto next_ret = q.top();
// with the help of `sign_`, both `res_` and `refine_res` are min-heap.
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`.
// just make sure that the next value is greater than or equal to the current value.
if (next_ret.val >= ret.val) {
break;

auto update_next_func = [&]() {
UpdateNext();
if (retain_iterator_order_) {
while (HasNext()) {
auto& q = refined_res_.empty() ? res_ : refined_res_;
auto next_ret = q.top();
// with the help of `sign_`, both `res_` and `refine_res` are min-heap.
// such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`.
// just make sure that the next value is greater than or equal to the current value.
if (next_ret.val >= ret.val) {
break;
}
q.pop();
UpdateNext();
}
q.pop();
UpdateNext();
}
};
if (use_knowhere_search_pool_) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
update_next_func();
}));
WaitAllSuccess(futs);
#else
update_next_func();
#endif
} else {
update_next_func();
}

return std::make_pair(ret.id, ret.val * sign_);
}

[[nodiscard]] bool
HasNext() override {
if (!initialized_) {
throw std::runtime_error("HasNext should not be called before initialization");
initialize();
}
return !res_.empty() || !refined_res_.empty();
}
Expand Down Expand Up @@ -540,45 +571,79 @@ class IndexIterator : public IndexNode::iterator {

bool initialized_ = false;
bool retain_iterator_order_ = false;
bool use_knowhere_search_pool_ = true;
const int64_t sign_;
};

// An iterator implementation that accepts a list of distances and ids and returns them in order.
// An iterator implementation that accepts a function to get distances and ids list and returns them in order.
// We do not directly accept a distance list as input. The main reason for this is to minimize the initialization time
// for all types of iterators in the `ANNIterator` interface, moving heavy computations to the first '->Next()' call.
// This way, the iterator initialization does not need to perform any concurrent acceleration, and the search pool
// only needs to handle the heavy work of `->Next()`
class PrecomputedDistanceIterator : public IndexNode::iterator {
public:
PrecomputedDistanceIterator(std::vector<DistId>&& distances_ids, bool larger_is_closer)
: larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) {
sort_size_ = get_sort_size(results_.size());
sort_next();
}

// Construct an iterator from a list of distances with index being id, filtering out zero distances.
PrecomputedDistanceIterator(const std::vector<float>& distances, bool larger_is_closer)
: larger_is_closer_(larger_is_closer) {
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero
// dims out of 30000 total dims) sharing at least 1 common non-zero dimension.
results_.reserve(distances.size() * 0.3);
for (size_t i = 0; i < distances.size(); i++) {
if (distances[i] != 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we do this? I think this PrecomputedDistanceIterator is used for both dense/sparse BF
@zhengbuqian

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

submitted a new commit to reduce misunderstanding.

results_.emplace_back((int64_t)i, distances[i]);
}
}
sort_size_ = get_sort_size(results_.size());
sort_next();
PrecomputedDistanceIterator(std::function<std::vector<DistId>()> compute_dist_func, bool larger_is_closer,
bool use_knowhere_search_pool = true)
: compute_dist_func_(compute_dist_func),
larger_is_closer_(larger_is_closer),
use_knowhere_search_pool_(use_knowhere_search_pool) {
}

std::pair<int64_t, float>
Next() override {
sort_next();
if (!initialized_) {
initialize();
}
if (use_knowhere_search_pool_) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
sort_next();
}));
WaitAllSuccess(futs);
#else
sort_next();
#endif
} else {
sort_next();
}
auto& result = results_[next_++];
return std::make_pair(result.id, result.val);
}

[[nodiscard]] bool
HasNext() override {
if (!initialized_) {
initialize();
}
return next_ < results_.size() && results_[next_].id != -1;
}

void
initialize() {
if (initialized_) {
throw std::runtime_error("initialize should not be called twice");
}
if (use_knowhere_search_pool_) {
#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
results_ = compute_dist_func_();
}));
WaitAllSuccess(futs);
#else
results_ = compute_dist_func_();
#endif
} else {
results_ = compute_dist_func_();
}
sort_size_ = get_sort_size(results_.size());
sort_next();
initialized_ = true;
}

private:
static inline size_t
get_sort_size(size_t rows) {
Expand All @@ -602,8 +667,11 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {

sorted_ = current_end;
}
const bool larger_is_closer_;

std::function<std::vector<DistId>()> compute_dist_func_;
const bool larger_is_closer_;
bool use_knowhere_search_pool_ = true;
bool initialized_ = false;
std::vector<DistId> results_;
size_t next_ = 0;
size_t sorted_ = 0;
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class IndexNodeDataMockWrapper : public IndexNode {
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

expected<std::vector<IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset,
bool use_knowhere_search_pool) const override;

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override;
Expand Down
Loading