From c0946f98a097ac3ee0e9b8368d96a6a8c4ad65b6 Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Fri, 10 Jan 2025 14:11:00 +0800 Subject: [PATCH] sparse: readd GetVectorByIds support for growing segments with IP metric Signed-off-by: Shawn Wang --- src/index/sparse/sparse_index_node.cc | 65 ++++++++++++++++++--- tests/ut/test_sparse.cc | 82 ++++++++++++++++----------- 2 files changed, 107 insertions(+), 40 deletions(-) diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index b3bccf3a7..03f3affa2 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -11,10 +11,13 @@ #include +#include + #include "index/sparse/sparse_inverted_index.h" #include "index/sparse/sparse_inverted_index_config.h" #include "io/file_io.h" #include "io/memory_io.h" +#include "knowhere/comp/index_param.h" #include "knowhere/comp/thread_pool.h" #include "knowhere/config.h" #include "knowhere/dataset.h" @@ -389,7 +392,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { } Status - Add(const DataSetPtr dataset, std::shared_ptr cfg) override { + Add(const DataSetPtr dataset, std::shared_ptr config) override { std::unique_lock lock(mutex_); uint64_t task_id = next_task_id_++; add_tasks_.push(task_id); @@ -397,7 +400,15 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { // add task is allowed to run only after all search tasks that come before it have finished. cv_.wait(lock, [this, task_id]() { return current_task_id_ == task_id && active_readers_ == 0; }); - auto res = SparseInvertedIndexNode::Add(dataset, cfg); + auto res = SparseInvertedIndexNode::Add(dataset, config); + + auto cfg = static_cast(*config); + if (IsMetricType(cfg.metric_type.value(), metric::IP)) { + // insert dataset to raw data if metric type is IP + auto data = static_cast*>(dataset->GetTensor()); + auto rows = dataset->GetRows(); + raw_data_.insert(raw_data_.end(), data, data + rows); + } add_tasks_.pop(); current_task_id_++; @@ -431,12 +442,6 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { return SparseInvertedIndexNode::RangeSearch(dataset, std::move(cfg), bitset); } - expected - GetVectorByIds(const DataSetPtr dataset) const override { - ReadPermission permission(*this); - return SparseInvertedIndexNode::GetVectorByIds(dataset); - } - int64_t Dim() const override { ReadPermission permission(*this); @@ -461,6 +466,49 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { : knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC; } + expected + GetVectorByIds(const DataSetPtr dataset) const override { + ReadPermission permission(*this); + + if (raw_data_.empty()) { + return expected::Err(Status::invalid_args, "GetVectorByIds failed: raw data is empty"); + } + + auto rows = dataset->GetRows(); + auto ids = dataset->GetIds(); + auto data = std::make_unique[]>(rows); + int64_t dim = 0; + + try { + for (int64_t i = 0; i < rows; ++i) { + data[i] = raw_data_[ids[i]]; + dim = std::max(dim, data[i].dim()); + } + } catch (std::exception& e) { + return expected::Err(Status::invalid_args, "GetVectorByIds failed: " + std::string(e.what())); + } + + auto res = GenResultDataSet(rows, dim, data.release()); + res->SetIsSparse(true); + + return res; + } + + [[nodiscard]] bool + HasRawData(const std::string& metric_type) const override { + return IsMetricType(metric_type, metric::IP); + } + + Status + Deserialize(const BinarySet& binset, std::shared_ptr config) override { + return Status::not_implemented; + } + + Status + DeserializeFromFile(const std::string& filename, std::shared_ptr config) override { + return Status::not_implemented; + } + private: struct ReadPermission { ReadPermission(const SparseInvertedIndexNodeCC& node) : node_(node) { @@ -490,6 +538,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { mutable std::queue add_tasks_; mutable uint64_t next_task_id_ = 0; mutable uint64_t current_task_id_ = 0; + mutable std::vector> raw_data_ = {}; }; // class SparseInvertedIndexNodeCC KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP, diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc index c84886b32..380678c36 100644 --- a/tests/ut/test_sparse.cc +++ b/tests/ut/test_sparse.cc @@ -504,41 +504,42 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") { auto test_time = 10; - SECTION("Test Search") { - using std::make_tuple; - auto [name, gen] = GENERATE_REF(table>({ - make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen), - make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen), - })); + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen), + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen), + })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); - auto cfg_json = gen().dump(); - CAPTURE(name, cfg_json); - knowhere::Json json = knowhere::Json::parse(cfg_json); - REQUIRE(idx.Type() == name); - // build the index with some initial data - REQUIRE(idx.Build(doc_vector_gen(nb, dim), json) == knowhere::Status::success); - - auto add_task = [&]() { - auto start = std::chrono::steady_clock::now(); - while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < - test_time) { - auto doc_ds = doc_vector_gen(nb, dim); - auto res = idx.Add(doc_ds, json); - REQUIRE(res == knowhere::Status::success); - } - }; + auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + // build the index with some initial data + auto train_ds = doc_vector_gen(nb, dim); + REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success); - auto search_task = [&]() { - auto start = std::chrono::steady_clock::now(); - while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < - test_time) { - auto results = idx.Search(query_ds, json, nullptr); - REQUIRE(results.has_value()); - check_result(*results.value()); - } - }; + auto add_task = [&]() { + auto start = std::chrono::steady_clock::now(); + while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < + test_time) { + auto doc_ds = doc_vector_gen(nb, dim); + auto res = idx.Add(doc_ds, json); + REQUIRE(res == knowhere::Status::success); + } + }; + auto search_task = [&]() { + auto start = std::chrono::steady_clock::now(); + while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < + test_time) { + auto results = idx.Search(query_ds, json, nullptr); + REQUIRE(results.has_value()); + check_result(*results.value()); + } + }; + + SECTION("Test Search") { std::vector> task_list; for (int thread = 0; thread < 5; thread++) { task_list.push_back(std::async(std::launch::async, search_task)); @@ -548,4 +549,21 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") { task.wait(); } } + + SECTION("Test GetVectorByIds") { + std::vector ids = {0, 1, 2}; + REQUIRE(idx.HasRawData(metric)); + auto results = idx.GetVectorByIds(GenIdsDataSet(3, ids)); + REQUIRE(results.has_value()); + auto xb = (knowhere::sparse::SparseRow*)train_ds->GetTensor(); + auto res_data = (knowhere::sparse::SparseRow*)results.value()->GetTensor(); + for (int i = 0; i < 3; ++i) { + const auto& truth_row = xb[i]; + const auto& res_row = res_data[i]; + REQUIRE(truth_row.size() == res_row.size()); + for (size_t j = 0; j < truth_row.size(); ++j) { + REQUIRE(truth_row[j] == res_row[j]); + } + } + } }