From 540d15b6014b97118be62f2ea0d62e5c9c578266 Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 25 Dec 2024 18:38:36 +0800 Subject: [PATCH] enhance: remove extra memory copy of bf16/bf16 brustforce search Signed-off-by: cqy123456 --- internal/core/src/query/SearchBruteForce.cpp | 30 ++++---------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 38887a150909e..7c8cdd138c698 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -82,18 +82,6 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds, if (data_type == DataType::VECTOR_SPARSE_FLOAT) { base_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true); - } else if (data_type == DataType::VECTOR_BFLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, remove convert - base_dataset = - knowhere::ConvertFromDataTypeIfNeeded(base_dataset); - query_dataset = - knowhere::ConvertFromDataTypeIfNeeded(query_dataset); - } else if (data_type == DataType::VECTOR_FLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, remove convert - base_dataset = - knowhere::ConvertFromDataTypeIfNeeded(base_dataset); - query_dataset = - knowhere::ConvertFromDataTypeIfNeeded(query_dataset); } base_dataset->SetTensorBeginId(raw_ds.begin_id); return std::make_pair(query_dataset, base_dataset); @@ -133,12 +121,10 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_FLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, change it - res = knowhere::BruteForce::RangeSearch( + res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BFLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, change it - res = knowhere::BruteForce::RangeSearch( + res = knowhere::BruteForce::RangeSearch( base_dataset, query_dataset, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BINARY) { res = knowhere::BruteForce::RangeSearch( @@ -178,8 +164,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, search_cfg, bitset); } else if (data_type == DataType::VECTOR_FLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, change it - stat = knowhere::BruteForce::SearchWithBuf( + stat = knowhere::BruteForce::SearchWithBuf( base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(), @@ -187,8 +172,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, search_cfg, bitset); } else if (data_type == DataType::VECTOR_BFLOAT16) { - //todo: if knowhere support real fp16/bf16 bf, change it - stat = knowhere::BruteForce::SearchWithBuf( + stat = knowhere::BruteForce::SearchWithBuf( base_dataset, query_dataset, sub_result.mutable_seg_offsets().data(), @@ -238,13 +222,11 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset, base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_FLOAT16: - //todo: if knowhere support real fp16/bf16 bf, change it - return knowhere::BruteForce::AnnIterator( + return knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_BFLOAT16: - //todo: if knowhere support real fp16/bf16 bf, change it - return knowhere::BruteForce::AnnIterator( + return knowhere::BruteForce::AnnIterator( base_dataset, query_dataset, config, bitset); break; case DataType::VECTOR_SPARSE_FLOAT: