diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 78c67d9c8..063a41b1e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -308,6 +308,7 @@ add_library( src/cluster/kmeans_transform_double.cu src/cluster/kmeans_transform_float.cu src/cluster/single_linkage_float.cu + src/core/bitset.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -406,10 +407,6 @@ add_library( src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu - src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu - src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu - src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu - src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu src/neighbors/nn_descent.cu src/neighbors/nn_descent_float.cu src/neighbors/nn_descent_half.cu diff --git a/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu b/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu index 4c38b3420..55d5b8c70 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_brute_force_knn.cu @@ -134,7 +134,7 @@ class BruteForceKNNBenchmark { search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); flush_l2_cache(); raft::resource::sync_stream(handle_, stream_); } @@ -158,7 +158,7 @@ class BruteForceKNNBenchmark { search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); raft::resource::sync_stream(handle_, stream_); end = std::chrono::high_resolution_clock::now(); search_dur = end - start; @@ -178,7 +178,7 @@ class BruteForceKNNBenchmark { search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); flush_l2_cache(); raft::resource::sync_stream(handle_, stream_); } @@ -202,7 +202,7 @@ class BruteForceKNNBenchmark { search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); raft::resource::sync_stream(handle_, stream_); end = std::chrono::high_resolution_clock::now(); search_dur = end - start; diff --git a/cpp/bench/ann/src/cuvs/cuvs_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_wrapper.h index ea052533d..bf0fa5934 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_wrapper.h @@ -155,8 +155,12 @@ void cuvs_gpu::search( raft::make_device_matrix_view(neighbors, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - cuvs::neighbors::brute_force::search( - handle_, *index_, queries_view, neighbors_view, distances_view, std::nullopt); + cuvs::neighbors::brute_force::search(handle_, + *index_, + queries_view, + neighbors_view, + distances_view, + cuvs::neighbors::filtering::none_sample_filter{}); } template diff --git a/cpp/include/cuvs/core/bitset.hpp b/cpp/include/cuvs/core/bitset.hpp index 99942e21c..8236bbf07 100644 --- a/cpp/include/cuvs/core/bitset.hpp +++ b/cpp/include/cuvs/core/bitset.hpp @@ -18,6 +18,12 @@ #include +extern template struct raft::core::bitset; +extern template struct raft::core::bitset; +extern template struct raft::core::bitset; +extern template struct raft::core::bitset; +extern template struct raft::core::bitset; + namespace cuvs::core { /* To use bitset functions containing CUDA code, include */ diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 5408eb1a0..428fa592a 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -291,7 +291,8 @@ void search(raft::resources const& handle, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - std::optional> sample_filter); + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -326,7 +327,8 @@ void search(raft::resources const& handle, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - std::optional> sample_filter); + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. * @@ -346,7 +348,8 @@ void search(raft::resources const& handle, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - std::optional> sample_filter); + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. * @@ -366,7 +369,8 @@ void search(raft::resources const& handle, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - std::optional> sample_filter); + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @} */ diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 20db7e8b7..e48050756 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1055,6 +1055,8 @@ void extend( * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& res, @@ -1062,7 +1064,9 @@ void search(raft::resources const& res, const cuvs::neighbors::cagra::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1077,13 +1081,17 @@ void search(raft::resources const& res, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, const cuvs::neighbors::cagra::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1098,13 +1106,17 @@ void search(raft::resources const& res, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, const cuvs::neighbors::cagra::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1119,13 +1131,18 @@ void search(raft::resources const& res, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, const cuvs::neighbors::cagra::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + /** * @} */ diff --git a/cpp/include/cuvs/neighbors/common.h b/cpp/include/cuvs/neighbors/common.h index 02cbeea96..d7ca878b9 100644 --- a/cpp/include/cuvs/neighbors/common.h +++ b/cpp/include/cuvs/neighbors/common.h @@ -44,7 +44,7 @@ enum cuvsFilterType { }; /** - * @brief Struct to hold address of cuvs::neighbor::prefilter and its type + * @brief Struct to hold address of cuvs::neighbors::prefilter and its type * */ typedef struct { diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 8218b5f52..73ce80b41 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -383,8 +383,12 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset::value; namespace filtering { +struct base_filter { + virtual ~base_filter() = default; +}; + /* A filter that filters nothing. This is the default behavior. */ -struct none_ivf_sample_filter { +struct none_sample_filter : public base_filter { inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, @@ -392,10 +396,7 @@ struct none_ivf_sample_filter { const uint32_t cluster_ix, // the index of the current sample inside the current inverted list const uint32_t sample_ix) const; -}; -/* A filter that filters nothing. This is the default behavior. */ -struct none_cagra_sample_filter { inline _RAFT_HOST_DEVICE bool operator()( // query index const uint32_t query_ix, @@ -431,13 +432,33 @@ struct ivf_to_sample_filter { const uint32_t sample_ix) const; }; +/** + * @brief Filter an index with a bitmap + * + * @tparam bitmap_t Data type of the bitmap + * @tparam index_t Indexing type + */ +template +struct bitmap_filter : public base_filter { + // View of the bitset to use as a filter + const cuvs::core::bitmap_view bitmap_view_; + + bitmap_filter(const cuvs::core::bitmap_view bitmap_for_filtering); + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const; +}; + /** * @brief Filter an index with a bitset * + * @tparam bitset_t Data type of the bitset * @tparam index_t Indexing type */ template -struct bitset_filter { +struct bitset_filter : public base_filter { // View of the bitset to use as a filter const cuvs::core::bitset_view bitset_view_; diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 44502f942..67d1b46c0 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1163,13 +1163,17 @@ void extend(raft::resources const& handle, * dataset [n_queries, k] * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors * [n_queries, k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_flat::search_params& params, cuvs::neighbors::ivf_flat::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1200,13 +1204,17 @@ void search(raft::resources const& handle, * dataset [n_queries, k] * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors * [n_queries, k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_flat::search_params& params, cuvs::neighbors::ivf_flat::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1237,112 +1245,18 @@ void search(raft::resources const& handle, * dataset [n_queries, k] * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors * [n_queries, k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_flat::search_params& params, cuvs::neighbors::ivf_flat::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-flat constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-flat constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-flat constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); /** * @} */ @@ -2039,18 +1953,18 @@ void reset_index(const raft::resources& res, index* index); * using namespace cuvs::neighbors; * raft::resources res; * // use default index parameters - * ivf_pq::index_params index_params; + * ivf_flat::index_params index_params; * // initialize an empty index - * ivf_pq::index index(res, index_params, D); - * ivf_pq::helpers::reset_index(res, &index); + * ivf_flat::index index(res, index_params, D); + * ivf_flat::helpers::reset_index(res, &index); * // resize the first IVF list to hold 5 records - * auto spec = list_spec{ - * index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + * auto spec = list_spec{ + * index->dim(), index->conservative_memory_allocation()}; * uint32_t new_size = 5; * ivf::resize_list(res, list, spec, new_size, 0); * raft::update_device(index.list_sizes(), &new_size, 1, stream); * // recompute the internal state of the index - * ivf_pq::helpers::recompute_internal_state(res, index); + * ivf_flat::helpers::recompute_internal_state(res, index); * @endcode * * @param[in] res raft resource @@ -2067,18 +1981,18 @@ void recompute_internal_state(const raft::resources& res, index* * using namespace cuvs::neighbors; * raft::resources res; * // use default index parameters - * ivf_pq::index_params index_params; + * ivf_flat::index_params index_params; * // initialize an empty index - * ivf_pq::index index(res, index_params, D); - * ivf_pq::helpers::reset_index(res, &index); + * ivf_flat::index index(res, index_params, D); + * ivf_flat::helpers::reset_index(res, &index); * // resize the first IVF list to hold 5 records - * auto spec = list_spec{ - * index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + * auto spec = list_spec{ + * index->dim(), index->conservative_memory_allocation()}; * uint32_t new_size = 5; * ivf::resize_list(res, list, spec, new_size, 0); * raft::update_device(index.list_sizes(), &new_size, 1, stream); * // recompute the internal state of the index - * ivf_pq::helpers::recompute_internal_state(res, index); + * ivf_flat::helpers::recompute_internal_state(res, index); * @endcode * * @param[in] res raft resource @@ -2095,18 +2009,18 @@ void recompute_internal_state(const raft::resources& res, index * using namespace cuvs::neighbors; * raft::resources res; * // use default index parameters - * ivf_pq::index_params index_params; + * ivf_flat::index_params index_params; * // initialize an empty index - * ivf_pq::index index(res, index_params, D); - * ivf_pq::helpers::reset_index(res, &index); + * ivf_flat::index index(res, index_params, D); + * ivf_flat::helpers::reset_index(res, &index); * // resize the first IVF list to hold 5 records - * auto spec = list_spec{ - * index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + * auto spec = list_spec{ + * index->dim(), index->conservative_memory_allocation()}; * uint32_t new_size = 5; * ivf::resize_list(res, list, spec, new_size, 0); * raft::update_device(index.list_sizes(), &new_size, 1, stream); * // recompute the internal state of the index - * ivf_pq::helpers::recompute_internal_state(res, index); + * ivf_flat::helpers::recompute_internal_state(res, index); * @endcode * * @param[in] res raft resource diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 8c378b1f0..3ce5f382f 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1400,13 +1400,17 @@ void extend(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_pq::search_params& search_params, cuvs::neighbors::ivf_pq::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1441,13 +1445,17 @@ void search(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_pq::search_params& search_params, cuvs::neighbors::ivf_pq::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1482,13 +1490,17 @@ void search(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_pq::search_params& search_params, cuvs::neighbors::ivf_pq::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1523,145 +1535,18 @@ void search(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] + * @param[in] sample_filter an optional device filter function object that greenlights samples + * for a given query. (none_sample_filter for no filtering) */ void search(raft::resources const& handle, const cuvs::neighbors::ivf_pq::search_params& search_params, cuvs::neighbors::ivf_pq::index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances); + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device bitset filter function that greenlights samples for a given - * query. - */ -void search_with_filtering( - raft::resources const& handle, - const search_params& params, - index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - cuvs::neighbors::filtering::bitset_filter sample_filter); /** * @} */ diff --git a/cpp/src/core/bitset.cu b/cpp/src/core/bitset.cu new file mode 100644 index 000000000..c791747a9 --- /dev/null +++ b/cpp/src/core/bitset.cu @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +template struct raft::core::bitset; +template struct raft::core::bitset; +template struct raft::core::bitset; +template struct raft::core::bitset; +template struct raft::core::bitset; diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index c76feb015..b0f87e9ac 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -145,54 +145,45 @@ void index::update_dataset( dataset_view_ = raft::make_const_mdspan(dataset_.view()); } -#define CUVS_INST_BFKNN(T, DistT) \ - auto build(raft::resources const& res, \ - raft::device_matrix_view dataset, \ - cuvs::distance::DistanceType metric, \ - DistT metric_arg) \ - ->cuvs::neighbors::brute_force::index \ - { \ - return detail::build(res, dataset, metric, metric_arg); \ - } \ - auto build(raft::resources const& res, \ - raft::device_matrix_view dataset, \ - cuvs::distance::DistanceType metric, \ - DistT metric_arg) \ - ->cuvs::neighbors::brute_force::index \ - { \ - return detail::build(res, dataset, metric, metric_arg); \ - } \ - \ - void search( \ - raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - std::optional> sample_filter = std::nullopt) \ - { \ - if (!sample_filter.has_value()) { \ - detail::brute_force_search(res, idx, queries, neighbors, distances); \ - } else { \ - detail::brute_force_search_filtered( \ - res, idx, queries, *sample_filter, neighbors, distances); \ - } \ - } \ - void search( \ - raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - std::optional> sample_filter = std::nullopt) \ - { \ - if (!sample_filter.has_value()) { \ - detail::brute_force_search(res, idx, queries, neighbors, distances); \ - } else { \ - RAFT_FAIL("filtered search isn't available with col_major queries yet"); \ - } \ - } \ - \ +#define CUVS_INST_BFKNN(T, DistT) \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + DistT metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + return detail::build(res, dataset, metric, metric_arg); \ + } \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + DistT metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + return detail::build(res, dataset, metric, metric_arg); \ + } \ + \ + void search(raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + detail::search( \ + res, idx, queries, neighbors, distances, sample_filter); \ + } \ + void search(raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + detail::search( \ + res, idx, queries, neighbors, distances, sample_filter); \ + } \ + \ template struct cuvs::neighbors::brute_force::index; CUVS_INST_BFKNN(float, float); @@ -200,4 +191,4 @@ CUVS_INST_BFKNN(half, float); #undef CUVS_INST_BFKNN -} // namespace cuvs::neighbors::brute_force +} // namespace cuvs::neighbors::brute_force \ No newline at end of file diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index f3ca2e730..eda79aa31 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -64,28 +64,31 @@ void _search(cuvsResources_t res, using neighbors_mdspan_type = raft::device_matrix_view; using distances_mdspan_type = raft::device_matrix_view; using prefilter_mds_type = raft::device_vector_view; - using prefilter_opt_type = cuvs::core::bitmap_view; + using prefilter_bmp_type = cuvs::core::bitmap_view; auto queries_mds = cuvs::core::from_dlpack(queries_tensor); auto neighbors_mds = cuvs::core::from_dlpack(neighbors_tensor); auto distances_mds = cuvs::core::from_dlpack(distances_tensor); - std::optional> filter_opt; - if (prefilter.type == NO_FILTER) { - filter_opt = std::nullopt; - } else { + cuvs::neighbors::brute_force::search(*res_ptr, + *index_ptr, + queries_mds, + neighbors_mds, + distances_mds, + cuvs::neighbors::filtering::none_sample_filter{}); + } else if (prefilter.type == BITMAP) { auto prefilter_ptr = reinterpret_cast(prefilter.addr); auto prefilter_mds = cuvs::core::from_dlpack(prefilter_ptr); - auto prefilter_view = prefilter_opt_type((const uint32_t*)prefilter_mds.data_handle(), - queries_mds.extent(0), - index_ptr->dataset().extent(0)); - - filter_opt = std::make_optional(prefilter_view); + auto prefilter_view = cuvs::neighbors::filtering::bitmap_filter( + prefilter_bmp_type((const uint32_t*)prefilter_mds.data_handle(), + queries_mds.extent(0), + index_ptr->dataset().extent(0))); + cuvs::neighbors::brute_force::search( + *res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, prefilter_view); + } else { + RAFT_FAIL("Unsupported prefilter type: BITSET"); } - - cuvs::neighbors::brute_force::search( - *res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, filter_opt); } } // namespace diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index 033f080e2..dacfd6f63 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -332,11 +332,29 @@ void search(raft::resources const& res, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { - using none_filter_type = cuvs::neighbors::filtering::none_cagra_sample_filter; - return cagra::search_with_filtering( - res, params, idx, queries, neighbors, distances, none_filter_type{}); + try { + using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; + auto& sample_filter = dynamic_cast(sample_filter_ref); + auto sample_filter_copy = sample_filter; + return search_with_filtering( + res, params, idx, queries, neighbors, distances, sample_filter_copy); + return; + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast&>( + sample_filter_ref); + auto sample_filter_copy = sample_filter; + return search_with_filtering( + res, params, idx, queries, neighbors, distances, sample_filter_copy); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); + } } template diff --git a/cpp/src/neighbors/cagra_search_float.cu b/cpp/src/neighbors/cagra_search_float.cu index e981d9127..3aca84f74 100644 --- a/cpp/src/neighbors/cagra_search_float.cu +++ b/cpp/src/neighbors/cagra_search_float.cu @@ -19,15 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(float, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_half.cu b/cpp/src/neighbors/cagra_search_half.cu index d80f2bc00..02be12731 100644 --- a/cpp/src/neighbors/cagra_search_half.cu +++ b/cpp/src/neighbors/cagra_search_half.cu @@ -19,15 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(half, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_int8.cu b/cpp/src/neighbors/cagra_search_int8.cu index b44a7507d..3442ef55f 100644 --- a/cpp/src/neighbors/cagra_search_int8.cu +++ b/cpp/src/neighbors/cagra_search_int8.cu @@ -18,15 +18,17 @@ #include namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_uint8.cu b/cpp/src/neighbors/cagra_search_uint8.cu index cbb7d6652..08fe1861b 100644 --- a/cpp/src/neighbors/cagra_search_uint8.cu +++ b/cpp/src/neighbors/cagra_search_uint8.cu @@ -19,15 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 6dc601f32..4c15b8e14 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -17,6 +17,7 @@ #pragma once #include "factory.cuh" +#include "sample_filter_utils.cuh" #include "search_plan.cuh" #include "search_single_cta_inst.cuh" @@ -42,48 +43,6 @@ namespace cuvs::neighbors::cagra::detail { -template -struct CagraSampleFilterWithQueryIdOffset { - const uint32_t offset; - CagraSampleFilterT filter; - - CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) - : offset(offset), filter(filter) - { - } - - _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) - { - return filter(query_id + offset, sample_id); - } -}; - -template -struct CagraSampleFilterT_Selector { - using type = CagraSampleFilterWithQueryIdOffset; -}; -template <> -struct CagraSampleFilterT_Selector { - using type = cuvs::neighbors::filtering::none_cagra_sample_filter; -}; - -// A helper function to set a query id offset -template -inline typename CagraSampleFilterT_Selector::type set_offset( - CagraSampleFilterT filter, const uint32_t offset) -{ - typename CagraSampleFilterT_Selector::type new_filter(offset, filter); - return new_filter; -} -template <> -inline - typename CagraSampleFilterT_Selector::type - set_offset( - cuvs::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) -{ - return filter; -} - template void search_main_core(raft::resources const& res, search_params params, diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 2f201de3b..abc907da5 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -29,7 +29,7 @@ namespace cuvs::neighbors::cagra::detail { template + typename CagraSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> class factory { public: /** diff --git a/cpp/src/neighbors/detail/cagra/sample_filter_utils.cuh b/cpp/src/neighbors/detail/cagra/sample_filter_utils.cuh new file mode 100644 index 000000000..cd77b9b6b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/sample_filter_utils.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../sample_filter.cuh" + +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct CagraSampleFilterWithQueryIdOffset { + const uint32_t offset; + CagraSampleFilterT filter; + + CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) + : offset(offset), filter(filter) + { + } + + _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) + { + return filter(query_id + offset, sample_id); + } +}; + +template +struct CagraSampleFilterT_Selector { + using type = CagraSampleFilterWithQueryIdOffset; +}; +template <> +struct CagraSampleFilterT_Selector { + using type = cuvs::neighbors::filtering::none_sample_filter; +}; + +// A helper function to set a query id offset +template +inline typename CagraSampleFilterT_Selector::type set_offset( + CagraSampleFilterT filter, const uint32_t offset) +{ + typename CagraSampleFilterT_Selector::type new_filter(offset, filter); + return new_filter; +} +template <> +inline typename CagraSampleFilterT_Selector::type +set_offset( + cuvs::neighbors::filtering::none_sample_filter filter, const uint32_t) +{ + return filter; +} +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 4e3983e3f..b05afd2c9 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -38,6 +38,9 @@ */ #include "search_multi_cta_inst.cuh" +#include "sample_filter_utils.cuh" + +#define COMMA , namespace cuvs::neighbors::cagra::detail::multi_cta_search { """ @@ -65,7 +68,10 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_cagra_sample_filter);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" + ) + f.write( + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" ) f.write(trailer) # For pasting into CMakeLists.txt diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32.cu index fae5a9387..0ee0fa082 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_multi_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection(float, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(float, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32.cu index 9606d510f..3bd4df172 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32.cu @@ -23,12 +23,17 @@ * */ +#include "sample_filter_utils.cuh" #include "search_multi_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(half, uint32_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(half, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32.cu index a3322c435..4e7389b4b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_multi_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection(int8_t, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(int8_t, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 4dfc46256..9fa9d5894 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -282,7 +282,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( // Filtering if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const INDEX_T invalid_index = utils::get_max_value(); @@ -305,7 +305,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( // Post process for filtering if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const INDEX_T invalid_index = utils::get_max_value(); diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32.cu index 51fc6526f..ed0e0387c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_multi_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection(uint8_t, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(uint8_t, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index 0daae17b3..9c22134a6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -365,7 +365,7 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( } if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { if (!sample_filter(query_id, parent_index)) { parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); parent_distance_ptr[parent_list_index + (lds * query_id)] = @@ -779,7 +779,7 @@ struct search : search_plan_impl { // Topk hint can not be used when applying a filter uint32_t* const top_hint_ptr = - std::is_same::value + std::is_same::value ? topk_hint.data() : nullptr; // Init topk_hint @@ -878,7 +878,7 @@ struct search : search_plan_impl { auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { // Remove parent bit in search results remove_parent_bit(num_queries, result_buffer_size, diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 6ecbbc2e8..f23b96631 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -361,7 +361,7 @@ struct search_plan_impl : public search_plan_impl_base { std::to_string(hashmap_max_fill_rate) + " has been given."; } if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { if (hashmap_mode == hash_mode::SMALL) { error_message += "`SMALL` hash is not available when filtering"; } else { diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index 4693cd54d..d59201061 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -37,8 +37,11 @@ * */ +#include "sample_filter_utils.cuh" #include "search_single_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::single_cta_search { """ @@ -68,7 +71,10 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_cagra_sample_filter);\n" + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, cuvs::neighbors::filtering::none_sample_filter);\n" + ) + f.write( + f"instantiate_kernel_selection(\n {data_t}, {idx_t}, {distance_t}, CagraSampleFilterWithQueryIdOffset>);\n" ) f.write(trailer) diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32.cu index f8495bc01..7de479e97 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_single_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::single_cta_search { instantiate_kernel_selection(float, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(float, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32.cu index c21e6d1f4..10abe1b24 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32.cu @@ -23,12 +23,17 @@ * */ +#include "sample_filter_utils.cuh" #include "search_single_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(half, uint32_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(half, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32.cu index 56a0d8ba9..ec0ea974c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_single_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::single_cta_search { instantiate_kernel_selection(int8_t, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(int8_t, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 21a0f6bb2..79cb6bc10 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -627,8 +627,7 @@ __device__ void search_core( // topk with bitonic sort _CLK_START(); - if (std::is_same::value || + if (std::is_same::value || *filter_flag == 0) { topk_by_bitonic_sort(result_distances_buffer, result_indices_buffer, @@ -716,7 +715,7 @@ __device__ void search_core( // Filtering if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { if (threadIdx.x == 0) { *filter_flag = 0; } __syncthreads(); @@ -742,7 +741,7 @@ __device__ void search_core( // Post process for filtering if constexpr (!std::is_same::value) { + cuvs::neighbors::filtering::none_sample_filter>::value) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const INDEX_T invalid_index = utils::get_max_value(); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32.cu index ee6427170..9df50513c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32.cu @@ -23,12 +23,20 @@ * */ +#include "sample_filter_utils.cuh" #include "search_single_cta_inst.cuh" +#define COMMA , + namespace cuvs::neighbors::cagra::detail::single_cta_search { instantiate_kernel_selection(uint8_t, uint32_t, float, - cuvs::neighbors::filtering::none_cagra_sample_filter); + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection(uint8_t, + uint32_t, + float, + CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::bitset_filter>); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 3aa1d7529..e5eeecbc9 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -718,6 +718,38 @@ void brute_force_search_filtered( return; } +template +void search(raft::resources const& res, + const cuvs::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) +{ + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + return brute_force_search(res, idx, queries, neighbors, distances); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast&>( + sample_filter_ref); + if constexpr (std::is_same_v) { + RAFT_FAIL("filtered search isn't available with col_major queries yet"); + } else { + cuvs::core::bitmap_view sample_filter_view = + sample_filter.bitmap_view_; + return brute_force_search_filtered( + res, idx, queries, sample_filter_view, neighbors, distances); + } + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); + } +} + template cuvs::neighbors::brute_force::index build( raft::resources const& res, diff --git a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py index e739bddd4..1fabcca8c 100644 --- a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py +++ b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py @@ -140,28 +140,18 @@ """ search_macro = """ -#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \\ - void search(raft::resources const& handle, \\ - const cuvs::neighbors::ivf_flat::search_params& params, \\ - cuvs::neighbors::ivf_flat::index& index, \\ - raft::device_matrix_view queries, \\ - raft::device_matrix_view neighbors, \\ - raft::device_matrix_view distances) \\ - { \\ - cuvs::neighbors::ivf_flat::detail::search( \\ - handle, params, index, queries, neighbors, distances); \\ - } \\ - void search_with_filtering( \\ - raft::resources const& handle, \\ - const search_params& params, \\ - index& idx, \\ - raft::device_matrix_view queries, \\ - raft::device_matrix_view neighbors, \\ - raft::device_matrix_view distances, \\ - cuvs::neighbors::filtering::bitset_filter sample_filter) \\ - { \\ - cuvs::neighbors::ivf_flat::detail::search_with_filtering( \\ - handle, params, idx, queries, neighbors, distances, sample_filter); \\ +#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \\ + void search( \\ + raft::resources const& handle, \\ + const cuvs::neighbors::ivf_flat::search_params& params, \\ + cuvs::neighbors::ivf_flat::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances, \\ + const cuvs::neighbors::filtering::base_filter& sample_filter) \\ + { \\ + cuvs::neighbors::ivf_flat::detail::search( \\ + handle, params, index, queries, neighbors, distances, sample_filter); \\ } """ diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index a4f769741..9626b2ce5 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -1304,7 +1304,7 @@ struct select_interleaved_scan_kernel { * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) * @param stream * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to + * A filter that selects samples for a given query. Use an instance of none_sample_filter to * provide a green light for every sample. */ template diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index b7dac3ef8..032b6a8ff 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -20,13 +20,14 @@ #include "../detail/ann_utils.cuh" #include "../ivf_common.cuh" // cuvs::neighbors::detail::ivf #include "ivf_flat_interleaved_scan.cuh" // interleaved_scan -#include // none_ivf_sample_filter +#include // none_sample_filter #include // raft::neighbors::ivf_flat::index #include "../detail/ann_utils.cuh" // utils::mapping #include // is_min_close, DistanceType #include // cuvs::selection::select_k -#include // RAFT_LOG_TRACE +#include +#include // RAFT_LOG_TRACE #include #include // raft::resources #include // raft::linalg::gemm @@ -307,7 +308,7 @@ void search_impl(raft::resources const& handle, /** See raft::neighbors::ivf_flat::search docs */ template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> inline void search_with_filtering(raft::resources const& handle, const search_params& params, const index& index, @@ -402,15 +403,24 @@ void search(raft::resources const& handle, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { - search_with_filtering(handle, - params, - idx, - queries, - neighbors, - distances, - cuvs::neighbors::filtering::none_ivf_sample_filter()); + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast&>( + sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); + } } } // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cu index 93e46cbef..3f262d612 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cu @@ -35,22 +35,11 @@ namespace cuvs::neighbors::ivf_flat { cuvs::neighbors::ivf_flat::index& index, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ cuvs::neighbors::ivf_flat::detail::search( \ - handle, params, index, queries, neighbors, distances); \ - } \ - void search_with_filtering( \ - raft::resources const& handle, \ - const search_params& params, \ - index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_flat::detail::search_with_filtering( \ - handle, params, idx, queries, neighbors, distances, sample_filter); \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_IVF_FLAT_SEARCH(float, int64_t); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu index 5f75d3d48..4357afb0a 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu @@ -35,22 +35,11 @@ namespace cuvs::neighbors::ivf_flat { cuvs::neighbors::ivf_flat::index& index, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ cuvs::neighbors::ivf_flat::detail::search( \ - handle, params, index, queries, neighbors, distances); \ - } \ - void search_with_filtering( \ - raft::resources const& handle, \ - const search_params& params, \ - index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_flat::detail::search_with_filtering( \ - handle, params, idx, queries, neighbors, distances, sample_filter); \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_IVF_FLAT_SEARCH(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu index a2696dc84..8265a3e17 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu @@ -35,22 +35,11 @@ namespace cuvs::neighbors::ivf_flat { cuvs::neighbors::ivf_flat::index& index, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ cuvs::neighbors::ivf_flat::detail::search( \ - handle, params, index, queries, neighbors, distances); \ - } \ - void search_with_filtering( \ - raft::resources const& handle, \ - const search_params& params, \ - index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_flat::detail::search_with_filtering( \ - handle, params, idx, queries, neighbors, distances, sample_filter); \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_IVF_FLAT_SEARCH(uint8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py index 9b3083c3b..a5a829967 100644 --- a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py +++ b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py @@ -67,29 +67,15 @@ search_macro = """ #define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \\ void search(raft::resources const& handle, \\ - const cuvs::neighbors::ivf_pq::search_params& params, \\ - cuvs::neighbors::ivf_pq::index& index, \\ - raft::device_matrix_view queries, \\ - raft::device_matrix_view neighbors, \\ - raft::device_matrix_view distances) \\ - { \\ - cuvs::neighbors::ivf_pq::detail::search( \\ - handle, params, index, queries, neighbors, distances); \\ - } -""" -search_with_filter_macro = """ -#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \\ - void search_with_filtering(raft::resources const& handle, \\ const cuvs::neighbors::ivf_pq::search_params& params, \\ cuvs::neighbors::ivf_pq::index& index, \\ raft::device_matrix_view queries, \\ raft::device_matrix_view neighbors, \\ raft::device_matrix_view distances, \\ - cuvs::neighbors::filtering::bitset_filter< \\ - uint32_t, IdxT> sample_filter) \\ + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) \\ { \\ - cuvs::neighbors::ivf_pq::detail::search_with_filtering( \\ - handle, params, index, queries, neighbors, distances, sample_filter); \\ + cuvs::neighbors::ivf_pq::detail::search( \\ + handle, params, index, queries, neighbors, distances, sample_filter_ref); \\ } """ @@ -104,11 +90,6 @@ definition=search_macro, name="CUVS_INST_IVF_PQ_SEARCH", ), - search_with_filter=dict( - include=search_include_macro, - definition=search_with_filter_macro, - name="CUVS_INST_IVF_PQ_SEARCH_FILTER", - ), ) for type_path, (T, IdxT) in types.items(): diff --git a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py index 4c35b2836..75373e746 100644 --- a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py +++ b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq_compute_similarity.py @@ -86,7 +86,7 @@ """ none_filter_int64 = "cuvs::neighbors::filtering::ivf_to_sample_filter" \ - "" + "" bitset_filter64 = "cuvs::neighbors::filtering::ivf_to_sample_filter" \ ">" diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_float.cu index 26312a4ae..bc73ff5a3 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_float.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, float, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false.cu index f08f1700c..2aa0bacf4 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true.cu index 588c89604..d4e3fdf5c 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_half.cu index 6c2f77412..02e118158 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_half.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, half, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 7170e49db..cde961c72 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu index c552065ab..f1efe79f9 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu index 8d9399da3..bb56fd08d 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_half.cu @@ -71,4 +71,4 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, half, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu index 0f54eede7..07ee110bc 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu @@ -29,15 +29,17 @@ namespace cuvs::neighbors::ivf_pq { -#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::ivf_pq::detail::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) \ + { \ + cuvs::neighbors::ivf_pq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter_ref); \ } CUVS_INST_IVF_PQ_SEARCH(float, int64_t); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu index e5556e593..cf387cb67 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu @@ -29,15 +29,17 @@ namespace cuvs::neighbors::ivf_pq { -#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::ivf_pq::detail::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) \ + { \ + cuvs::neighbors::ivf_pq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter_ref); \ } CUVS_INST_IVF_PQ_SEARCH(half, int64_t); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu index 297e615d2..5ec9093df 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu @@ -29,15 +29,17 @@ namespace cuvs::neighbors::ivf_pq { -#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::ivf_pq::detail::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) \ + { \ + cuvs::neighbors::ivf_pq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter_ref); \ } CUVS_INST_IVF_PQ_SEARCH(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu index 3cf8bfaff..d2e2f3b00 100644 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu @@ -29,15 +29,17 @@ namespace cuvs::neighbors::ivf_pq { -#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - cuvs::neighbors::ivf_pq::detail::search(handle, params, index, queries, neighbors, distances); \ +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) \ + { \ + cuvs::neighbors::ivf_pq::detail::search( \ + handle, params, index, queries, neighbors, distances, sample_filter_ref); \ } CUVS_INST_IVF_PQ_SEARCH(uint8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu deleted file mode 100644 index 4e7541882..000000000 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_pq.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_pq.py - * - */ - -#include - -#include "../ivf_pq_search.cuh" - -namespace cuvs::neighbors::ivf_pq { - -#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \ - void search_with_filtering( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_pq::detail::search_with_filtering( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ - } -CUVS_INST_IVF_PQ_SEARCH_FILTER(float, int64_t); - -#undef CUVS_INST_IVF_PQ_SEARCH_FILTER - -} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu deleted file mode 100644 index 5874fba6c..000000000 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_pq.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_pq.py - * - */ - -#include - -#include "../ivf_pq_search.cuh" - -namespace cuvs::neighbors::ivf_pq { - -#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \ - void search_with_filtering( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_pq::detail::search_with_filtering( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ - } -CUVS_INST_IVF_PQ_SEARCH_FILTER(half, int64_t); - -#undef CUVS_INST_IVF_PQ_SEARCH_FILTER - -} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu deleted file mode 100644 index 52b1c68e7..000000000 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_pq.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_pq.py - * - */ - -#include - -#include "../ivf_pq_search.cuh" - -namespace cuvs::neighbors::ivf_pq { - -#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \ - void search_with_filtering( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_pq::detail::search_with_filtering( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ - } -CUVS_INST_IVF_PQ_SEARCH_FILTER(int8_t, int64_t); - -#undef CUVS_INST_IVF_PQ_SEARCH_FILTER - -} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu deleted file mode 100644 index e3d936155..000000000 --- a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_pq.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_pq.py - * - */ - -#include - -#include "../ivf_pq_search.cuh" - -namespace cuvs::neighbors::ivf_pq { - -#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \ - void search_with_filtering( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_pq::search_params& params, \ - cuvs::neighbors::ivf_pq::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - cuvs::neighbors::filtering::bitset_filter sample_filter) \ - { \ - cuvs::neighbors::ivf_pq::detail::search_with_filtering( \ - handle, params, index, queries, neighbors, distances, sample_filter); \ - } -CUVS_INST_IVF_PQ_SEARCH_FILTER(uint8_t, int64_t); - -#undef CUVS_INST_IVF_PQ_SEARCH_FILTER - -} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity.cuh index 48e2bf222..37612402c 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../sample_filter.cuh" // none_ivf_sample_filter +#include "../sample_filter.cuh" // none_sample_filter #include "ivf_pq_fp_8bit.cuh" // cuvs::neighbors::ivf_pq::detail::fp_8bit #include // cuvs::distance::DistanceType @@ -177,37 +177,37 @@ instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, half, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, half, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, float, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( float, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); + int64_t COMMA cuvs::neighbors::filtering::none_sample_filter>); instantiate_cuvs_neighbors_ivf_pq_detail_compute_similarity_select( half, cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh index 5fccbb385..8404ca1f9 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh @@ -17,7 +17,7 @@ #pragma once #include "../ivf_common.cuh" // dummy_block_sort_t -#include "../sample_filter.cuh" // none_ivf_sample_filter +#include "../sample_filter.cuh" // none_sample_filter #include // cuvs::distance::DistanceType #include // codebook_gen #include // matrix::detail::select::warpsort::warp_sort_distributed @@ -247,7 +247,7 @@ __device__ auto ivfpq_compute_score(uint32_t pq_dim, * query_kths keep the current state of the filtering - atomically updated distances to the * k-th closest neighbors for each query [n_queries]. * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to + * A filter that selects samples for a given query. Use an instance of none_sample_filter to * provide a green light for every sample. * @param lut_scores * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. @@ -513,7 +513,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, // The signature of the kernel defined by a minimal set of template parameters template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> using compute_similarity_kernel_t = decltype(&compute_similarity_kernel); @@ -522,7 +522,7 @@ template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> struct compute_similarity_kernel_config { public: static auto get(uint32_t pq_bits, uint32_t k_max) @@ -572,7 +572,7 @@ template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t { @@ -617,7 +617,7 @@ struct selected { template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> void compute_similarity_run(selected s, rmm::cuda_stream_view stream, uint32_t dim, @@ -682,7 +682,7 @@ void compute_similarity_run(selected s, */ template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> auto compute_similarity_select(const cudaDeviceProp& dev_props, bool manage_local_topk, int locality_hint, diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index 5f812dc4f..e185f18dc 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -19,7 +19,7 @@ #include "../../core/nvtx.hpp" #include "../detail/ann_utils.cuh" #include "../ivf_common.cuh" -#include "../sample_filter.cuh" // none_ivf_sample_filter +#include "../sample_filter.cuh" // none_sample_filter #include "ivf_pq_compute_similarity.cuh" #include "ivf_pq_fp_8bit.cuh" @@ -592,7 +592,7 @@ constexpr uint32_t kMaxQueries = 4096; /** See raft::spatial::knn::ivf_pq::search docs */ template + typename IvfSampleFilterT = cuvs::neighbors::filtering::none_sample_filter> inline void search(raft::resources const& handle, const search_params& params, const index& index, @@ -789,14 +789,23 @@ void search(raft::resources const& handle, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { - search_with_filtering(handle, - params, - idx, - queries, - neighbors, - distances, - cuvs::neighbors::filtering::none_ivf_sample_filter{}); + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast&>( + sample_filter_ref); + return search_with_filtering(handle, params, idx, queries, neighbors, distances, sample_filter); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type"); + } } } // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index 5bf315ae5..6184e540b 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -126,7 +126,7 @@ void refine_device( 0, chunk_index.data(), cuvs::distance::is_min_close(cuvs::distance::DistanceType(metric)), - cuvs::neighbors::filtering::none_ivf_sample_filter(), + cuvs::neighbors::filtering::none_sample_filter(), neighbors_uint32, distances.data_handle(), grid_dim_x, diff --git a/cpp/src/neighbors/sample_filter.cu b/cpp/src/neighbors/sample_filter.cu index 32a0d3bfb..2da4bea4e 100644 --- a/cpp/src/neighbors/sample_filter.cu +++ b/cpp/src/neighbors/sample_filter.cu @@ -18,6 +18,11 @@ namespace cuvs::neighbors::filtering { +template struct bitmap_filter; +template struct bitmap_filter; +template struct bitmap_filter; +template struct bitmap_filter; + template struct bitset_filter; template struct bitset_filter; template struct bitset_filter; diff --git a/cpp/src/neighbors/sample_filter.cuh b/cpp/src/neighbors/sample_filter.cuh index e49d54920..258116ed3 100644 --- a/cpp/src/neighbors/sample_filter.cuh +++ b/cpp/src/neighbors/sample_filter.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -26,7 +27,7 @@ namespace cuvs::neighbors::filtering { /* A filter that filters nothing. This is the default behavior. */ -inline _RAFT_HOST_DEVICE bool none_ivf_sample_filter::operator()( +inline _RAFT_HOST_DEVICE bool none_sample_filter::operator()( // query index const uint32_t query_ix, // the current inverted list index @@ -38,7 +39,7 @@ inline _RAFT_HOST_DEVICE bool none_ivf_sample_filter::operator()( } /* A filter that filters nothing. This is the default behavior. */ -inline _RAFT_HOST_DEVICE bool none_cagra_sample_filter::operator()( +inline _RAFT_HOST_DEVICE bool none_sample_filter::operator()( // query index const uint32_t query_ix, // the index of the current sample @@ -107,4 +108,20 @@ inline _RAFT_HOST_DEVICE bool bitset_filter::operator()( return bitset_view_.test(sample_ix); } +template +bitmap_filter::bitmap_filter( + const cuvs::core::bitmap_view bitmap_for_filtering) + : bitmap_view_{bitmap_for_filtering} +{ +} + +template +inline _RAFT_HOST_DEVICE bool bitmap_filter::operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const +{ + return bitmap_view_.test(query_ix, sample_ix); +} } // namespace cuvs::neighbors::filtering diff --git a/cpp/src/stats/detail/trustworthiness_score.cuh b/cpp/src/stats/detail/trustworthiness_score.cuh index f4725a2e8..4d9c3af75 100644 --- a/cpp/src/stats/detail/trustworthiness_score.cuh +++ b/cpp/src/stats/detail/trustworthiness_score.cuh @@ -108,7 +108,7 @@ void run_knn(const raft::resources& h, input_view, raft::make_device_matrix_view(indices, n, n_neighbors), raft::make_device_matrix_view(distances, n, n_neighbors), - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); } /** diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh index 461a202f2..c2afa4e8b 100644 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ b/cpp/test/neighbors/ann_brute_force.cuh @@ -96,8 +96,12 @@ class AnnBruteForceTest : public ::testing::TestWithParam( distances_bruteforce_dev.data(), ps.num_queries, ps.k); - brute_force::search( - handle_, idx, search_queries_view, indices_out_view, dists_out_view, std::nullopt); + brute_force::search(handle_, + idx, + search_queries_view, + indices_out_view, + dists_out_view, + cuvs::neighbors::filtering::none_sample_filter{}); raft::resource::sync_stream(handle_); @@ -110,8 +114,12 @@ class AnnBruteForceTest : public ::testing::TestWithParam= offset; + } +}; + /** Xorshift rondem number generator. * * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. @@ -663,6 +675,203 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam { rmm::device_uvector search_queries; }; +template +class AnnCagraFilterTest : public ::testing::TestWithParam { + public: + AnnCagraFilterTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testCagra() + { + if (ps.metric == cuvs::distance::DistanceType::InnerProduct && + ps.build_algo == graph_build_algo::NN_DESCENT) + GTEST_SKIP(); + + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; + cuvs::neighbors::naive_knn( + handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - test_cagra_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_cagra_sample_filter::offset), + queries_size, + stream_); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + + switch (ps.build_algo) { + case graph_build_algo::IVF_PQ: + index_params.graph_build_params = + graph_build_params::ivf_pq_params(raft::matrix_extent(ps.n_rows, ps.dim)); + if (ps.ivf_pq_search_refine_ratio) { + std::get( + index_params.graph_build_params) + .refinement_rate = *ps.ivf_pq_search_refine_ratio; + } + break; + case graph_build_algo::NN_DESCENT: { + index_params.graph_build_params = + graph_build_params::nn_descent_params(index_params.intermediate_graph_degree); + break; + } + case graph_build_algo::AUTO: + // do nothing + break; + }; + + index_params.compression = ps.compression; + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + + // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for + // k>1024 skip these tests until fixed + if (ps.k >= 1024) { GTEST_SKIP(); } + // search_params.itopk_size = ps.itopk_size; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + } + + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + auto removed_indices = + raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + thrust::sequence( + raft::resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + raft::resource::sync_stream(handle_); + cuvs::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.n_rows); + auto bitset_filter_obj = + cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view()); + cagra::search(handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + bitset_filter_obj); + raft::update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + // Test search results for nodes marked as filtered + bool unacceptable_node = false; + for (int q = 0; q < ps.n_queries; q++) { + for (int i = 0; i < ps.k; i++) { + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + } + } + EXPECT_FALSE(unacceptable_node); + + double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.003, + min_recall, + false)); + if (!ps.compression.has_value()) { + // Don't evaluate distances for CAGRA-Q for now as the error can be somewhat large + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + search_queries.resize(ps.n_queries * ps.dim, stream_); + raft::random::RngState r(1234ULL); + InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r); + InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + inline std::vector generate_inputs() { // TODO(tfeher): test MULTI_CTA kernel with search_width > 1 to allow multiple CTA per queries diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index d4e634719..ca188d132 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -26,9 +26,13 @@ TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraAddNodesTest AnnCagraAddNodesTestF_U32; TEST_P(AnnCagraAddNodesTestF_U32, AnnCagraAddNodes) { this->testCagra(); } +typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; +TEST_P(AnnCagraFilterTestF_U32, AnnCagra) { this->testCagra(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index 72bdee428..4aa03afd5 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -24,10 +24,13 @@ typedef AnnCagraTest AnnCagraTestI8_U32; TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraAddNodesTest AnnCagraAddNodesTestI8_U32; TEST_P(AnnCagraAddNodesTestI8_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; +TEST_P(AnnCagraFilterTestI8_U32, AnnCagra) { this->testCagra(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestI8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32, ::testing::ValuesIn(inputs)); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index b68bfa574..b8e2a6b77 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -24,10 +24,13 @@ typedef AnnCagraTest AnnCagraTestU8_U32; TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraAddNodesTest AnnCagraAddNodesTestU8_U32; TEST_P(AnnCagraAddNodesTestU8_U32, AnnCagra) { this->testCagra(); } +typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; +TEST_P(AnnCagraFilterTestU8_U32, AnnCagra) { this->testCagra(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestU8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestU8_U32, ::testing::ValuesIn(inputs)); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 17ec84097..8cc46b2f7 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -304,7 +304,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivf::resize_list(handle_, lists[label], list_device_spec, list_size, 0); } - idx.recompute_internal_state(handle_); + ivf_flat::helpers::recompute_internal_state(handle_, &idx); using interleaved_group = raft::Pow2; @@ -466,18 +466,19 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { cuvs::core::bitset removed_indices_bitset( handle_, removed_indices.view(), ps.num_db_vecs); + auto bitset_filter_obj = + cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view()); // Search with the filter auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); - ivf_flat::search_with_filtering( - handle_, - search_params, - index, - search_queries_view, - indices_ivfflat_dev.view(), - distances_ivfflat_dev.view(), - cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + ivf_flat::search(handle_, + search_params, + index, + search_queries_view, + indices_ivfflat_dev.view(), + distances_ivfflat_dev.view(), + bitset_filter_obj); raft::update_host( distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index 0ce168f5e..6a4a34516 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -21,7 +21,12 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_float; -TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { this->testIVFFlat(); } +TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) +{ + this->testIVFFlat(); + this->testPacker(); + this->testFilter(); +} INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index 15935fd88..5335b1656 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -21,7 +21,12 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; -TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } +TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) +{ + this->testIVFFlat(); + this->testPacker(); + this->testFilter(); +} INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index 42a8dab2e..e5573bcbc 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -21,7 +21,12 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; -TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } +TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) +{ + this->testIVFFlat(); + this->testPacker(); + this->testFilter(); +} INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index e6d8efc93..f02568b74 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -18,10 +18,10 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" #include "naive_knn.cuh" +#include #include #include -#include #include #include #include @@ -629,14 +629,10 @@ class ivf_pq_filter_test : public ::testing::TestWithParam { cuvs::core::bitset removed_indices_bitset( handle_, removed_indices.view(), ps.num_db_vecs); - cuvs::neighbors::ivf_pq::search_with_filtering( - handle_, - ps.search_params, - index, - query_view, - inds_view, - dists_view, - cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + auto bitset_filter_obj = + cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset.view()); + cuvs::neighbors::ivf_pq::search( + handle_, ps.search_params, index, query_view, inds_view, dists_view, bitset_filter_obj); raft::update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); raft::update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu index f1a05e045..8c354baa9 100644 --- a/cpp/test/neighbors/brute_force.cu +++ b/cpp/test/neighbors/brute_force.cu @@ -93,7 +93,8 @@ class KNNTest : public ::testing::TestWithParam> { auto metric = cuvs::distance::DistanceType::L2Unexpanded; auto idx = cuvs::neighbors::brute_force::build(handle, index, metric); - cuvs::neighbors::brute_force::search(handle, idx, search, indices, distances, std::nullopt); + cuvs::neighbors::brute_force::search( + handle, idx, search, indices, distances, cuvs::neighbors::filtering::none_sample_filter{}); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); @@ -401,7 +402,7 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); } else { auto idx = cuvs::neighbors::brute_force::build( handle_, @@ -417,7 +418,7 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam search_queries.data(), params_.num_queries, params_.dim), indices, distances, - std::nullopt); + cuvs::neighbors::filtering::none_sample_filter{}); } ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(ref_indices_.data(), diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index ae9111ea1..12b1c529e 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -502,7 +502,12 @@ class PrefilteredBruteForceTest auto out_idx = raft::make_device_matrix_view( out_idx_d.data(), params.n_queries, params.top_k); - brute_force::search(handle, dataset, queries, out_idx, out_val, std::make_optional(filter)); + brute_force::search(handle, + dataset, + queries, + out_idx, + out_val, + cuvs::neighbors::filtering::bitmap_filter(filter)); std::vector out_val_h(params.n_queries * params.top_k, std::numeric_limits::infinity());