diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 8727cc8a..371847ad 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface { searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { assert(k <= cur_element_count); std::priority_queue> topResults; - if (cur_element_count == 0) return topResults; - for (int i = 0; i < k; i++) { + dist_t lastdist = std::numeric_limits::max(); + for (int i = 0; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.emplace(dist, label); - } - } - dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; - for (int i = k; i < cur_element_count; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - if (dist <= lastdist) { + if (dist <= lastdist || topResults.size() < k) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.emplace(dist, label); - } - if (topResults.size() > k) - topResults.pop(); - - if (!topResults.empty()) { - lastdist = topResults.top().first; + if (topResults.size() > k) + topResults.pop(); + if (!topResults.empty()) + lastdist = topResults.top().first; } } } diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index dd09e80a..56ce9beb 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -871,16 +871,39 @@ class BFIndex { CustomFilterFunctor idFilter(filter); CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; - ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = alg->searchKnn( - (void*)items.data(row), k, p_idFilter); - for (int i = k - 1; i >= 0; i--) { - auto& result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - }); + if (!normalize) { + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg->searchKnn( + (void*)items.data(row), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } else { + std::vector norm_array(num_threads * features); + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), norm_array.data() + start_idx); + + std::priority_queue> result = alg->searchKnn( + (void*)(norm_array.data() + start_idx), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } } py::capsule free_when_done_l(data_numpy_l, [](void *f) {