diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index e3f3d8864..d4a5ee851 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -278,7 +278,6 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); index = std::make_unique(qzr, dim, nlist, metric.value(), is_cosine); index->train(rows, (const float*)data); - index->make_direct_map(true); } if constexpr (std::is_same::value) { const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast(cfg); @@ -287,6 +286,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { index = std::make_unique(qzr, dim, nlist, ivf_flat_cc_cfg.ssize.value(), metric.value(), is_cosine); index->train(rows, (const float*)data); + // ivfflat_cc has no serialize stage, make map at build stage index->make_direct_map(true, faiss::DirectMap::ConcurrentArray); } if constexpr (std::is_same::value) { @@ -296,7 +296,6 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); index = std::make_unique(qzr, dim, nlist, ivf_pq_cfg.m.value(), nbits, metric.value()); index->train(rows, (const float*)data); - index->make_direct_map(true); } if constexpr (std::is_same::value) { const ScannConfig& scann_cfg = static_cast(cfg); @@ -320,7 +319,6 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { index = std::make_unique(qzr, dim, nlist, faiss::QuantizerType::QT_8bit, metric.value()); index->train(rows, (const float*)data); - index->make_direct_map(true); } if constexpr (std::is_same::value) { const IvfBinConfig& ivf_bin_cfg = static_cast(cfg); @@ -328,7 +326,6 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); index = std::make_unique(qzr, dim, nlist, metric.value()); index->train(rows, (const uint8_t*)data); - index->make_direct_map(true); } index->own_fields = true; } catch (std::exception& e) { @@ -751,6 +748,12 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { } else { index_.reset(static_cast(faiss::read_index(&reader))); } + if constexpr (!std::is_same_v) { + const BaseConfig& base_cfg = static_cast(config); + if (HasRawData(base_cfg.metric_type.value())) { + index_->make_direct_map(true); + } + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -773,6 +776,12 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& } else { index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); } + if constexpr (!std::is_same_v) { + const BaseConfig& base_cfg = static_cast(config); + if (HasRawData(base_cfg.metric_type.value())) { + index_->make_direct_map(true); + } + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error;