Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize feature registry #990

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,47 +103,37 @@ class IndexFactory {
// Please review carefully and select with caution

// register vector index supporting ALL_TYPE(binary, bf16, fp16, fp32, sparse_float32) data types
#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__);
#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);

// register vector index supporting sparse_float32
#define KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::SPARSE_FLOAT32), \
##__VA_ARGS__);

// register vector index supporting ALL_DENSE_TYPE(binary, bf16, fp16, fp32) data types
#define KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_DENSE_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_TYPE), \
##__VA_ARGS__);
#define KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);
// register vector index supporting binary data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);

// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types
#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);
// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types, but mocked bf16 and fp16
#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__); \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \
##__VA_ARGS__);
#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \
KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__);

#define KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(name, index_node, data_type, features, thread_size) \
KNOWHERE_REGISTER_STATIC(name, index_node, data_type) \
Expand Down
2 changes: 1 addition & 1 deletion src/index/index_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ IndexFactory::Register(const std::string& name, std::function<Index<IndexNode>(c
feature_mapping_[name] = features;
} else {
// All data types should have the same features; please try to avoid breaking this rule.
feature_mapping_[name] = feature_mapping_[name] & features;
feature_mapping_[name] |= features;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem here; if a certain data type supports a feature, then all types will be considered to support this feature; this is unreasonable.This is why ALL_TYPE is used above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current feature mark mechanism cannot handle different data type support different feature for a same index type.
If it happens, the whole feature mark mechanism need be re-designed.

}
return *this;
}
Expand Down
14 changes: 13 additions & 1 deletion tests/ut/test_index_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,30 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") {
REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IVFFLAT,
knowhere::feature::SPARSE_FLOAT32));

// HNSW Index
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BF16));

// HNSW Index
#ifdef KNOWHERE_WITH_CARDINAL
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BINARY));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::SPARSE_FLOAT32));
#else
REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::SPARSE_FLOAT32));
#endif

REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::BF16));

REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::BF16));

REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FLOAT32));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FP16));
REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::BF16));

// Sparse Indexes
REQUIRE_FALSE(
IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_INVERTED_INDEX, knowhere::feature::FLOAT32));
Expand Down
Loading