diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 9cda0df76..4526534a2 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -43,6 +43,11 @@ constexpr const char* INDEX_RAFT_IVFFLAT = "GPU_RAFT_IVF_FLAT"; constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ"; constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA"; +constexpr const char* INDEX_GPU_BRUTEFORCE = "GPU_BRUTE_FORCE"; +constexpr const char* INDEX_GPU_IVFFLAT = "GPU_IVF_FLAT"; +constexpr const char* INDEX_GPU_IVFPQ = "GPU_IVF_PQ"; +constexpr const char* INDEX_GPU_CAGRA = "GPU_CAGRA"; + constexpr const char* INDEX_HNSW = "HNSW"; constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8"; constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE"; diff --git a/include/knowhere/comp/knowhere_check.h b/include/knowhere/comp/knowhere_check.h index 1499a7f34..15734dda5 100644 --- a/include/knowhere/comp/knowhere_check.h +++ b/include/knowhere/comp/knowhere_check.h @@ -18,27 +18,14 @@ #include "knowhere/index/index_factory.h" namespace knowhere { namespace KnowhereCheck { -bool +static bool IndexTypeAndDataTypeCheck(const std::string& index_name, VecType data_type) { - auto& index_factory = IndexFactory::Instance(); - switch (data_type) { - case VecType::VECTOR_BINARY: - return index_factory.HasIndex(index_name); - case VecType::VECTOR_FLOAT: - return index_factory.HasIndex(index_name); - case VecType::VECTOR_BFLOAT16: - return index_factory.HasIndex(index_name); - case VecType::VECTOR_FLOAT16: - return index_factory.HasIndex(index_name); - case VecType::VECTOR_SPARSE_FLOAT: - if (index_name != IndexEnum::INDEX_SPARSE_INVERTED_INDEX && index_name != IndexEnum::INDEX_SPARSE_WAND && - index_name != IndexEnum::INDEX_HNSW) { - return false; - } else { - return index_factory.HasIndex(index_name); - } - default: - return false; + auto& static_index_table = IndexFactory::StaticIndexTableInstance(); + auto key = std::pair(index_name, data_type); + if (static_index_table.find(key) != static_index_table.end()) { + return true; + } else { + return false; } } } // namespace KnowhereCheck diff --git a/include/knowhere/index/index_factory.h b/include/knowhere/index/index_factory.h index 1eca451b5..f66cfe9a3 100644 --- a/include/knowhere/index/index_factory.h +++ b/include/knowhere/index/index_factory.h @@ -13,6 +13,7 @@ #define INDEX_FACTORY_H #include +#include #include #include @@ -28,11 +29,11 @@ class IndexFactory { template const IndexFactory& Register(const std::string& name, std::function(const int32_t&, const Object&)> func); - template - bool - HasIndex(const std::string& name); static IndexFactory& Instance(); + typedef std::set> GlobalIndexTable; + static GlobalIndexTable& + StaticIndexTableInstance(); private: struct FunMapValueBase { @@ -76,6 +77,12 @@ class IndexFactory { std::make_unique::type>>(version, object), thread_size)); \ }, \ data_type) +#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(name, index_table) \ + static int name = []() -> int { \ + auto& static_index_table = IndexFactory::StaticIndexTableInstance(); \ + static_index_table.insert(index_table.begin(), index_table.end()); \ + return 0; \ + }(); } // namespace knowhere #endif /* INDEX_FACTORY_H */ diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h new file mode 100644 index 000000000..275bdaf07 --- /dev/null +++ b/include/knowhere/index/index_table.h @@ -0,0 +1,79 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef INDEX_TABLE_H +#define INDEX_TABLE_H +#include +#include + +#include "knowhere/comp/index_param.h" +#include "knowhere/index/index_factory.h" +namespace knowhere { +static std::set> legal_knowhere_index = { + // binary ivf + {IndexEnum::INDEX_FAISS_BIN_IDMAP, VecType::VECTOR_BINARY}, + {IndexEnum::INDEX_FAISS_BIN_IVFFLAT, VecType::VECTOR_BINARY}, + // ivf + {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16}, + // gpu index + {IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT}, + // hnsw + {IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_BFLOAT16}, + + {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16}, + // diskann + {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, + {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, + {IndexEnum::INDEX_DISKANN, VecType::VECTOR_BFLOAT16}, + // sparse index + {IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT}, + {IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT}, +}; +KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(KNOWHERE_STATIC_INDEX, legal_knowhere_index) +} // namespace knowhere +#endif /* INDEX_TABLE_H */ diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 1fce95a34..35ae12733 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -31,8 +31,8 @@ IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) { inline bool IsFlatIndex(const knowhere::IndexType& index_type) { - static std::vector flat_index_list = {IndexEnum::INDEX_FAISS_IDMAP, - IndexEnum::INDEX_FAISS_GPU_IDMAP}; + static std::vector flat_index_list = { + IndexEnum::INDEX_FAISS_IDMAP, IndexEnum::INDEX_FAISS_GPU_IDMAP, IndexEnum::INDEX_GPU_BRUTEFORCE}; return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end(); } diff --git a/src/index/index_factory.cc b/src/index/index_factory.cc index b090425e9..8bce34916 100644 --- a/src/index/index_factory.cc +++ b/src/index/index_factory.cc @@ -11,6 +11,8 @@ #include "knowhere/index/index_factory.h" +#include "knowhere/index/index_table.h" + #ifdef KNOWHERE_WITH_RAFT #include #endif @@ -72,14 +74,6 @@ IndexFactory::Register(const std::string& name, std::function(c return *this; } -template -bool -IndexFactory::HasIndex(const std::string& name) { - auto& func_mapping_ = MapInstance(); - auto key = GetKey(name); - return (func_mapping_.find(key) != func_mapping_.end()); -} - IndexFactory& IndexFactory::Instance() { static IndexFactory factory; @@ -93,6 +87,11 @@ IndexFactory::MapInstance() { static FuncMap func_map; return func_map; } +IndexFactory::GlobalIndexTable& +IndexFactory::StaticIndexTableInstance() { + static GlobalIndexTable static_index_table; + return static_index_table; +} } // namespace knowhere // @@ -116,11 +115,3 @@ knowhere::IndexFactory::Register( template const knowhere::IndexFactory& knowhere::IndexFactory::Register( const std::string&, std::function(const int32_t&, const Object&)>); -template bool -knowhere::IndexFactory::HasIndex(const std::string&); -template bool -knowhere::IndexFactory::HasIndex(const std::string&); -template bool -knowhere::IndexFactory::HasIndex(const std::string&); -template bool -knowhere::IndexFactory::HasIndex(const std::string&); diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index 9f14236cc..fb8dc059c 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -24,8 +24,18 @@ TEST_CASE("Test index and data type check", "[IndexCheckTest]") { knowhere::VecType::VECTOR_FLOAT) == true); REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW, knowhere::VecType::VECTOR_BFLOAT16) == true); + +#ifndef KNOWHERE_WITH_CARDINAL + REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW, + knowhere::VecType::VECTOR_BINARY) == false); + REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW, + knowhere::VecType::VECTOR_SPARSE_FLOAT) == false); +#else REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW, knowhere::VecType::VECTOR_BINARY) == true); + REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_HNSW, + knowhere::VecType::VECTOR_SPARSE_FLOAT) == true); +#endif REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_DISKANN, knowhere::VecType::VECTOR_FLOAT) == true); REQUIRE(knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(knowhere::IndexEnum::INDEX_DISKANN,