Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse: readd GetVectorByIds support for growing segments with IP metric #1022

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

#include <sys/mman.h>

#include <exception>

#include "index/sparse/sparse_inverted_index.h"
#include "index/sparse/sparse_inverted_index_config.h"
#include "io/file_io.h"
#include "io/memory_io.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -389,15 +392,23 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
Add(const DataSetPtr dataset, std::shared_ptr<Config> config) override {
std::unique_lock<std::mutex> lock(mutex_);
uint64_t task_id = next_task_id_++;
add_tasks_.push(task_id);

// add task is allowed to run only after all search tasks that come before it have finished.
cv_.wait(lock, [this, task_id]() { return current_task_id_ == task_id && active_readers_ == 0; });

auto res = SparseInvertedIndexNode<T, use_wand>::Add(dataset, cfg);
auto res = SparseInvertedIndexNode<T, use_wand>::Add(dataset, config);

auto cfg = static_cast<const SparseInvertedIndexConfig&>(*config);
if (IsMetricType(cfg.metric_type.value(), metric::IP)) {
// insert dataset to raw data if metric type is IP
auto data = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto rows = dataset->GetRows();
raw_data_.insert(raw_data_.end(), data, data + rows);
}

add_tasks_.pop();
current_task_id_++;
Expand Down Expand Up @@ -431,12 +442,6 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
return SparseInvertedIndexNode<T, use_wand>::RangeSearch(dataset, std::move(cfg), bitset);
}

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::GetVectorByIds(dataset);
}

int64_t
Dim() const override {
ReadPermission permission(*this);
Expand All @@ -461,6 +466,49 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
: knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC;
}

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override {
ReadPermission permission(*this);

if (raw_data_.empty()) {
return expected<DataSetPtr>::Err(Status::invalid_args, "GetVectorByIds failed: raw data is empty");
}

auto rows = dataset->GetRows();
auto ids = dataset->GetIds();
auto data = std::make_unique<sparse::SparseRow<T>[]>(rows);
int64_t dim = 0;

try {
for (int64_t i = 0; i < rows; ++i) {
data[i] = raw_data_[ids[i]];
dim = std::max(dim, data[i].dim());
}
} catch (std::exception& e) {
return expected<DataSetPtr>::Err(Status::invalid_args, "GetVectorByIds failed: " + std::string(e.what()));
}

auto res = GenResultDataSet(rows, dim, data.release());
res->SetIsSparse(true);

return res;
}

[[nodiscard]] bool
HasRawData(const std::string& metric_type) const override {
return IsMetricType(metric_type, metric::IP);
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

private:
struct ReadPermission {
ReadPermission(const SparseInvertedIndexNodeCC& node) : node_(node) {
Expand Down Expand Up @@ -490,6 +538,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
mutable std::queue<uint64_t> add_tasks_;
mutable uint64_t next_task_id_ = 0;
mutable uint64_t current_task_id_ = 0;
mutable std::vector<sparse::SparseRow<T>> raw_data_ = {};
}; // class SparseInvertedIndexNodeCC

KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP,
Expand Down
82 changes: 50 additions & 32 deletions tests/ut/test_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,41 +504,42 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {

auto test_time = 10;

SECTION("Test Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen),
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen),
}));
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen),
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen),
}));

auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
// build the index with some initial data
REQUIRE(idx.Build(doc_vector_gen(nb, dim), json) == knowhere::Status::success);

auto add_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto doc_ds = doc_vector_gen(nb, dim);
auto res = idx.Add(doc_ds, json);
REQUIRE(res == knowhere::Status::success);
}
};
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
// build the index with some initial data
auto train_ds = doc_vector_gen(nb, dim);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);

auto search_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}
};
auto add_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto doc_ds = doc_vector_gen(nb, dim);
auto res = idx.Add(doc_ds, json);
REQUIRE(res == knowhere::Status::success);
}
};

auto search_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}
};

SECTION("Test Search") {
std::vector<std::future<void>> task_list;
for (int thread = 0; thread < 5; thread++) {
task_list.push_back(std::async(std::launch::async, search_task));
Expand All @@ -548,4 +549,21 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
task.wait();
}
}

SECTION("Test GetVectorByIds") {
std::vector<int64_t> ids = {0, 1, 2};
REQUIRE(idx.HasRawData(metric));
auto results = idx.GetVectorByIds(GenIdsDataSet(3, ids));
REQUIRE(results.has_value());
auto xb = (knowhere::sparse::SparseRow<float>*)train_ds->GetTensor();
auto res_data = (knowhere::sparse::SparseRow<float>*)results.value()->GetTensor();
for (int i = 0; i < 3; ++i) {
const auto& truth_row = xb[i];
const auto& res_row = res_data[i];
REQUIRE(truth_row.size() == res_row.size());
for (size_t j = 0; j < truth_row.size(); ++j) {
REQUIRE(truth_row[j] == res_row[j]);
}
}
}
}