From b5a58460bb657882ee61342d2a8bf34ddf9fed72 Mon Sep 17 00:00:00 2001 From: Michael Norris Date: Fri, 6 Dec 2024 11:11:52 -0800 Subject: [PATCH] Add more unit testing for HNSW [3/n] (#4059) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4059 Add unit test to compare reference version with optimized version. Reviewed By: mengdilin Differential Revision: D66793367 fbshipit-source-id: 8da25e79f66d079f76d237c10fc3db4a0def767d --- faiss/impl/HNSW.cpp | 19 ++++++++++++------- faiss/impl/HNSW.h | 10 ++++++++++ tests/test_hnsw.cpp | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 7 deletions(-) diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 642bf7c532..09b10e2b97 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -351,6 +351,8 @@ void add_link( } } +} // namespace + /// search neighbors on a single level, starting from an entry point void search_neighbors_to_add( HNSW& hnsw, @@ -359,10 +361,8 @@ void search_neighbors_to_add( int entry_point, float d_entry_point, int level, - VisitedTable& vt) { - // selects a version - const bool reference_version = false; - + VisitedTable& vt, + bool reference_version) { // top is nearest candidate std::priority_queue candidates; @@ -385,7 +385,14 @@ void search_neighbors_to_add( size_t begin, end; hnsw.neighbor_range(currNode, level, &begin, &end); - // select a version, based on a flag + // The reference version is not used, but kept here because: + // 1. It is easier to switch back if the optimized version has a problem + // 2. It serves as a starting point for new optimizations + // 3. It helps understand the code + // 4. It ensures the reference version is still compilable if the + // optimized version changes + // The reference and the optimized versions' results are compared in + // test_hnsw.cpp if (reference_version) { // a reference version for (size_t i = begin; i < end; i++) { @@ -470,8 +477,6 @@ void search_neighbors_to_add( vt.advance(); } -} // namespace - /// Finds neighbors and builds links with them, starting from an entry /// point. The own neighbor list is assumed to be locked. void HNSW::add_links_starting_from( diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index 71419edbb5..aad26b1eda 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -281,4 +281,14 @@ std::priority_queue search_from_candidate_unbounded( VisitedTable* vt, HNSWStats& stats); +void search_neighbors_to_add( + HNSW& hnsw, + DistanceComputer& qdis, + std::priority_queue& results, + int entry_point, + float d_entry_point, + int level, + VisitedTable& vt, + bool reference_version = false); + } // namespace faiss diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp index c546a76778..a878564a6d 100644 --- a/tests/test_hnsw.cpp +++ b/tests/test_hnsw.cpp @@ -541,3 +541,43 @@ TEST_F(HNSWTest, TEST_search_from_candidates) { EXPECT_EQ(reference_stats.n1, stats.n1); EXPECT_EQ(reference_stats.n2, stats.n2); } + +TEST_F(HNSWTest, TEST_search_neighbors_to_add) { + omp_set_num_threads(1); + + faiss::VisitedTable vt(index->ntotal); + faiss::VisitedTable reference_vt(index->ntotal); + + std::priority_queue link_targets; + std::priority_queue reference_link_targets; + + faiss::search_neighbors_to_add( + index->hnsw, + *dis, + link_targets, + index->hnsw.entry_point, + (*dis)(index->hnsw.entry_point), + index->hnsw.max_level, + vt, + false); + + faiss::search_neighbors_to_add( + index->hnsw, + *dis, + reference_link_targets, + index->hnsw.entry_point, + (*dis)(index->hnsw.entry_point), + index->hnsw.max_level, + reference_vt, + true); + + EXPECT_EQ(link_targets.size(), reference_link_targets.size()); + while (!link_targets.empty()) { + auto val = link_targets.top(); + auto reference_val = reference_link_targets.top(); + EXPECT_EQ(val.d, reference_val.d); + EXPECT_EQ(val.id, reference_val.id); + link_targets.pop(); + reference_link_targets.pop(); + } +}