diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h index 0db13a3df..3af9b9bba 100644 --- a/include/knowhere/operands.h +++ b/include/knowhere/operands.h @@ -170,6 +170,8 @@ template using KnowhereDataTypeCheck = TypeMatch; template using KnowhereFloatTypeCheck = TypeMatch; +template +using KnowhereHalfPrecisionFloatPointTypeCheck = TypeMatch; template struct MockData { diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 80ee7154f..1b0f5e97b 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -17,6 +17,7 @@ #include "faiss/MetricType.h" #include "faiss/utils/binary_distances.h" #include "faiss/utils/distances.h" +#include "faiss/utils/half_precision_floating_point_distances.h" #include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/thread_pool.h" #include "knowhere/config.h" @@ -26,6 +27,7 @@ #include "knowhere/range_util.h" #include "knowhere/sparse_utils.h" #include "knowhere/utils.h" +#include "simd/hook.h" #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) #include "knowhere/tracer.h" @@ -61,22 +63,55 @@ GetDocValueComputer(const BruteForceConfig& cfg) { } } +template +std::unique_ptr +GetVecNorms(const DataSetPtr& base) { + typedef float (*NormComputer)(const DataType*, size_t); + NormComputer norm_computer; + if constexpr (std::is_same_v) { + norm_computer = faiss::fvec_norm_L2sqr; + } else if constexpr (std::is_same_v) { + norm_computer = faiss::fp16_vec_norm_L2sqr; + } else if constexpr (std::is_same_v) { + norm_computer = faiss::bf16_vec_norm_L2sqr; + } else { + return nullptr; + } + auto xb = (DataType*)base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); + auto norms = std::make_unique(nb); + + // use build thread pool to compute norms + auto pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + constexpr int64_t chunk_size = 8192; + auto chunk_num = (nb + chunk_size - 1) / chunk_size; + futs.reserve(chunk_num); + for (auto i = 0; i < nb; i += chunk_size) { + auto last = std::min(i + chunk_size, nb); + futs.emplace_back(pool->push([&, beg_id = i, end_id = last] { + for (auto j = beg_id; j < end_id; j++) { + norms[j] = std::sqrt(norm_computer(xb + j * dim, dim)); + } + })); + } + WaitAllSuccess(futs); + return norms; +} } // namespace template expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto xb_id_offset = base->GetTensorBeginId(); - auto dim = base->GetDim(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); + auto dim = base_dataset->GetDim(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; @@ -114,12 +149,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset int topk = cfg.k.value(); auto labels = std::make_unique(nq * topk); auto distances = std::make_unique(nq * topk); - std::unique_ptr norms = nullptr; - if (is_cosine) { - ThreadPool::ScopedSearchOmpSetter setter(1); - norms = std::make_unique(nb); - faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); - } + std::unique_ptr norms = is_cosine ? GetVecNorms(base_dataset) : nullptr; auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); @@ -132,22 +162,49 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; + auto cur_query = (const DataType*)xq + dim * index; switch (faiss_metric_type) { case faiss::METRIC_L2: { - auto cur_query = (const float*)xq + dim * index; - faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + if constexpr (std::is_same_v) { + faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_knn_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, topk, + cur_distances, cur_labels, nullptr, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; + return Status::faiss_inner_error; + } break; } case faiss::METRIC_INNER_PRODUCT: { - auto cur_query = (const float*)xq + dim * index; faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, - id_selector); + if constexpr (std::is_same_v) { + auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, + id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + // normalize query vector may cause presision loss, so div query norms in apply function + + faiss::half_precision_floating_point_knn_cosine(cur_query, (const DataType*)xb, norms.get(), + dim, 1, nb, topk, cur_distances, cur_labels, + id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + return Status::faiss_inner_error; + } } else { - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + if constexpr (std::is_same_v) { + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_knn_inner_product(cur_query, (const DataType*)xb, dim, + 1, nb, topk, cur_distances, + cur_labels, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Inner Product only used in floating point data"; + return Status::faiss_inner_error; + } } break; } @@ -210,16 +267,13 @@ template Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); - auto xb_id_offset = base->GetTensorBeginId(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH)); @@ -254,13 +308,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ auto labels = ids; auto distances = dis; - std::unique_ptr norms = nullptr; - if (is_cosine) { - ThreadPool::ScopedSearchOmpSetter setter(1); - norms = std::make_unique(nb); - faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); - } - + std::unique_ptr norms = is_cosine ? GetVecNorms(base_dataset) : nullptr; auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); @@ -272,23 +320,49 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - + auto cur_query = (const DataType*)xq + dim * index; switch (faiss_metric_type) { case faiss::METRIC_L2: { - auto cur_query = (const float*)xq + dim * index; - faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + if constexpr (std::is_same_v) { + faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_knn_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, topk, + cur_distances, cur_labels, nullptr, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; + return Status::faiss_inner_error; + } break; } case faiss::METRIC_INNER_PRODUCT: { - auto cur_query = (const float*)xq + dim * index; faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, - id_selector); + if constexpr (std::is_same_v) { + auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, + id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + // normalize query vector may cause presision loss, so div query norms in apply function + + faiss::half_precision_floating_point_knn_cosine(cur_query, (const DataType*)xb, norms.get(), + dim, 1, nb, topk, cur_distances, cur_labels, + id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + return Status::faiss_inner_error; + } } else { - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + if constexpr (std::is_same_v) { + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_knn_inner_product(cur_query, (const DataType*)xb, dim, + 1, nb, topk, cur_distances, + cur_labels, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Inner Product only used in floating point data"; + return Status::faiss_inner_error; + } } break; } @@ -349,20 +423,14 @@ template expected BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - DataSetPtr base(base_dataset); DataSetPtr query(query_dataset); - bool is_sparse = std::is_same>::value; - if (!is_sparse) { - base = ConvertFromDataTypeIfNeeded(base_dataset); - query = ConvertFromDataTypeIfNeeded(query_dataset); - } - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); - auto xb_id_offset = base->GetTensorBeginId(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; @@ -397,7 +465,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da faiss::MetricType faiss_metric_type; sparse::DocValueComputer sparse_computer; - if (!is_sparse) { + if constexpr (!std::is_same_v>) { auto result = Str2FaissMetricType(metric_str); if (result.error() != Status::success) { return expected::Err(result.error(), result.what()); @@ -422,18 +490,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da std::vector> result_id_array(nq); std::vector> result_dist_array(nq); - std::unique_ptr norms = nullptr; - if (is_cosine) { - ThreadPool::ScopedSearchOmpSetter setter(1); - norms = std::make_unique(nb); - faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); - } - + std::unique_ptr norms = is_cosine ? GetVecNorms(base_dataset) : nullptr; std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(pool->push([&, index = i] { - if (is_sparse) { + if constexpr (std::is_same_v>) { auto cur_query = (const sparse::SparseRow*)xq + index; auto xb_sparse = (const sparse::SparseRow*)xb; for (int j = 0; j < nb; ++j) { @@ -454,64 +516,90 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da } } return Status::success; - } - // else not sparse: - ThreadPool::ScopedSearchOmpSetter setter(1); - faiss::RangeSearchResult res(1); - - BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); - faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - - switch (faiss_metric_type) { - case faiss::METRIC_L2: { - auto cur_query = (const float*)xq + dim * index; - faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, id_selector); - break; - } - case faiss::METRIC_INNER_PRODUCT: { - is_ip = true; - auto cur_query = (const float*)xq + dim * index; - if (is_cosine) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::range_search_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, - radius, &res, id_selector); - } else { - faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, - id_selector); + } else { + // else not sparse: + ThreadPool::ScopedSearchOmpSetter setter(1); + faiss::RangeSearchResult res(1); + + BitsetViewIDSelector bw_idselector(bitset, xb_id_offset); + faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; + auto cur_query = (const DataType*)xq + dim * index; + switch (faiss_metric_type) { + case faiss::METRIC_L2: { + if constexpr (std::is_same_v) { + faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, + id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_range_search_L2sqr(cur_query, (const DataType*)xb, dim, + 1, nb, radius, &res, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; + return Status::faiss_inner_error; + } + break; + } + case faiss::METRIC_INNER_PRODUCT: { + is_ip = true; + if (is_cosine) { + if constexpr (std::is_same_v) { + auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + faiss::range_search_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, + nb, radius, &res, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + // normalize query vector may cause presision loss, so div query norms in apply function + + faiss::half_precision_floating_point_range_search_cosine( + cur_query, (const DataType*)xb, norms.get(), dim, 1, nb, radius, &res, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + return Status::faiss_inner_error; + } + } else { + if constexpr (std::is_same_v) { + faiss::range_search_inner_product(cur_query, (const DataType*)xb, dim, 1, nb, radius, + &res, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_range_search_inner_product( + cur_query, (const DataType*)xb, dim, 1, nb, radius, &res, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Inner Product only used in floating point data"; + return Status::faiss_inner_error; + } + } + break; + } + case faiss::METRIC_Jaccard: { + auto cur_query = (const uint8_t*)xq + (dim / 8) * index; + faiss::binary_range_search, float>( + faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, + id_selector); + break; + } + case faiss::METRIC_Hamming: { + auto cur_query = (const uint8_t*)xq + (dim / 8) * index; + faiss::binary_range_search, int>( + faiss::METRIC_Hamming, cur_query, (const uint8_t*)xb, 1, nb, (int)radius, dim / 8, &res, + id_selector); + break; + } + default: { + LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); + return Status::invalid_metric_type; } - break; - } - case faiss::METRIC_Jaccard: { - auto cur_query = (const uint8_t*)xq + (dim / 8) * index; - faiss::binary_range_search, float>(faiss::METRIC_Jaccard, cur_query, - (const uint8_t*)xb, 1, nb, radius, - dim / 8, &res, id_selector); - break; } - case faiss::METRIC_Hamming: { - auto cur_query = (const uint8_t*)xq + (dim / 8) * index; - faiss::binary_range_search, int>(faiss::METRIC_Hamming, cur_query, - (const uint8_t*)xb, 1, nb, (int)radius, - dim / 8, &res, id_selector); - break; + auto elem_cnt = res.lims[1]; + result_dist_array[index].resize(elem_cnt); + result_id_array[index].resize(elem_cnt); + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[index][j] = res.distances[j]; + result_id_array[index][j] = res.labels[j] + xb_id_offset; } - default: { - LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); - return Status::invalid_metric_type; + if (cfg.range_filter.value() != defaultRangeFilter) { + FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius, + range_filter); } + return Status::success; } - auto elem_cnt = res.lims[1]; - result_dist_array[index].resize(elem_cnt); - result_id_array[index].resize(elem_cnt); - for (size_t j = 0; j < elem_cnt; j++) { - result_dist_array[index][j] = res.distances[j]; - result_id_array[index][j] = res.labels[j] + xb_id_offset; - } - if (cfg.range_filter.value() != defaultRangeFilter) { - FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius, - range_filter); - } - return Status::success; })); } auto ret = WaitAllSuccess(futs); @@ -658,16 +746,13 @@ template expected> BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); - auto xb_id_offset = base->GetTensorBeginId(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); + auto xb_id_offset = base_dataset->GetTensorBeginId(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; auto status = Config::Load(cfg, config, knowhere::ITERATOR, &msg); @@ -699,13 +784,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da faiss::MetricType faiss_metric_type = result.value(); bool is_cosine = IsMetricType(metric_str, metric::COSINE); - std::unique_ptr norms = nullptr; - if (is_cosine) { - ThreadPool::ScopedSearchOmpSetter setter(1); - norms = std::make_unique(nb); - faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); - } - + std::unique_ptr norms = is_cosine ? GetVecNorms(base_dataset) : nullptr; auto pool = ThreadPool::GetGlobalSearchThreadPool(); auto vec = std::vector(nq, nullptr); std::vector> futs; @@ -720,21 +799,47 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine; auto max_dis = larger_is_closer ? std::numeric_limits::lowest() : std::numeric_limits::max(); std::vector distances_ids(nb, {-1, max_dis}); - + auto cur_query = (const DataType*)xq + dim * index; switch (faiss_metric_type) { case faiss::METRIC_L2: { - auto cur_query = (const float*)xq + dim * index; - faiss::all_L2sqr(cur_query, (const float*)xb, dim, 1, nb, distances_ids, nullptr, id_selector); + auto cur_query = (const DataType*)xq + dim * index; + if constexpr (std::is_same_v) { + faiss::all_L2sqr(cur_query, (const float*)xb, dim, 1, nb, distances_ids, nullptr, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_all_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, + distances_ids, nullptr, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric L2 only used in floating point data"; + return Status::faiss_inner_error; + } break; } case faiss::METRIC_INNER_PRODUCT: { - auto cur_query = (const float*)xq + dim * index; if (is_cosine) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, distances_ids, - id_selector); + if constexpr (std::is_same_v) { + auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, + distances_ids, id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + // normalize query vector may cause presision loss, so div query norms in apply function + + faiss::half_precision_floating_point_all_cosine(cur_query, (const DataType*)xb, norms.get(), + dim, 1, nb, distances_ids, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + return Status::faiss_inner_error; + } } else { - faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, id_selector); + if constexpr (std::is_same_v) { + faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, + id_selector); + } else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck::value) { + faiss::half_precision_floating_point_all_inner_product(cur_query, (const DataType*)xb, dim, + 1, nb, distances_ids, id_selector); + } else { + LOG_KNOWHERE_ERROR_ << "Metric Cosine only used in floating point data"; + return Status::faiss_inner_error; + } } break; } diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index a57830a59..914b97766 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -58,19 +58,22 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END float fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); - auto my_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y)); - auto mx_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + 8))); - auto my_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 1)); + + auto my = _mm256_loadu_si256((__m256i*)y); + auto my_0 = _mm256_cvtph_ps(_mm256_extracti128_si256(my, 0)); + auto my_1 = _mm256_cvtph_ps(_mm256_extracti128_si256(my, 1)); + msum_0 = _mm256_fmadd_ps(mx_0, my_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, my_1, msum_1); + auto msum_1 = _mm256_mul_ps(mx_1, my_1); + msum_0 = msum_0 + msum_1; x += 16; y += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); auto my = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y)); @@ -91,19 +94,22 @@ fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, siz float bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); - auto my_0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y)); - auto mx_1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)(x + 8))); - auto my_1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)(y + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 1)); + + auto my = _mm256_loadu_si256((__m256i*)y); + auto my_0 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(my, 0)); + auto my_1 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(my, 1)); + msum_0 = _mm256_fmadd_ps(mx_0, my_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, my_1, msum_1); + auto msum_1 = _mm256_mul_ps(mx_1, my_1); + msum_0 = msum_0 + msum_1; x += 16; y += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); auto my = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y)); @@ -153,21 +159,24 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END float fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); - auto my_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y)); - auto mx_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + 8))); - auto my_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 1)); + + auto my = _mm256_loadu_si256((__m256i*)y); + auto my_0 = _mm256_cvtph_ps(_mm256_extracti128_si256(my, 0)); + auto my_1 = _mm256_cvtph_ps(_mm256_extracti128_si256(my, 1)); + mx_0 = _mm256_sub_ps(mx_0, my_0); mx_1 = _mm256_sub_ps(mx_1, my_1); msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + msum_0 = _mm256_fmadd_ps(mx_1, mx_1, msum_0); + x += 16; y += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); auto my = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y)); @@ -190,21 +199,22 @@ fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { float bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); - auto my_0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y)); - auto mx_1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)(x + 8))); - auto my_1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)(y + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 1)); + + auto my = _mm256_loadu_si256((__m256i*)y); + auto my_0 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(my, 0)); + auto my_1 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(my, 1)); mx_0 = _mm256_sub_ps(mx_0, my_0); mx_1 = _mm256_sub_ps(mx_1, my_1); msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + msum_0 = _mm256_fmadd_ps(mx_1, mx_1, msum_0); x += 16; y += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); auto my = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y)); @@ -438,6 +448,200 @@ fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +void +fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m256 msum_0 = _mm256_setzero_ps(); + __m256 msum_1 = _mm256_setzero_ps(); + __m256 msum_2 = _mm256_setzero_ps(); + __m256 msum_3 = _mm256_setzero_ps(); + + size_t cur_d = d; + while (cur_d >= 8) { + auto mx = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); + auto my0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y0)); + auto my1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y1)); + auto my2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y2)); + auto my3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y3)); + msum_0 = _mm256_fmadd_ps(mx, my0, msum_0); + msum_1 = _mm256_fmadd_ps(mx, my1, msum_1); + msum_2 = _mm256_fmadd_ps(mx, my2, msum_2); + msum_3 = _mm256_fmadd_ps(mx, my3, msum_3); + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + cur_d -= 8; + } + if (cur_d > 0) { + auto mx = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)x)); + auto my0 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y0)); + auto my1 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y1)); + auto my2 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y2)); + auto my3 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y3)); + msum_0 = _mm256_fmadd_ps(mx, my0, msum_0); + msum_1 = _mm256_fmadd_ps(mx, my1, msum_1); + msum_2 = _mm256_fmadd_ps(mx, my2, msum_2); + msum_3 = _mm256_fmadd_ps(mx, my3, msum_3); + } + dis0 = _mm256_reduce_add_ps(msum_0); + dis1 = _mm256_reduce_add_ps(msum_1); + dis2 = _mm256_reduce_add_ps(msum_2); + dis3 = _mm256_reduce_add_ps(msum_3); + return; +} + +void +bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m256 msum_0 = _mm256_setzero_ps(); + __m256 msum_1 = _mm256_setzero_ps(); + __m256 msum_2 = _mm256_setzero_ps(); + __m256 msum_3 = _mm256_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 8) { + auto mx = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); + auto my0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y0)); + auto my1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y1)); + auto my2 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y2)); + auto my3 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y3)); + msum_0 = _mm256_fmadd_ps(mx, my0, msum_0); + msum_1 = _mm256_fmadd_ps(mx, my1, msum_1); + msum_2 = _mm256_fmadd_ps(mx, my2, msum_2); + msum_3 = _mm256_fmadd_ps(mx, my3, msum_3); + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + cur_d -= 8; + } + if (cur_d > 0) { + auto mx = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)x)); + auto my0 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y0)); + auto my1 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y1)); + auto my2 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y2)); + auto my3 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y3)); + msum_0 = _mm256_fmadd_ps(mx, my0, msum_0); + msum_1 = _mm256_fmadd_ps(mx, my1, msum_1); + msum_2 = _mm256_fmadd_ps(mx, my2, msum_2); + msum_3 = _mm256_fmadd_ps(mx, my3, msum_3); + } + dis0 = _mm256_reduce_add_ps(msum_0); + dis1 = _mm256_reduce_add_ps(msum_1); + dis2 = _mm256_reduce_add_ps(msum_2); + dis3 = _mm256_reduce_add_ps(msum_3); + + return; +} + +void +fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + __m256 msum_0 = _mm256_setzero_ps(); + __m256 msum_1 = _mm256_setzero_ps(); + __m256 msum_2 = _mm256_setzero_ps(); + __m256 msum_3 = _mm256_setzero_ps(); + auto cur_d = d; + while (cur_d >= 8) { + auto mx = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); + auto my0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y0)); + auto my1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y1)); + auto my2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y2)); + auto my3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)y3)); + my0 = _mm256_sub_ps(mx, my0); + msum_0 = _mm256_fmadd_ps(my0, my0, msum_0); + my1 = _mm256_sub_ps(mx, my1); + msum_1 = _mm256_fmadd_ps(my1, my1, msum_1); + my2 = _mm256_sub_ps(mx, my2); + msum_2 = _mm256_fmadd_ps(my2, my2, msum_2); + my3 = _mm256_sub_ps(mx, my3); + msum_3 = _mm256_fmadd_ps(my3, my3, msum_3); + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + cur_d -= 8; + } + if (cur_d > 0) { + auto mx = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)x)); + auto my0 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y0)); + auto my1 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y1)); + auto my2 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y2)); + auto my3 = _mm256_cvtph_ps(mm_masked_read_short(cur_d, (uint16_t*)y3)); + my0 = _mm256_sub_ps(mx, my0); + my1 = _mm256_sub_ps(mx, my1); + my2 = _mm256_sub_ps(mx, my2); + my3 = _mm256_sub_ps(mx, my3); + msum_0 = _mm256_fmadd_ps(my0, my0, msum_0); + msum_1 = _mm256_fmadd_ps(my1, my1, msum_1); + msum_2 = _mm256_fmadd_ps(my2, my2, msum_2); + msum_3 = _mm256_fmadd_ps(my3, my3, msum_3); + } + dis0 = _mm256_reduce_add_ps(msum_0); + dis1 = _mm256_reduce_add_ps(msum_1); + dis2 = _mm256_reduce_add_ps(msum_2); + dis3 = _mm256_reduce_add_ps(msum_3); + return; +} + +void +bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + __m256 msum_0 = _mm256_setzero_ps(); + __m256 msum_1 = _mm256_setzero_ps(); + __m256 msum_2 = _mm256_setzero_ps(); + __m256 msum_3 = _mm256_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 8) { + auto mx = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); + auto my0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y0)); + auto my1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y1)); + auto my2 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y2)); + auto my3 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)y3)); + my0 = _mm256_sub_ps(mx, my0); + my1 = _mm256_sub_ps(mx, my1); + my2 = _mm256_sub_ps(mx, my2); + my3 = _mm256_sub_ps(mx, my3); + msum_0 = _mm256_fmadd_ps(my0, my0, msum_0); + msum_1 = _mm256_fmadd_ps(my1, my1, msum_1); + msum_2 = _mm256_fmadd_ps(my2, my2, msum_2); + msum_3 = _mm256_fmadd_ps(my3, my3, msum_3); + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + cur_d -= 8; + } + if (cur_d > 0) { + auto mx = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)x)); + auto my0 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y0)); + auto my1 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y1)); + auto my2 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y2)); + auto my3 = _mm256_bf16_to_fp32(mm_masked_read_short(cur_d, (uint16_t*)y3)); + my0 = _mm256_sub_ps(mx, my0); + my1 = _mm256_sub_ps(mx, my1); + my2 = _mm256_sub_ps(mx, my2); + my3 = _mm256_sub_ps(mx, my3); + msum_0 = _mm256_fmadd_ps(my0, my0, msum_0); + msum_1 = _mm256_fmadd_ps(my1, my1, msum_1); + msum_2 = _mm256_fmadd_ps(my2, my2, msum_2); + msum_3 = _mm256_fmadd_ps(my3, my3, msum_3); + } + dis0 = _mm256_reduce_add_ps(msum_0); + dis1 = _mm256_reduce_add_ps(msum_1); + dis2 = _mm256_reduce_add_ps(msum_2); + dis3 = _mm256_reduce_add_ps(msum_3); + return; +} + // trust the compiler to unroll this properly int32_t ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) { @@ -464,16 +668,15 @@ ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) { float fvec_norm_L2sqr_avx(const float* x, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { auto mx_0 = _mm256_loadu_ps(x); auto mx_1 = _mm256_loadu_ps(x + 8); msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + auto msum_1 = _mm256_mul_ps(mx_1, mx_1); + msum_0 = msum_0 + msum_1; x += 16; d -= 16; } - msum_0 = msum_0 + msum_1; if (d >= 8) { auto mx = _mm256_loadu_ps(x); msum_0 = _mm256_fmadd_ps(mx, mx, msum_0); @@ -501,16 +704,16 @@ fvec_norm_L2sqr_avx(const float* x, size_t d) { float fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); - auto mx_1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_cvtph_ps(_mm256_extracti128_si256(mx, 1)); msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + auto msum_1 = _mm256_mul_ps(mx_1, mx_1); + msum_0 = msum_0 + msum_1; x += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)x)); msum_0 = _mm256_fmadd_ps(mx, mx, msum_0); @@ -528,16 +731,16 @@ fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d) { float bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); - __m256 msum_1 = _mm256_setzero_ps(); while (d >= 16) { - auto mx_0 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); - auto mx_1 = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)(x + 8))); + auto mx = _mm256_loadu_si256((__m256i*)x); + auto mx_0 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 0)); + auto mx_1 = _mm256_bf16_to_fp32(_mm256_extracti128_si256(mx, 1)); msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); - msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + auto msum_1 = _mm256_mul_ps(mx_1, mx_1); + msum_0 = msum_0 + msum_1; x += 16; d -= 16; } - msum_0 = msum_0 + msum_1; while (d >= 8) { auto mx = _mm256_bf16_to_fp32(_mm_loadu_si128((__m128i*)x)); msum_0 = _mm256_fmadd_ps(mx, mx, msum_0); diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index 441daee15..774e5a386 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -65,6 +65,16 @@ fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + void fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); @@ -73,6 +83,16 @@ void fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3); + +void +bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3); + int32_t ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index 07b74195e..3a5a49267 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END float fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { __m512 m512_res = _mm512_setzero_ps(); - __m512 m512_res_0 = _mm512_setzero_ps(); while (d >= 32) { auto mx_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x)); auto my_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y)); auto mx_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + 16))); auto my_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + 16))); m512_res = _mm512_fmadd_ps(mx_0, my_0, m512_res); - m512_res_0 = _mm512_fmadd_ps(mx_1, my_1, m512_res_0); + m512_res = _mm512_fmadd_ps(mx_1, my_1, m512_res); x += 32; y += 32; d -= 32; } - m512_res = m512_res + m512_res_0; if (d >= 16) { auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x)); auto my = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y)); @@ -408,6 +406,201 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* __restrict x, const fl } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +void +fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m512 m512_res_0 = _mm512_setzero_ps(); + __m512 m512_res_1 = _mm512_setzero_ps(); + __m512 m512_res_2 = _mm512_setzero_ps(); + __m512 m512_res_3 = _mm512_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 16) { + auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x)); + auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0)); + auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1)); + auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2)); + auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3)); + m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3); + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + cur_d -= 16; + } + if (cur_d > 0) { + const __mmask16 mask = (1U << cur_d) - 1U; + auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x)); + auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0)); + auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1)); + auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2)); + auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3)); + m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3); + } + dis0 = _mm512_reduce_add_ps(m512_res_0); + dis1 = _mm512_reduce_add_ps(m512_res_1); + dis2 = _mm512_reduce_add_ps(m512_res_2); + dis3 = _mm512_reduce_add_ps(m512_res_3); + return; +} + +void +bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m512 m512_res_0 = _mm512_setzero_ps(); + __m512 m512_res_1 = _mm512_setzero_ps(); + __m512 m512_res_2 = _mm512_setzero_ps(); + __m512 m512_res_3 = _mm512_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 16) { + auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x)); + auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0)); + auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1)); + auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2)); + auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3)); + m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3); + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + cur_d -= 16; + } + if (cur_d > 0) { + const __mmask16 mask = (1U << cur_d) - 1U; + auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x)); + auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0)); + auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1)); + auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2)); + auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3)); + m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3); + } + dis0 = _mm512_reduce_add_ps(m512_res_0); + dis1 = _mm512_reduce_add_ps(m512_res_1); + dis2 = _mm512_reduce_add_ps(m512_res_2); + dis3 = _mm512_reduce_add_ps(m512_res_3); + return; +} + +void +fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m512 m512_res_0 = _mm512_setzero_ps(); + __m512 m512_res_1 = _mm512_setzero_ps(); + __m512 m512_res_2 = _mm512_setzero_ps(); + __m512 m512_res_3 = _mm512_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 16) { + auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x)); + auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0)); + auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1)); + auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2)); + auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3)); + my0 = _mm512_sub_ps(mx, my0); + my1 = _mm512_sub_ps(mx, my1); + my2 = _mm512_sub_ps(mx, my2); + my3 = _mm512_sub_ps(mx, my3); + m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3); + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + cur_d -= 16; + } + if (cur_d > 0) { + const __mmask16 mask = (1U << cur_d) - 1U; + auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x)); + auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0)); + auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1)); + auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2)); + auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3)); + my0 = _mm512_sub_ps(mx, my0); + my1 = _mm512_sub_ps(mx, my1); + my2 = _mm512_sub_ps(mx, my2); + my3 = _mm512_sub_ps(mx, my3); + m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3); + } + dis0 = _mm512_reduce_add_ps(m512_res_0); + dis1 = _mm512_reduce_add_ps(m512_res_1); + dis2 = _mm512_reduce_add_ps(m512_res_2); + dis3 = _mm512_reduce_add_ps(m512_res_3); + return; +} + +void +bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + __m512 m512_res_0 = _mm512_setzero_ps(); + __m512 m512_res_1 = _mm512_setzero_ps(); + __m512 m512_res_2 = _mm512_setzero_ps(); + __m512 m512_res_3 = _mm512_setzero_ps(); + size_t cur_d = d; + while (cur_d >= 16) { + auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x)); + auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0)); + auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1)); + auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2)); + auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3)); + my0 = mx - my0; + my1 = mx - my1; + my2 = mx - my2; + my3 = mx - my3; + m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3); + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + cur_d -= 16; + } + if (cur_d > 0) { + const __mmask16 mask = (1U << cur_d) - 1U; + auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x)); + auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0)); + auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1)); + auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2)); + auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3)); + my0 = _mm512_sub_ps(mx, my0); + my1 = _mm512_sub_ps(mx, my1); + my2 = _mm512_sub_ps(mx, my2); + my3 = _mm512_sub_ps(mx, my3); + m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0); + m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1); + m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2); + m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3); + } + dis0 = _mm512_reduce_add_ps(m512_res_0); + dis1 = _mm512_reduce_add_ps(m512_res_1); + dis2 = _mm512_reduce_add_ps(m512_res_2); + dis3 = _mm512_reduce_add_ps(m512_res_3); + return; +} // trust the compiler to unroll this properly FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index 36096d005..bea6dc237 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -64,6 +64,16 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* x, const float* y0, co const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + void fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); @@ -72,6 +82,16 @@ void fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + int32_t ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc index 0b600673f..77e26ce4d 100644 --- a/src/simd/distances_neon.cc +++ b/src/simd/distances_neon.cc @@ -2090,5 +2090,633 @@ fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* dis3 = vaddvq_f32(sum_.val[3]); } +void +fp16_vec_inner_product_batch_4_neon(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + // res store sub result of {4*dis0, 4*dis1, d4*is2, 4*dis3} + float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + auto cur_d = d; + while (cur_d >= 16) { + float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); + float32x4x4_t b0 = vcvt4_f32_f16(vld4_f16((const __fp16*)y0)); + float32x4x4_t b1 = vcvt4_f32_f16(vld4_f16((const __fp16*)y1)); + float32x4x4_t b2 = vcvt4_f32_f16(vld4_f16((const __fp16*)y2)); + float32x4x4_t b3 = vcvt4_f32_f16(vld4_f16((const __fp16*)y3)); + + res.val[0] = vmlaq_f32(res.val[0], a.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], a.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], a.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], a.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], a.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], a.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], a.val[1], b3.val[1]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[2], b0.val[2]); + res.val[1] = vmlaq_f32(res.val[1], a.val[2], b1.val[2]); + res.val[2] = vmlaq_f32(res.val[2], a.val[2], b2.val[2]); + res.val[3] = vmlaq_f32(res.val[3], a.val[2], b3.val[2]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[3], b0.val[3]); + res.val[1] = vmlaq_f32(res.val[1], a.val[3], b1.val[3]); + res.val[2] = vmlaq_f32(res.val[2], a.val[3], b2.val[3]); + res.val[3] = vmlaq_f32(res.val[3], a.val[3], b3.val[3]); + + cur_d -= 16; + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + } + if (cur_d >= 8) { + float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); + float32x4x2_t b0 = vcvt2_f32_f16(vld2_f16((const __fp16*)y0)); + float32x4x2_t b1 = vcvt2_f32_f16(vld2_f16((const __fp16*)y1)); + float32x4x2_t b2 = vcvt2_f32_f16(vld2_f16((const __fp16*)y2)); + float32x4x2_t b3 = vcvt2_f32_f16(vld2_f16((const __fp16*)y3)); + res.val[0] = vmlaq_f32(res.val[0], a.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], a.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], a.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], a.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], a.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], a.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], a.val[1], b3.val[1]); + + cur_d -= 8; + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + } + if (cur_d >= 4) { + float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); + float32x4_t b0 = vcvt_f32_f16(vld1_f16((const __fp16*)y0)); + float32x4_t b1 = vcvt_f32_f16(vld1_f16((const __fp16*)y1)); + float32x4_t b2 = vcvt_f32_f16(vld1_f16((const __fp16*)y2)); + float32x4_t b3 = vcvt_f32_f16(vld1_f16((const __fp16*)y3)); + res.val[0] = vmlaq_f32(res.val[0], a, b0); + res.val[1] = vmlaq_f32(res.val[1], a, b1); + res.val[2] = vmlaq_f32(res.val[2], a, b2); + res.val[3] = vmlaq_f32(res.val[3], a, b3); + cur_d -= 4; + x += 4; + y0 += 4; + y1 += 4; + y2 += 4; + y3 += 4; + } + if (cur_d >= 0) { + float16x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y0 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y1 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y2 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y3 = {0.0f, 0.0f, 0.0f, 0.0f}; + switch (cur_d) { + case 3: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 2); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 2); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 2); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 2); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 2: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 1); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 1); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 1); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 1); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 1: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 0); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 0); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 0); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 0); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + } + res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y0)); + res.val[1] = vmlaq_f32(res.val[1], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y1)); + res.val[2] = vmlaq_f32(res.val[2], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y2)); + res.val[3] = vmlaq_f32(res.val[3], vcvt_f32_f16(res_x), vcvt_f32_f16(res_y3)); + } + dis0 = vaddvq_f32(res.val[0]); + dis1 = vaddvq_f32(res.val[1]); + dis2 = vaddvq_f32(res.val[2]); + dis3 = vaddvq_f32(res.val[3]); + return; +} + +void +bf16_vec_inner_product_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + auto cur_d = d; + while (cur_d >= 16) { + float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); + float32x4x4_t b0 = vcvt4_f32_half(vld4_u16((const uint16_t*)y0)); + float32x4x4_t b1 = vcvt4_f32_half(vld4_u16((const uint16_t*)y1)); + float32x4x4_t b2 = vcvt4_f32_half(vld4_u16((const uint16_t*)y2)); + float32x4x4_t b3 = vcvt4_f32_half(vld4_u16((const uint16_t*)y3)); + + res.val[0] = vmlaq_f32(res.val[0], a.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], a.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], a.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], a.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], a.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], a.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], a.val[1], b3.val[1]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[2], b0.val[2]); + res.val[1] = vmlaq_f32(res.val[1], a.val[2], b1.val[2]); + res.val[2] = vmlaq_f32(res.val[2], a.val[2], b2.val[2]); + res.val[3] = vmlaq_f32(res.val[3], a.val[2], b3.val[2]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[3], b0.val[3]); + res.val[1] = vmlaq_f32(res.val[1], a.val[3], b1.val[3]); + res.val[2] = vmlaq_f32(res.val[2], a.val[3], b2.val[3]); + res.val[3] = vmlaq_f32(res.val[3], a.val[3], b3.val[3]); + + cur_d -= 16; + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + } + if (cur_d >= 8) { + float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); + float32x4x2_t b0 = vcvt2_f32_half(vld2_u16((const uint16_t*)y0)); + float32x4x2_t b1 = vcvt2_f32_half(vld2_u16((const uint16_t*)y1)); + float32x4x2_t b2 = vcvt2_f32_half(vld2_u16((const uint16_t*)y2)); + float32x4x2_t b3 = vcvt2_f32_half(vld2_u16((const uint16_t*)y3)); + res.val[0] = vmlaq_f32(res.val[0], a.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], a.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], a.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], a.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], a.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], a.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], a.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], a.val[1], b3.val[1]); + + cur_d -= 8; + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + } + if (cur_d >= 4) { + float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); + float32x4_t b0 = vcvt_f32_half(vld1_u16((const uint16_t*)y0)); + float32x4_t b1 = vcvt_f32_half(vld1_u16((const uint16_t*)y1)); + float32x4_t b2 = vcvt_f32_half(vld1_u16((const uint16_t*)y2)); + float32x4_t b3 = vcvt_f32_half(vld1_u16((const uint16_t*)y3)); + res.val[0] = vmlaq_f32(res.val[0], a, b0); + res.val[1] = vmlaq_f32(res.val[1], a, b1); + res.val[2] = vmlaq_f32(res.val[2], a, b2); + res.val[3] = vmlaq_f32(res.val[3], a, b3); + cur_d -= 4; + x += 4; + y0 += 4; + y1 += 4; + y2 += 4; + y3 += 4; + } + if (cur_d >= 0) { + uint16x4_t res_x = vdup_n_u16(0); + uint16x4_t res_y0 = vdup_n_u16(0); + uint16x4_t res_y1 = vdup_n_u16(0); + uint16x4_t res_y2 = vdup_n_u16(0); + uint16x4_t res_y3 = vdup_n_u16(0); + switch (cur_d) { + case 3: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 2); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 2); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 2); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 2); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 2: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 1); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 1); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 1); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 1); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 1: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 0); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 0); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 0); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 0); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + } + res.val[0] = vmlaq_f32(res.val[0], vcvt_f32_half(res_x), vcvt_f32_half(res_y0)); + res.val[1] = vmlaq_f32(res.val[1], vcvt_f32_half(res_x), vcvt_f32_half(res_y1)); + res.val[2] = vmlaq_f32(res.val[2], vcvt_f32_half(res_x), vcvt_f32_half(res_y2)); + res.val[3] = vmlaq_f32(res.val[3], vcvt_f32_half(res_x), vcvt_f32_half(res_y3)); + } + dis0 = vaddvq_f32(res.val[0]); + dis1 = vaddvq_f32(res.val[1]); + dis2 = vaddvq_f32(res.val[2]); + dis3 = vaddvq_f32(res.val[3]); + return; +} + +void +fp16_vec_L2sqr_batch_4_neon(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + auto cur_d = d; + while (cur_d >= 16) { + float32x4x4_t a = vcvt4_f32_f16(vld4_f16((const __fp16*)x)); + float32x4x4_t b0 = vcvt4_f32_f16(vld4_f16((const __fp16*)y0)); + float32x4x4_t b1 = vcvt4_f32_f16(vld4_f16((const __fp16*)y1)); + float32x4x4_t b2 = vcvt4_f32_f16(vld4_f16((const __fp16*)y2)); + float32x4x4_t b3 = vcvt4_f32_f16(vld4_f16((const __fp16*)y3)); + + b0.val[0] = vsubq_f32(a.val[0], b0.val[0]); + b0.val[1] = vsubq_f32(a.val[1], b0.val[1]); + b0.val[2] = vsubq_f32(a.val[2], b0.val[2]); + b0.val[3] = vsubq_f32(a.val[3], b0.val[3]); + + b1.val[0] = vsubq_f32(a.val[0], b1.val[0]); + b1.val[1] = vsubq_f32(a.val[1], b1.val[1]); + b1.val[2] = vsubq_f32(a.val[2], b1.val[2]); + b1.val[3] = vsubq_f32(a.val[3], b1.val[3]); + + b2.val[0] = vsubq_f32(a.val[0], b2.val[0]); + b2.val[1] = vsubq_f32(a.val[1], b2.val[1]); + b2.val[2] = vsubq_f32(a.val[2], b2.val[2]); + b2.val[3] = vsubq_f32(a.val[3], b2.val[3]); + + b3.val[0] = vsubq_f32(a.val[0], b3.val[0]); + b3.val[1] = vsubq_f32(a.val[1], b3.val[1]); + b3.val[2] = vsubq_f32(a.val[2], b3.val[2]); + b3.val[3] = vsubq_f32(a.val[3], b3.val[3]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[0], b0.val[0]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[0], b1.val[0]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[0], b2.val[0]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[0], b3.val[0]), res.val[3]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[1], b0.val[1]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[1], b1.val[1]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[1], b2.val[1]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[1], b3.val[1]), res.val[3]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[2], b0.val[2]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[2], b1.val[2]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[2], b2.val[2]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[2], b3.val[2]), res.val[3]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[3], b0.val[3]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[3], b1.val[3]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[3], b2.val[3]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[3], b3.val[3]), res.val[3]); + + cur_d -= 16; + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + } + if (cur_d >= 8) { + float32x4x2_t a = vcvt2_f32_f16(vld2_f16((const __fp16*)x)); + float32x4x2_t b0 = vcvt2_f32_f16(vld2_f16((const __fp16*)y0)); + float32x4x2_t b1 = vcvt2_f32_f16(vld2_f16((const __fp16*)y1)); + float32x4x2_t b2 = vcvt2_f32_f16(vld2_f16((const __fp16*)y2)); + float32x4x2_t b3 = vcvt2_f32_f16(vld2_f16((const __fp16*)y3)); + b0.val[0] = vsubq_f32(a.val[0], b0.val[0]); + b0.val[1] = vsubq_f32(a.val[1], b0.val[1]); + + b1.val[0] = vsubq_f32(a.val[0], b1.val[0]); + b1.val[1] = vsubq_f32(a.val[1], b1.val[1]); + + b2.val[0] = vsubq_f32(a.val[0], b2.val[0]); + b2.val[1] = vsubq_f32(a.val[1], b2.val[1]); + + b3.val[0] = vsubq_f32(a.val[0], b3.val[0]); + b3.val[1] = vsubq_f32(a.val[1], b3.val[1]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[0], b0.val[0]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[0], b1.val[0]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[0], b2.val[0]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[0], b3.val[0]), res.val[3]); + + res.val[0] = vaddq_f32(vmulq_f32(b0.val[1], b0.val[1]), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1.val[1], b1.val[1]), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2.val[1], b2.val[1]), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3.val[1], b3.val[1]), res.val[3]); + + cur_d -= 8; + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + } + if (cur_d >= 4) { + float32x4_t a = vcvt_f32_f16(vld1_f16((const __fp16*)x)); + float32x4_t b0 = vcvt_f32_f16(vld1_f16((const __fp16*)y0)); + float32x4_t b1 = vcvt_f32_f16(vld1_f16((const __fp16*)y1)); + float32x4_t b2 = vcvt_f32_f16(vld1_f16((const __fp16*)y2)); + float32x4_t b3 = vcvt_f32_f16(vld1_f16((const __fp16*)y3)); + + b0 = vsubq_f32(a, b0); + b1 = vsubq_f32(a, b1); + b2 = vsubq_f32(a, b2); + b3 = vsubq_f32(a, b3); + + res.val[0] = vaddq_f32(vmulq_f32(b0, b0), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(b1, b1), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(b2, b2), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(b3, b3), res.val[3]); + cur_d -= 4; + x += 4; + y0 += 4; + y1 += 4; + y2 += 4; + y3 += 4; + } + if (cur_d >= 0) { + float16x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y0 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y1 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y2 = {0.0f, 0.0f, 0.0f, 0.0f}; + float16x4_t res_y3 = {0.0f, 0.0f, 0.0f, 0.0f}; + switch (cur_d) { + case 3: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 2); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 2); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 2); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 2); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 2); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 2: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 1); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 1); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 1); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 1); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 1); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 1: + res_x = vld1_lane_f16((const __fp16*)x, res_x, 0); + res_y0 = vld1_lane_f16((const __fp16*)y0, res_y0, 0); + res_y1 = vld1_lane_f16((const __fp16*)y1, res_y1, 0); + res_y2 = vld1_lane_f16((const __fp16*)y2, res_y2, 0); + res_y3 = vld1_lane_f16((const __fp16*)y3, res_y3, 0); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + } + float32x4_t diff0 = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y0)); + float32x4_t diff1 = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y1)); + float32x4_t diff2 = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y2)); + float32x4_t diff3 = vsubq_f32(vcvt_f32_f16(res_x), vcvt_f32_f16(res_y3)); + res.val[0] = vaddq_f32(vmulq_f32(diff0, diff0), res.val[0]); + res.val[1] = vaddq_f32(vmulq_f32(diff1, diff1), res.val[1]); + res.val[2] = vaddq_f32(vmulq_f32(diff2, diff2), res.val[2]); + res.val[3] = vaddq_f32(vmulq_f32(diff3, diff3), res.val[3]); + } + dis0 = vaddvq_f32(res.val[0]); + dis1 = vaddvq_f32(res.val[1]); + dis2 = vaddvq_f32(res.val[2]); + dis3 = vaddvq_f32(res.val[3]); + return; +} + +void +bf16_vec_L2sqr_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + float32x4x4_t res = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + auto cur_d = d; + while (cur_d >= 16) { + float32x4x4_t a = vcvt4_f32_half(vld4_u16((const uint16_t*)x)); + float32x4x4_t b0 = vcvt4_f32_half(vld4_u16((const uint16_t*)y0)); + float32x4x4_t b1 = vcvt4_f32_half(vld4_u16((const uint16_t*)y1)); + float32x4x4_t b2 = vcvt4_f32_half(vld4_u16((const uint16_t*)y2)); + float32x4x4_t b3 = vcvt4_f32_half(vld4_u16((const uint16_t*)y3)); + + b0.val[0] = vsubq_f32(a.val[0], b0.val[0]); + b0.val[1] = vsubq_f32(a.val[1], b0.val[1]); + b0.val[2] = vsubq_f32(a.val[2], b0.val[2]); + b0.val[3] = vsubq_f32(a.val[3], b0.val[3]); + + b1.val[0] = vsubq_f32(a.val[0], b1.val[0]); + b1.val[1] = vsubq_f32(a.val[1], b1.val[1]); + b1.val[2] = vsubq_f32(a.val[2], b1.val[2]); + b1.val[3] = vsubq_f32(a.val[3], b1.val[3]); + + b2.val[0] = vsubq_f32(a.val[0], b2.val[0]); + b2.val[1] = vsubq_f32(a.val[1], b2.val[1]); + b2.val[2] = vsubq_f32(a.val[2], b2.val[2]); + b2.val[3] = vsubq_f32(a.val[3], b2.val[3]); + + b3.val[0] = vsubq_f32(a.val[0], b3.val[0]); + b3.val[1] = vsubq_f32(a.val[1], b3.val[1]); + b3.val[2] = vsubq_f32(a.val[2], b3.val[2]); + b3.val[3] = vsubq_f32(a.val[3], b3.val[3]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[1], b3.val[1]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[2], b0.val[2]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[2], b1.val[2]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[2], b2.val[2]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[2], b3.val[2]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[3], b0.val[3]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[3], b1.val[3]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[3], b2.val[3]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[3], b3.val[3]); + + cur_d -= 16; + x += 16; + y0 += 16; + y1 += 16; + y2 += 16; + y3 += 16; + } + if (cur_d >= 8) { + float32x4x2_t a = vcvt2_f32_half(vld2_u16((const uint16_t*)x)); + float32x4x2_t b0 = vcvt2_f32_half(vld2_u16((const uint16_t*)y0)); + float32x4x2_t b1 = vcvt2_f32_half(vld2_u16((const uint16_t*)y1)); + float32x4x2_t b2 = vcvt2_f32_half(vld2_u16((const uint16_t*)y2)); + float32x4x2_t b3 = vcvt2_f32_half(vld2_u16((const uint16_t*)y3)); + + b0.val[0] = vsubq_f32(a.val[0], b0.val[0]); + b0.val[1] = vsubq_f32(a.val[1], b0.val[1]); + + b1.val[0] = vsubq_f32(a.val[0], b1.val[0]); + b1.val[1] = vsubq_f32(a.val[1], b1.val[1]); + + b2.val[0] = vsubq_f32(a.val[0], b2.val[0]); + b2.val[1] = vsubq_f32(a.val[1], b2.val[1]); + + b3.val[0] = vsubq_f32(a.val[0], b3.val[0]); + b3.val[1] = vsubq_f32(a.val[1], b3.val[1]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[0], b0.val[0]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[0], b1.val[0]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[0], b2.val[0]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[0], b3.val[0]); + + res.val[0] = vmlaq_f32(res.val[0], b0.val[1], b0.val[1]); + res.val[1] = vmlaq_f32(res.val[1], b1.val[1], b1.val[1]); + res.val[2] = vmlaq_f32(res.val[2], b2.val[1], b2.val[1]); + res.val[3] = vmlaq_f32(res.val[3], b3.val[1], b3.val[1]); + + cur_d -= 8; + x += 8; + y0 += 8; + y1 += 8; + y2 += 8; + y3 += 8; + } + if (cur_d >= 4) { + float32x4_t a = vcvt_f32_half(vld1_u16((const uint16_t*)x)); + float32x4_t b0 = vcvt_f32_half(vld1_u16((const uint16_t*)y0)); + float32x4_t b1 = vcvt_f32_half(vld1_u16((const uint16_t*)y1)); + float32x4_t b2 = vcvt_f32_half(vld1_u16((const uint16_t*)y2)); + float32x4_t b3 = vcvt_f32_half(vld1_u16((const uint16_t*)y3)); + b0 = vsubq_f32(a, b0); + b1 = vsubq_f32(a, b1); + b2 = vsubq_f32(a, b2); + b3 = vsubq_f32(a, b3); + + res.val[0] = vmlaq_f32(res.val[0], b0, b0); + res.val[1] = vmlaq_f32(res.val[1], b1, b1); + res.val[2] = vmlaq_f32(res.val[2], b2, b2); + res.val[3] = vmlaq_f32(res.val[3], b3, b3); + cur_d -= 4; + x += 4; + y0 += 4; + y1 += 4; + y2 += 4; + y3 += 4; + } + if (cur_d >= 0) { + uint16x4_t res_x = vdup_n_u16(0); + uint16x4_t res_y0 = vdup_n_u16(0); + uint16x4_t res_y1 = vdup_n_u16(0); + uint16x4_t res_y2 = vdup_n_u16(0); + uint16x4_t res_y3 = vdup_n_u16(0); + switch (cur_d) { + case 3: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 2); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 2); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 2); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 2); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 2); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 2: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 1); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 1); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 1); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 1); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 1); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + case 1: + res_x = vld1_lane_u16((const uint16_t*)x, res_x, 0); + res_y0 = vld1_lane_u16((const uint16_t*)y0, res_y0, 0); + res_y1 = vld1_lane_u16((const uint16_t*)y1, res_y1, 0); + res_y2 = vld1_lane_u16((const uint16_t*)y2, res_y2, 0); + res_y3 = vld1_lane_u16((const uint16_t*)y3, res_y3, 0); + x++; + y0++; + y1++; + y2++; + y3++; + cur_d--; + } + float32x4_t diff0 = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y0)); + float32x4_t diff1 = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y1)); + float32x4_t diff2 = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y2)); + float32x4_t diff3 = vsubq_f32(vcvt_f32_half(res_x), vcvt_f32_half(res_y3)); + res.val[0] = vmlaq_f32(res.val[0], diff0, diff0); + res.val[1] = vmlaq_f32(res.val[1], diff1, diff1); + res.val[2] = vmlaq_f32(res.val[2], diff2, diff2); + res.val[3] = vmlaq_f32(res.val[3], diff3, diff3); + } + dis0 = vaddvq_f32(res.val[0]); + dis1 = vaddvq_f32(res.val[1]); + dis2 = vaddvq_f32(res.val[2]); + dis3 = vaddvq_f32(res.val[3]); + return; +} } // namespace faiss #endif diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h index bb4fd5423..fa6c48b9e 100644 --- a/src/simd/distances_neon.h +++ b/src/simd/distances_neon.h @@ -92,6 +92,16 @@ fvec_inner_product_batch_4_neon_bf16_patch(const float* x, const float* y0, cons const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_inner_product_batch_4_neon(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_inner_product_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void @@ -102,6 +112,16 @@ void fvec_L2sqr_batch_4_neon_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t dim, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_L2sqr_batch_4_neon(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_L2sqr_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + } // namespace faiss #endif /* DISTANCES_NEON_H */ diff --git a/src/simd/distances_ref.cc b/src/simd/distances_ref.cc index 69734845b..5b2fd47a1 100644 --- a/src/simd/distances_ref.cc +++ b/src/simd/distances_ref.cc @@ -343,6 +343,102 @@ fvec_L2sqr_batch_4_ref_bf16_patch(const float* x, const float* y0, const float* dis3 = d3; } +void +fp16_vec_inner_product_batch_4_ref(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + auto x_i = (float)x[i]; + d0 += x_i * (float)y0[i]; + d1 += x_i * (float)y1[i]; + d2 += x_i * (float)y2[i]; + d3 += x_i * (float)y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + +void +bf16_vec_inner_product_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + auto x_i = (float)x[i]; + d0 += x_i * (float)y0[i]; + d1 += x_i * (float)y1[i]; + d2 += x_i * (float)y2[i]; + d3 += x_i * (float)y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + +void +fp16_vec_L2sqr_batch_4_ref(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + auto x_i = (float)x[i]; + const float q0 = x_i - (float)y0[i]; + const float q1 = x_i - (float)y1[i]; + const float q2 = x_i - (float)y2[i]; + const float q3 = x_i - (float)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + +void +bf16_vec_L2sqr_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + for (size_t i = 0; i < d; ++i) { + auto x_i = (float)x[i]; + const float q0 = x_i - (float)y0[i]; + const float q1 = x_i - (float)y1[i]; + const float q2 = x_i - (float)y2[i]; + const float q3 = x_i - (float)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} + int32_t ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) { size_t i; diff --git a/src/simd/distances_ref.h b/src/simd/distances_ref.h index d44320391..6281ffe92 100644 --- a/src/simd/distances_ref.h +++ b/src/simd/distances_ref.h @@ -97,6 +97,16 @@ fvec_inner_product_batch_4_ref_bf16_patch(const float* x, const float* y0, const const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_inner_product_batch_4_ref(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + +void +bf16_vec_inner_product_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, + float& dis1, float& dis2, float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void @@ -107,6 +117,16 @@ void fvec_L2sqr_batch_4_ref_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +void +fp16_vec_L2sqr_batch_4_ref(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1, + const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3); + +void +bf16_vec_L2sqr_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1, + const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, + float& dis2, float& dis3); + int32_t ivec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d); diff --git a/src/simd/hook.cc b/src/simd/hook.cc index d7cbd497c..a44351577 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -75,6 +75,11 @@ decltype(bf16_vec_inner_product) bf16_vec_inner_product = bf16_vec_inner_product decltype(fp16_vec_norm_L2sqr) fp16_vec_norm_L2sqr = fp16_vec_norm_L2sqr_ref; decltype(bf16_vec_norm_L2sqr) bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_ref; +decltype(fp16_vec_inner_product_batch_4) fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_ref; +decltype(bf16_vec_inner_product_batch_4) bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; +decltype(fp16_vec_L2sqr_batch_4) fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_ref; +decltype(bf16_vec_L2sqr_batch_4) bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; + #if defined(__x86_64__) bool cpu_support_avx512() { @@ -205,6 +210,11 @@ fvec_hook(std::string& simd_type) { bf16_vec_L2sqr = bf16_vec_L2sqr_avx512; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_avx512; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_avx512; + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_avx512; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_avx512; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx512; + simd_type = "AVX512"; support_pq_fast_scan = true; } else if (use_avx2 && cpu_support_avx2()) { @@ -233,6 +243,11 @@ fvec_hook(std::string& simd_type) { bf16_vec_L2sqr = bf16_vec_L2sqr_avx; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_avx; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_avx; + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_avx; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_avx; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx; + simd_type = "AVX2"; support_pq_fast_scan = true; } else if (use_sse4_2 && cpu_support_sse4_2()) { @@ -261,6 +276,11 @@ fvec_hook(std::string& simd_type) { bf16_vec_L2sqr = bf16_vec_L2sqr_sse; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_sse; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_ref; + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_ref; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; + simd_type = "SSE4_2"; support_pq_fast_scan = false; } else { @@ -289,6 +309,11 @@ fvec_hook(std::string& simd_type) { bf16_vec_L2sqr = bf16_vec_L2sqr_ref; bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_ref; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_ref; + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_ref; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; + simd_type = "GENERIC"; support_pq_fast_scan = false; } @@ -320,6 +345,11 @@ fvec_hook(std::string& simd_type) { fvec_inner_product_batch_4 = fvec_inner_product_batch_4_neon; fvec_L2sqr_batch_4 = fvec_L2sqr_batch_4_neon; + fp16_vec_inner_product_batch_4 = fp16_vec_inner_product_batch_4_neon; + bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_neon; + fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_neon; + bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_neon; + simd_type = "NEON"; support_pq_fast_scan = true; diff --git a/src/simd/hook.h b/src/simd/hook.h index 47744f6d9..8f984c97b 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -78,12 +78,28 @@ extern int (*fvec_madd_and_argmin)(size_t, const float*, float, const float*, fl extern void (*fvec_inner_product_batch_4)(const float*, const float*, const float*, const float*, const float*, const size_t, float&, float&, float&, float&); +extern void (*fp16_vec_inner_product_batch_4)(const knowhere::fp16*, const knowhere::fp16*, const knowhere::fp16*, + const knowhere::fp16*, const knowhere::fp16*, const size_t, float&, + float&, float&, float&); + +extern void (*bf16_vec_inner_product_batch_4)(const knowhere::bf16*, const knowhere::bf16*, const knowhere::bf16*, + const knowhere::bf16*, const knowhere::bf16*, const size_t, float&, + float&, float&, float&); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. /// todo aguzhva: bring non-ref versions extern void (*fvec_L2sqr_batch_4)(const float*, const float*, const float*, const float*, const float*, const size_t, float&, float&, float&, float&); +extern void (*fp16_vec_L2sqr_batch_4)(const knowhere::fp16*, const knowhere::fp16*, const knowhere::fp16*, + const knowhere::fp16*, const knowhere::fp16*, const size_t, float&, float&, + float&, float&); + +extern void (*bf16_vec_L2sqr_batch_4)(const knowhere::bf16*, const knowhere::bf16*, const knowhere::bf16*, + const knowhere::bf16*, const knowhere::bf16*, const size_t, float&, float&, + float&, float&); + extern int32_t (*ivec_inner_product)(const int8_t*, const int8_t*, size_t); extern int32_t (*ivec_L2sqr)(const int8_t*, const int8_t*, size_t); diff --git a/tests/ut/test_bruteforce.cc b/tests/ut/test_bruteforce.cc index a5c129ab1..c3206f500 100644 --- a/tests/ut/test_bruteforce.cc +++ b/tests/ut/test_bruteforce.cc @@ -16,6 +16,7 @@ #include "knowhere/comp/brute_force.h" #include "knowhere/comp/index_param.h" #include "knowhere/utils.h" +#include "simd/hook.h" #include "utils.h" template @@ -89,6 +90,61 @@ check_range_search(const knowhere::DataSetPtr train_ds, const knowhere::DataSetP } } +template +void +check_search_with_out_ids(const uint64_t nb, const uint64_t nq, const uint64_t dim, const int64_t k, + const knowhere::MetricType metric, const knowhere::Json& conf) { + auto total_train_ds = knowhere::ConvertToDataTypeIfNeeded(GenDataSet(nb, dim)); + auto query_ds = knowhere::ConvertToDataTypeIfNeeded(GenDataSet(nq, dim)); + std::vector block_prefix = {0, 111, 333, 500, 555, 666, 888, 1000}; + + // generate filter id and data + auto filter_bits = GenerateBitsetWithRandomTbitsSet(nb, 100); + knowhere::BitsetView bitset(filter_bits.data(), nb); + + std::vector dis(nq * k, std::numeric_limits::quiet_NaN()); + std::vector ids(nq * k, -1); + if (metric == knowhere::metric::L2) { + faiss::float_maxheap_array_t heaps{nq, k, ids.data(), dis.data()}; + heaps.heapify(); + for (auto i = 0; i < block_prefix.size() - 1; i++) { + auto begin_id = block_prefix[i]; + auto end_id = block_prefix[i + 1]; + auto blk_rows = end_id - begin_id; + auto tensor = (const T*)total_train_ds->GetTensor() + dim * begin_id; + auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id); + auto partial_v = knowhere::BruteForce::Search(blk_train_ds, query_ds, conf, bitset); + REQUIRE(partial_v.has_value()); + auto partial_res = partial_v.value(); + heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq); + } + heaps.reorder(); + } else { + faiss::float_minheap_array_t heaps{nq, k, ids.data(), dis.data()}; + heaps.heapify(); + for (auto i = 0; i < block_prefix.size() - 1; i++) { + auto begin_id = block_prefix[i]; + auto end_id = block_prefix[i + 1]; + auto blk_rows = end_id - begin_id; + auto tensor = (const T*)total_train_ds->GetTensor() + dim * begin_id; + auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id); + auto partial_v = knowhere::BruteForce::Search(blk_train_ds, query_ds, conf, bitset); + REQUIRE(partial_v.has_value()); + auto partial_res = partial_v.value(); + heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq); + } + heaps.reorder(); + } + + auto gt = knowhere::BruteForce::Search(total_train_ds, query_ds, conf, bitset); + auto gt_ids = gt.value()->GetIds(); + const float* gt_dis = gt.value()->GetDistance(); + for (auto i = 0; i < nq * k; i++) { + REQUIRE(gt_ids[i] == ids[i]); + REQUIRE(GetRelativeLoss(gt_dis[i], dis[i]) < 0.00001); + } +} + TEST_CASE("Test Brute Force", "[float vector]") { using Catch::Approx; @@ -204,42 +260,13 @@ TEST_CASE("Test Brute Force with input ids", "[float vector]") { const int64_t nq = 10; const int64_t dim = 128; const int64_t k = 10; + auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::IP, knowhere::metric::COSINE); const knowhere::Json conf = { {knowhere::meta::DIM, dim}, - {knowhere::meta::METRIC_TYPE, "L2"}, + {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, k}, }; - std::vector block_prefix = {0, 111, 333, 500, 555, 666, 888, 1000}; - - // generate filter id and data - auto filter_bits = GenerateBitsetWithRandomTbitsSet(nb, 100); - knowhere::BitsetView bitset(filter_bits.data(), nb); - - const auto total_train_ds = GenDataSet(nb, dim); - const auto query_ds = GenDataSet(nq, dim); - - std::vector dis(nq * k, std::numeric_limits::quiet_NaN()); - std::vector ids(nq * k, -1); - faiss::float_maxheap_array_t heaps{nq, k, ids.data(), dis.data()}; - heaps.heapify(); - for (auto i = 0; i < block_prefix.size() - 1; i++) { - auto begin_id = block_prefix[i]; - auto end_id = block_prefix[i + 1]; - auto blk_rows = end_id - begin_id; - auto tensor = (const float*)total_train_ds->GetTensor() + dim * begin_id; - auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id); - auto partial_v = knowhere::BruteForce::Search(blk_train_ds, query_ds, conf, bitset); - REQUIRE(partial_v.has_value()); - auto partial_res = partial_v.value(); - heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq); - } - heaps.reorder(); - - auto gt = knowhere::BruteForce::Search(total_train_ds, query_ds, conf, bitset); - auto gt_ids = gt.value()->GetIds(); - auto gt_dis = gt.value()->GetDistance(); - for (auto i = 0; i < nq * k; i++) { - REQUIRE(gt_ids[i] == ids[i]); - REQUIRE(gt_dis[i] == dis[i]); - } + check_search_with_out_ids(nb, nq, dim, k, metric, conf); + check_search_with_out_ids(nb, nq, dim, k, metric, conf); + check_search_with_out_ids(nb, nq, dim, k, metric, conf); } diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index a6bcd809c..e024c6984 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -197,16 +197,16 @@ TEST_CASE("Test distance") { faiss::fvec_inner_product_batch_4(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, dis0, dis1, dis2, dis3); - REQUIRE_THAT(dis0, Catch::Matchers::WithinRel(ref_ip_0, batch_tolerance)); - REQUIRE_THAT(dis1, Catch::Matchers::WithinRel(ref_ip_1, batch_tolerance)); - REQUIRE_THAT(dis2, Catch::Matchers::WithinRel(ref_ip_2, batch_tolerance)); - REQUIRE_THAT(dis3, Catch::Matchers::WithinRel(ref_ip_3, batch_tolerance)); + REQUIRE(GetRelativeLoss(dis0, ref_ip_0) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis1, ref_ip_1) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis2, ref_ip_2) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis3, ref_ip_3) < batch_tolerance); faiss::fvec_L2sqr_batch_4(x.get(), y0.get(), y1.get(), y2.get(), y3.get(), dim, dis0, dis1, dis2, dis3); - REQUIRE_THAT(dis0, Catch::Matchers::WithinRel(ref_l2_0, batch_tolerance)); - REQUIRE_THAT(dis1, Catch::Matchers::WithinRel(ref_l2_1, batch_tolerance)); - REQUIRE_THAT(dis2, Catch::Matchers::WithinRel(ref_l2_2, batch_tolerance)); - REQUIRE_THAT(dis3, Catch::Matchers::WithinRel(ref_l2_3, batch_tolerance)); + REQUIRE(GetRelativeLoss(dis0, ref_l2_0) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis1, ref_l2_1) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis2, ref_l2_2) < batch_tolerance); + REQUIRE(GetRelativeLoss(dis3, ref_l2_3) < batch_tolerance); }; tolerance = 0.02f; @@ -220,4 +220,55 @@ TEST_CASE("Test distance") { knowhere::KnowhereConfig::DisablePatchForComputeFP32AsBF16(); run_test(); } + + SECTION("test batch_4 distance calculation of bf16/fp16") { + auto fp16_ds = GenRandomVector(dim, 5, 100); + auto bf16_ds = GenRandomVector(dim, 5, 100); + std::vector fp16_data{fp16_ds.get(), fp16_ds.get() + dim, fp16_ds.get() + 2 * dim, + fp16_ds.get() + 3 * dim, fp16_ds.get() + 4 * dim}; + std::vector bf16_data{bf16_ds.get(), bf16_ds.get() + dim, bf16_ds.get() + 2 * dim, + bf16_ds.get() + 3 * dim, bf16_ds.get() + 4 * dim}; + std::array fp16_l2_dis_gt, fp16_ip_dis_gt, bf16_l2_dis_gt, bf16_ip_dis_gt; + + faiss::fp16_vec_inner_product_batch_4_ref(fp16_data[0], fp16_data[1], fp16_data[2], fp16_data[3], fp16_data[4], + dim, fp16_ip_dis_gt[0], fp16_ip_dis_gt[1], fp16_ip_dis_gt[2], + fp16_ip_dis_gt[3]); + faiss::fp16_vec_L2sqr_batch_4_ref(fp16_data[0], fp16_data[1], fp16_data[2], fp16_data[3], fp16_data[4], dim, + fp16_l2_dis_gt[0], fp16_l2_dis_gt[1], fp16_l2_dis_gt[2], fp16_l2_dis_gt[3]); + faiss::bf16_vec_inner_product_batch_4_ref(bf16_data[0], bf16_data[1], bf16_data[2], bf16_data[3], bf16_data[4], + dim, bf16_ip_dis_gt[0], bf16_ip_dis_gt[1], bf16_ip_dis_gt[2], + bf16_ip_dis_gt[3]); + faiss::bf16_vec_L2sqr_batch_4_ref(bf16_data[0], bf16_data[1], bf16_data[2], bf16_data[3], bf16_data[4], dim, + bf16_l2_dis_gt[0], bf16_l2_dis_gt[1], bf16_l2_dis_gt[2], bf16_l2_dis_gt[3]); + + std::array fp16_l2_dis, fp16_ip_dis, bf16_l2_dis, bf16_ip_dis; + faiss::fp16_vec_inner_product_batch_4(fp16_data[0], fp16_data[1], fp16_data[2], fp16_data[3], fp16_data[4], dim, + fp16_ip_dis[0], fp16_ip_dis[1], fp16_ip_dis[2], fp16_ip_dis[3]); + faiss::fp16_vec_L2sqr_batch_4(fp16_data[0], fp16_data[1], fp16_data[2], fp16_data[3], fp16_data[4], dim, + fp16_l2_dis[0], fp16_l2_dis[1], fp16_l2_dis[2], fp16_l2_dis[3]); + faiss::bf16_vec_inner_product_batch_4(bf16_data[0], bf16_data[1], bf16_data[2], bf16_data[3], bf16_data[4], dim, + bf16_ip_dis[0], bf16_ip_dis[1], bf16_ip_dis[2], bf16_ip_dis[3]); + faiss::bf16_vec_L2sqr_batch_4(bf16_data[0], bf16_data[1], bf16_data[2], bf16_data[3], bf16_data[4], dim, + bf16_l2_dis[0], bf16_l2_dis[1], bf16_l2_dis[2], bf16_l2_dis[3]); + const float tolerance = 0.00001f; + REQUIRE_THAT(fp16_ip_dis[0], Catch::Matchers::WithinRel(fp16_ip_dis_gt[0], tolerance)); + REQUIRE_THAT(fp16_ip_dis[1], Catch::Matchers::WithinRel(fp16_ip_dis_gt[1], tolerance)); + REQUIRE_THAT(fp16_ip_dis[2], Catch::Matchers::WithinRel(fp16_ip_dis_gt[2], tolerance)); + REQUIRE_THAT(fp16_ip_dis[3], Catch::Matchers::WithinRel(fp16_ip_dis_gt[3], tolerance)); + + REQUIRE_THAT(fp16_l2_dis[0], Catch::Matchers::WithinRel(fp16_l2_dis_gt[0], tolerance)); + REQUIRE_THAT(fp16_l2_dis[1], Catch::Matchers::WithinRel(fp16_l2_dis_gt[1], tolerance)); + REQUIRE_THAT(fp16_l2_dis[2], Catch::Matchers::WithinRel(fp16_l2_dis_gt[2], tolerance)); + REQUIRE_THAT(fp16_l2_dis[3], Catch::Matchers::WithinRel(fp16_l2_dis_gt[3], tolerance)); + + REQUIRE_THAT(bf16_ip_dis[0], Catch::Matchers::WithinRel(bf16_ip_dis_gt[0], tolerance)); + REQUIRE_THAT(bf16_ip_dis[1], Catch::Matchers::WithinRel(bf16_ip_dis_gt[1], tolerance)); + REQUIRE_THAT(bf16_ip_dis[2], Catch::Matchers::WithinRel(bf16_ip_dis_gt[2], tolerance)); + REQUIRE_THAT(bf16_ip_dis[3], Catch::Matchers::WithinRel(bf16_ip_dis_gt[3], tolerance)); + + REQUIRE_THAT(bf16_l2_dis[0], Catch::Matchers::WithinRel(bf16_l2_dis_gt[0], tolerance)); + REQUIRE_THAT(bf16_l2_dis[1], Catch::Matchers::WithinRel(bf16_l2_dis_gt[1], tolerance)); + REQUIRE_THAT(bf16_l2_dis[2], Catch::Matchers::WithinRel(bf16_l2_dis_gt[2], tolerance)); + REQUIRE_THAT(bf16_l2_dis[3], Catch::Matchers::WithinRel(bf16_l2_dis_gt[3], tolerance)); + } } diff --git a/tests/ut/utils.h b/tests/ut/utils.h index b02cd8bdb..3bbc3d794 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -239,6 +239,14 @@ GetRangeSearchRecall(const knowhere::DataSet& gt, const knowhere::DataSet& resul return (1 + precision) * recall / 2; } +inline float +GetRelativeLoss(float gt_res, float res) { + if (gt_res == 0.0 || std::abs(gt_res) < 0.000001) { + return gt_res; + } + return std::abs((gt_res - res) / gt_res); +} + inline bool CheckDistanceInScope(const knowhere::DataSet& result, int topk, float low_bound, float high_bound) { auto ids = result.GetDistance(); diff --git a/thirdparty/faiss/faiss/utils/distances_if.h b/thirdparty/faiss/faiss/utils/distances_if.h index f42fa5b98..9f43d3335 100644 --- a/thirdparty/faiss/faiss/utils/distances_if.h +++ b/thirdparty/faiss/faiss/utils/distances_if.h @@ -147,6 +147,10 @@ struct ByIdxRemapping { } // namespace +/*************************************************************************** + * float32 search functions + ***************************************************************************/ + template< // A predicate for filtering elements. // std::optional Pred(const size_t idx); @@ -241,8 +245,8 @@ void internal_fvec_inner_products_ny_if( using idx_type = std::invoke_result_t; // compute a distance from the query to 1 element - auto distance1 = [x, y, d](const idx_type idx) { - return fvec_inner_product(x, y + idx * d, d); + auto distance1 = [x, y, d](const idx_type idx) { + return fvec_inner_product(x, y + idx * d, d); }; // compute distances from the query to 4 elements @@ -569,5 +573,428 @@ void distance_compute_by_idx_if( internal_distance_compute_if(ny, dc, pred, remapper, apply); } +/*************************************************************************** + * float16 & bfloat16 search functions + ***************************************************************************/ +namespace { +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Maps an iteration for-loop index to a database index. + // It is needed for calls with indirect indexing like + // fvec_L2sqr_by_idx(). + // auto IndexRemapper(const size_t idx); + typename IndexRemapper, + // Apply an element. + // void Apply(const float dis, const auto idx); + typename Apply> +void internal_fp16_vec_inner_products_ny_if( + const knowhere::fp16* __restrict x, + const knowhere::fp16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + IndexRemapper remapper, + Apply apply) { + using idx_type = std::invoke_result_t; + // compute a distance from the query to 1 element + auto distance1 = [x, y, d](const idx_type idx) { + return fp16_vec_inner_product(x, y + idx * d, d); + }; + // compute distances from the query to 4 elements + auto distance4 = [x, y, d]( + const std::array indices, + std::array& dis) { + fp16_vec_inner_product_batch_4( + x, + y + indices[0] * d, + y + indices[1] * d, + y + indices[2] * d, + y + indices[3] * d, + d, + dis[0], + dis[1], + dis[2], + dis[3]); + }; + fvec_distance_ny_if< + Pred, + decltype(distance1), + decltype(distance4), + IndexRemapper, + Apply, + 4, + DEFAULT_BUFFER_SIZE>( + ny, pred, distance1, distance4, remapper, apply); +} + +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Maps an iteration for-loop index to a database index. + // It is needed for calls with indirect indexing like + // fvec_L2sqr_by_idx(). + // auto IndexRemapper(const size_t idx); + typename IndexRemapper, + // Apply an element. + // void Apply(const float dis, const auto idx); + typename Apply> +void internal_fp16_vec_L2sqr_ny_if( + const knowhere::fp16* __restrict x, + const knowhere::fp16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + IndexRemapper remapper, + Apply apply) { + using idx_type = std::invoke_result_t; + // compute a distance from the query to 1 element + auto distance1 = [x, y, d](const idx_type idx) { + return fp16_vec_L2sqr(x, y + idx * d, d); + }; + // compute distances from the query to 4 elements + auto distance4 = [x, y, d]( + const std::array indices, + std::array& dis) { + fp16_vec_L2sqr_batch_4( + x, + y + indices[0] * d, + y + indices[1] * d, + y + indices[2] * d, + y + indices[3] * d, + d, + dis[0], + dis[1], + dis[2], + dis[3]); + }; + fvec_distance_ny_if< + Pred, + decltype(distance1), + decltype(distance4), + IndexRemapper, + Apply, + 4, + DEFAULT_BUFFER_SIZE>( + ny, pred, distance1, distance4, remapper, apply); +} +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Maps an iteration for-loop index to a database index. + // It is needed for calls with indirect indexing like + // fvec_L2sqr_by_idx(). + // auto IndexRemapper(const size_t idx); + typename IndexRemapper, + // Apply an element. + // void Apply(const float dis, const auto idx); + typename Apply> +void internal_bf16_vec_inner_products_ny_if( + const knowhere::bf16* __restrict x, + const knowhere::bf16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + IndexRemapper remapper, + Apply apply) { + using idx_type = std::invoke_result_t; + // compute a distance from the query to 1 element + auto distance1 = [x, y, d](const idx_type idx) { + return bf16_vec_inner_product(x, y + idx * d, d); + }; + // compute distances from the query to 4 elements + auto distance4 = [x, y, d]( + const std::array indices, + std::array& dis) { + bf16_vec_inner_product_batch_4( + x, + y + indices[0] * d, + y + indices[1] * d, + y + indices[2] * d, + y + indices[3] * d, + d, + dis[0], + dis[1], + dis[2], + dis[3]); + }; + fvec_distance_ny_if< + Pred, + decltype(distance1), + decltype(distance4), + IndexRemapper, + Apply, + 4, + DEFAULT_BUFFER_SIZE>( + ny, pred, distance1, distance4, remapper, apply); +} + +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Maps an iteration for-loop index to a database index. + // It is needed for calls with indirect indexing like + // fvec_L2sqr_by_idx(). + // auto IndexRemapper(const size_t idx); + typename IndexRemapper, + // Apply an element. + // void Apply(const float dis, const auto idx); + typename Apply> +void internal_bf16_vec_L2sqr_ny_if( + const knowhere::bf16* __restrict x, + const knowhere::bf16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + IndexRemapper remapper, + Apply apply) { + using idx_type = std::invoke_result_t; + // compute a distance from the query to 1 element + auto distance1 = [x, y, d](const idx_type idx) { + return bf16_vec_L2sqr(x, y + idx * d, d); + }; + // compute distances from the query to 4 elements + auto distance4 = [x, y, d]( + const std::array indices, + std::array& dis) { + // todo: optimze it with bf16_vec_L2sqr_batch_4 + bf16_vec_L2sqr_batch_4( + x, + y + indices[0] * d, + y + indices[1] * d, + y + indices[2] * d, + y + indices[3] * d, + d, + dis[0], + dis[1], + dis[2], + dis[3]); + }; + fvec_distance_ny_if< + Pred, + decltype(distance1), + decltype(distance4), + IndexRemapper, + Apply, + 4, + DEFAULT_BUFFER_SIZE>( + ny, pred, distance1, distance4, remapper, apply); +} +} // namespace +// compute ny inner product between x vectors x and a set of contiguous y +// vectors +// with filtering and applying filtered elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const size_t idx); + typename Apply> +void fp16_vec_inner_products_ny_if( + const knowhere::fp16* __restrict x, + const knowhere::fp16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + internal_fp16_vec_inner_products_ny_if( + x, y, d, ny, pred, NoRemapping(), apply); +} + +// compute ny square L2 distance between x vectors x and a set of contiguous y +// vectors +// with filtering and applying filtered elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const size_t idx); + typename Apply> +void fp16_vec_L2sqr_ny_if( + const knowhere::fp16* __restrict x, + const knowhere::fp16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + internal_fp16_vec_L2sqr_ny_if(x, y, d, ny, pred, NoRemapping(), apply); +} + +// compute ny inner product between x vectors x and a set of contiguous y +// vectors +// whose indices are given by idy with filtering and applying filtered +// elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const int64_t idx); + typename Apply> +void fp16_vec_inner_products_ny_by_idx_if( + const float* __restrict x, + const float* __restrict y, + const int64_t* __restrict ids, /* ids of y vecs */ + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + ByIdxRemapping remapper{ids}; + internal_fp16_vec_inner_products_ny_if(x, y, d, ny, pred, remapper, apply); +} + +// compute ny square L2 distance between x vectors x and a set of contiguous y +// vectors +// whose indices are given by idy with filtering and applying filtered +// elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const int64_t idx); + typename Apply> +void fp16_vec_L2sqr_ny_by_idx_if( + const float* __restrict x, + const float* __restrict y, + const int64_t* __restrict ids, /* ids of y vecs */ + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + ByIdxRemapping remapper{ids}; + internal_fp16_vec_L2sqr_ny_if(x, y, d, ny, pred, remapper, apply); +} + +// compute ny inner product between x vectors x and a set of contiguous y +// vectors +// with filtering and applying filtered elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const size_t idx); + typename Apply> +void bf16_vec_inner_products_ny_if( + const knowhere::bf16* __restrict x, + const knowhere::bf16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + internal_bf16_vec_inner_products_ny_if( + x, y, d, ny, pred, NoRemapping(), apply); +} + +// compute ny square L2 distance between x vectors x and a set of contiguous y +// vectors +// with filtering and applying filtered elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const size_t idx); + typename Apply> +void bf16_vec_L2sqr_ny_if( + const knowhere::bf16* __restrict x, + const knowhere::bf16* __restrict y, + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + internal_bf16_vec_L2sqr_ny_if(x, y, d, ny, pred, NoRemapping(), apply); +} + +// compute ny inner product between x vectors x and a set of contiguous y +// vectors +// whose indices are given by idy with filtering and applying filtered +// elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const int64_t idx); + typename Apply> +void bf16_vec_inner_products_ny_by_idx_if( + const float* __restrict x, + const float* __restrict y, + const int64_t* __restrict ids, /* ids of y vecs */ + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + ByIdxRemapping remapper{ids}; + internal_bf16_vec_inner_products_ny_if(x, y, d, ny, pred, remapper, apply); +} + +// compute ny square L2 distance between x vectors x and a set of contiguous y +// vectors +// whose indices are given by idy with filtering and applying filtered +// elements. +template < + // A predicate for filtering elements. + // std::optional Pred(const size_t idx); + // * return true to accept an element. + // * return false to reject an element. + // * return std::nullopt to break the iteration loop. + typename Pred, + // Apply an element. + // void Apply(const float dis, const int64_t idx); + typename Apply> +void bf16_vec_L2sqr_ny_by_idx_if( + const float* __restrict x, + const float* __restrict y, + const int64_t* __restrict ids, /* ids of y vecs */ + size_t d, + const size_t ny, + Pred pred, + Apply apply) { + ByIdxRemapping remapper{ids}; + internal_bf16_vec_L2sqr_ny_if(x, y, d, ny, pred, remapper, apply); +} + + } //namespace faiss diff --git a/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.cpp b/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.cpp new file mode 100644 index 000000000..e51cf8dff --- /dev/null +++ b/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.cpp @@ -0,0 +1,692 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License + +#include + +#include +#include +#include +#include +#include +#include "knowhere/bitsetview_idselector.h" +#include "knowhere/operands.h" +#include "simd/hook.h" +namespace faiss { +namespace { +template +void half_precision_floating_point_exhaustive_inner_product_impl( + const DataType* __restrict x, + const DataType* __restrict y, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res, + const IDSelector& selector) { + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; + + SingleResultHandler resi(res); + for (int64_t i = 0; i < nx; i++) { + const DataType* x_i = x + i * d; + resi.begin(i); + + // the lambda that applies a filtered element. + auto apply = [&resi](const float ip, const idx_t j) { + resi.add_result(ip, j); + }; + if constexpr (std::is_same_v) { + // todo: need more tests about this branch + auto filter = [](const size_t j) { return true; }; + if constexpr (std::is_same_v) { + fp16_vec_inner_products_ny_by_idx_if( + x_i, y, selector.ids, d, selector.n, filter, apply); + } else { + bf16_vec_inner_products_ny_by_idx_if( + x_i, y, selector.ids, d, selector.n, filter, apply); + } + } else { + // the lambda that filters acceptable elements. + auto filter = [&selector](const size_t j) { + return selector.is_member(j); + }; + if constexpr (std::is_same_v) { + fp16_vec_inner_products_ny_if(x_i, y, d, ny, filter, apply); + } else { + bf16_vec_inner_products_ny_if(x_i, y, d, ny, filter, apply); + } + } + resi.end(); + } +} + +template +void half_precision_floating_point_exhaustive_L2sqr_seq_impl( + const DataType* __restrict x, + const DataType* __restrict y, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res, + const IDSelector& selector) { + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; + + SingleResultHandler resi(res); + for (int64_t i = 0; i < nx; i++) { + const DataType* x_i = x + i * d; + resi.begin(i); + + // the lambda that applies a filtered element. + auto apply = [&resi](const float ip, const idx_t j) { + resi.add_result(ip, j); + }; + if constexpr (std::is_same_v) { + // todo: need more tests about this branch + auto filter = [](const size_t j) { return true; }; + if constexpr (std::is_same_v) { + fp16_vec_L2sqr_ny_by_idx_if( + x_i, y, selector.ids, d, selector->n, filter, apply); + } else { + bf16_vec_L2sqr_ny_by_idx_if( + x_i, y, selector.ids, d, selector->n, filter, apply); + } + } else { + // the lambda that filters acceptable elements. + auto filter = [&selector](const size_t j) { + return selector.is_member(j); + }; + if constexpr (std::is_same_v) { + fp16_vec_L2sqr_ny_if(x_i, y, d, ny, filter, apply); + } else { + bf16_vec_L2sqr_ny_if(x_i, y, d, ny, filter, apply); + } + } + resi.end(); + } +} + +template +void half_precision_floating_point_exhaustive_cosine_seq_impl( + const DataType* __restrict x, + const DataType* __restrict y, + const float* __restrict y_norms, + size_t d, + size_t nx, + size_t ny, + BlockResultHandler& res, + const IDSelector& selector) { + using SingleResultHandler = + typename BlockResultHandler::SingleResultHandler; + typedef float (*NormComputer)(const DataType*, size_t); + NormComputer norm_computer; + if constexpr (std::is_same_v) { + norm_computer = fp16_vec_norm_L2sqr; + } else { + norm_computer = bf16_vec_norm_L2sqr; + } + SingleResultHandler resi(res); + for (int64_t i = 0; i < nx; i++) { + const DataType* x_i = x + i * d; + // distance div x_norm before pushing into the heap + auto x_norm = sqrtf(norm_computer(x_i, d)); + x_norm = (x_norm == 0.0 ? 1.0 : x_norm); + auto apply = [&resi, x_norm, y, y_norms, d, norm_computer]( + const float ip, const idx_t j) { + float y_norm = (y_norms != nullptr) + ? y_norms[j] + : sqrtf(norm_computer(y + j * d, d)); + + y_norm = (y_norm == 0.0 ? 1.0 : y_norm); + resi.add_result(ip / (x_norm * y_norm), j); + }; + resi.begin(i); + // the lambda that applies a filtered element + if constexpr (std::is_same_v) { + // todo: need more tests about this branch + auto filter = [](const size_t j) { return true; }; + if constexpr (std::is_same_v) { + fp16_vec_inner_products_ny_by_idx_if( + x_i, y, selector.ids, d, selector->n, filter, apply); + } else { + bf16_vec_inner_products_ny_by_idx_if( + x_i, y, selector.ids, d, selector->n, filter, apply); + } + } else { + // the lambda that filters acceptable elements. + auto filter = [&selector](const size_t j) { + return selector.is_member(j); + }; + if constexpr (std::is_same_v) { + fp16_vec_inner_products_ny_if(x_i, y, d, ny, filter, apply); + } else { + bf16_vec_inner_products_ny_if(x_i, y, d, ny, filter, apply); + } + } + resi.end(); + } +} +} // namespace + +template +void half_precision_floating_point_knn_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* vals, + int64_t* ids, + const IDSelector* sel) { + int64_t imin = 0; + if (auto selr = dynamic_cast(sel)) { + imin = std::max(selr->imin, int64_t(0)); + int64_t imax = std::min(selr->imax, int64_t(ny)); + ny = imax - imin; + y += d * imin; + sel = nullptr; + } + if (k < distance_compute_min_k_reservoir) { + HeapBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel); + } + } else { + ReservoirBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel); + } + } + + if (imin != 0) { + for (size_t i = 0; i < nx * k; i++) { + if (ids[i] >= 0) { + ids[i] += imin; + } + } + } + return; +} + +template +void half_precision_floating_point_all_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const IDSelector* sel) { + CollectAllResultHandler> res(nx, ny, output); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, res, *sel); + } + return; +} + +template +void half_precision_floating_point_knn_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* vals, + int64_t* ids, + const float* y_norm2, + const IDSelector* sel) { + int64_t imin = 0; + if (auto selr = dynamic_cast(sel)) { + imin = std::max(selr->imin, int64_t(0)); + int64_t imax = std::min(selr->imax, int64_t(ny)); + ny = imax - imin; + y += d * imin; + sel = nullptr; + } + if (k < distance_compute_min_k_reservoir) { + HeapBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel); + } + } else { + ReservoirBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel); + } + } + if (imin != 0) { + for (size_t i = 0; i < nx * k; i++) { + if (ids[i] >= 0) { + ids[i] += imin; + } + } + } + return; +} + +template +void half_precision_floating_point_all_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const float* y_norms, + const IDSelector* sel) { + CollectAllResultHandler> res(nx, ny, output); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, res, *sel); + } + return; +} + +template +void half_precision_floating_point_knn_cosine( + const DataType* x, + const DataType* y, + const float* y_norm2, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* vals, + int64_t* ids, + const IDSelector* sel) { + int64_t imin = 0; + if (auto selr = dynamic_cast(sel)) { + imin = std::max(selr->imin, int64_t(0)); + int64_t imax = std::min(selr->imax, int64_t(ny)); + ny = imax - imin; + y += d * imin; + sel = nullptr; + } + if (k < distance_compute_min_k_reservoir) { + HeapBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, *sel); + } + } else { + ReservoirBlockResultHandler> res(nx, vals, ids, k); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norm2, d, nx, ny, res, *sel); + } + } + if (imin != 0) { + for (size_t i = 0; i < nx * k; i++) { + if (ids[i] >= 0) { + ids[i] += imin; + } + } + } + return; +} + +template +void half_precision_floating_point_all_cosine( + const DataType* x, + const DataType* y, + const float* y_norms, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const IDSelector* sel) { + CollectAllResultHandler> res(nx, ny, output); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, res, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, res, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, res, *sel); + } + return; +} + +/*************************************************************************** + * Range search + ***************************************************************************/ +struct RangeSearchResult; + +template +void half_precision_floating_point_range_search_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const IDSelector* sel) { + RangeSearchBlockResultHandler> resh(res, radius); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, resh, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, resh, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_L2sqr_seq_impl( + x, y, d, nx, ny, resh, *sel); + } + return; +} + +template +void half_precision_floating_point_range_search_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const IDSelector* sel) { + RangeSearchBlockResultHandler> resh(res, radius); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, resh, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, resh, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_inner_product_impl( + x, y, d, nx, ny, resh, *sel); + } + return; +} + +// Knowhere-specific function +template +void half_precision_floating_point_range_search_cosine( + const DataType* x, + const DataType* y, + const float* y_norms, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const IDSelector* sel) { + RangeSearchBlockResultHandler> resh(res, radius); + if (const auto* sel_bs = + dynamic_cast(sel)) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, resh, *sel_bs); + } else if (sel == nullptr) { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, resh, IDSelectorAll()); + } else { + half_precision_floating_point_exhaustive_cosine_seq_impl( + x, y, y_norms, d, nx, ny, resh, *sel); + } + return; +} + +// knn functions +template void faiss::half_precision_floating_point_knn_inner_product< + knowhere::fp16>( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const IDSelector*); +template void faiss::half_precision_floating_point_knn_inner_product< + knowhere::bf16>( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const IDSelector*); +template void faiss::half_precision_floating_point_all_inner_product< + knowhere::fp16>( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + std::vector&, + const IDSelector*); +template void faiss::half_precision_floating_point_all_inner_product< + knowhere::bf16>( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + std::vector&, + const IDSelector*); +template void faiss::half_precision_floating_point_knn_L2sqr( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const float*, + const IDSelector*); +template void faiss::half_precision_floating_point_knn_L2sqr( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const float*, + const IDSelector*); +template void faiss::half_precision_floating_point_all_L2sqr( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + std::vector&, + const float*, + const IDSelector*); +template void faiss::half_precision_floating_point_all_L2sqr( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + std::vector&, + const float*, + const IDSelector*); +template void faiss::half_precision_floating_point_knn_cosine( + const knowhere::fp16*, + const knowhere::fp16*, + const float*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const IDSelector*); +template void faiss::half_precision_floating_point_knn_cosine( + const knowhere::bf16*, + const knowhere::bf16*, + const float*, + size_t, + size_t, + size_t, + size_t, + float*, + int64_t*, + const IDSelector*); +template void faiss::half_precision_floating_point_all_cosine( + const knowhere::fp16*, + const knowhere::fp16*, + const float*, + size_t, + size_t, + size_t, + std::vector&, + const IDSelector*); +template void faiss::half_precision_floating_point_all_cosine( + const knowhere::bf16*, + const knowhere::bf16*, + const float*, + size_t, + size_t, + size_t, + std::vector&, + const IDSelector*); +// range search functions +template void faiss::half_precision_floating_point_range_search_L2sqr< + knowhere::fp16>( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +template void faiss::half_precision_floating_point_range_search_L2sqr< + knowhere::bf16>( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +template void faiss::half_precision_floating_point_range_search_inner_product< + knowhere::fp16>( + const knowhere::fp16*, + const knowhere::fp16*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +template void faiss::half_precision_floating_point_range_search_inner_product< + knowhere::bf16>( + const knowhere::bf16*, + const knowhere::bf16*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +template void faiss::half_precision_floating_point_range_search_cosine< + knowhere::fp16>( + const knowhere::fp16*, + const knowhere::fp16*, + const float*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +template void faiss::half_precision_floating_point_range_search_cosine< + knowhere::bf16>( + const knowhere::bf16*, + const knowhere::bf16*, + const float*, + size_t, + size_t, + size_t, + float, + faiss::RangeSearchResult*, + const faiss::IDSelector*); +} // namespace faiss \ No newline at end of file diff --git a/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.h b/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.h new file mode 100644 index 000000000..d53a94704 --- /dev/null +++ b/thirdparty/faiss/faiss/utils/half_precision_floating_point_distances.h @@ -0,0 +1,138 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License + +#ifndef FAISS_HALF_PRECISION_FLOATING_POINT_DISTANCES_H +#define FAISS_HALF_PRECISION_FLOATING_POINT_DISTANCES_H +#pragma once +#include +#include +#include +#include "knowhere/object.h" +namespace faiss { +struct IDSelector; +/*************************************************************************** + * KNN functions + ***************************************************************************/ +// Knowhere-specific function +template +void half_precision_floating_point_knn_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* distances, + int64_t* indexes, + const IDSelector* sel = nullptr); + +template +void half_precision_floating_point_all_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const IDSelector* sel = nullptr); + +template +void half_precision_floating_point_knn_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* distances, + int64_t* indexes, + const float* y_norm2 = nullptr, + const IDSelector* sel = nullptr); + +template +void half_precision_floating_point_all_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const float* y_norms = nullptr, + const IDSelector* sel = nullptr); + +template +void half_precision_floating_point_knn_cosine( + const DataType* x, + const DataType* y, + const float* y_norms, + size_t d, + size_t nx, + size_t ny, + size_t k, + float* distances, + int64_t* indexes, + const IDSelector* sel = nullptr); + +template +void half_precision_floating_point_all_cosine( + const DataType* x, + const DataType* y, + const float* y_norms, + size_t d, + size_t nx, + size_t ny, + std::vector& output, + const IDSelector* sel = nullptr); + +/*************************************************************************** + * Range search + ***************************************************************************/ +struct RangeSearchResult; +template +void half_precision_floating_point_range_search_L2sqr( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* result, + const IDSelector* sel = nullptr); + +/// same as range_search_L2sqr for the inner product similarity +template +void half_precision_floating_point_range_search_inner_product( + const DataType* x, + const DataType* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* result, + const IDSelector* sel = nullptr); + +// Knowhere-specific function +template +void half_precision_floating_point_range_search_cosine( + const DataType* x, + const DataType* y, + const float* y_norms, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* result, + const IDSelector* sel = nullptr); +} // namespace faiss +#endif // FAISS_HALF_PRECISION_FLOATING_POINT_DISTANCES_H \ No newline at end of file