Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-CTA algorithm #492

Open
wants to merge 24 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6223fd2
[Improved Multi-CTA algo] Address low recall issue of multi-CTA algo …
anaruse Sep 26, 2024
8ff6991
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse Dec 5, 2024
37e26c1
fix style
tfeher Dec 5, 2024
3665d45
Merge branch 'branch-24.12' into improved_multi_cta_algo
anaruse Dec 5, 2024
018e792
Merge branch 'branch-25.02' into improved_multi_cta_algo
achirkin Dec 9, 2024
ab1130b
Check if CAGRA search returns enough valid indices during add_nodes
achirkin Dec 9, 2024
bedd224
Resolving various issues with the new multi-CTA algorithm
anaruse Dec 20, 2024
ea8c273
Add comments in add_nodes.cuh
anaruse Dec 23, 2024
5025481
Limit tht number of warnings output
anaruse Dec 23, 2024
b61126a
Avoid invalid results in search results as much as possible
anaruse Dec 25, 2024
588bd0c
Improve the accuracy of the new multi-CTA algo by revising the usase …
anaruse Dec 30, 2024
228a1ae
Reduce the number of shared memory access
anaruse Jan 6, 2025
776f2f5
Remove unused code
anaruse Jan 8, 2025
9d262f7
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 8, 2025
192c0a9
Update cpp/src/neighbors/detail/cagra/device_common.hpp
anaruse Jan 10, 2025
d19a6c4
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 16, 2025
b5c31b3
Merge branch 'branch-25.02' into improved_multi_cta_algo
anaruse Jan 17, 2025
81e4b39
Fixed data type issues
anaruse Jan 17, 2025
cdc4bc4
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 23, 2025
dd371dc
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 24, 2025
e769ca7
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 25, 2025
ce93427
Merge branch 'branch-25.02' into improved_multi_cta_algo
achirkin Jan 27, 2025
c133c8b
Merge branch 'branch-25.02' into improved_multi_cta_algo
cjnolet Jan 29, 2025
baa3c0c
Fixed problem of infinite loop when graph degree is small
anaruse Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,31 @@ void add_node_core(
raft::resource::get_cuda_stream(handle));
raft::resource::sync_stream(handle);

// Check search results
constexpr int max_warnings = 3;
int num_warnings = 0;
for (std::size_t vec_i = 0; vec_i < batch.size(); vec_i++) {
std::uint32_t invalid_edges = 0;
for (std::uint32_t i = 0; i < base_degree; i++) {
if (host_neighbor_indices(vec_i, i) >= old_size) { invalid_edges++; }
}
if (invalid_edges > 0) {
if (num_warnings < max_warnings) {
RAFT_LOG_WARN(
"Invalid edges found in search results "
"(vec_i:%lu, invalid_edges:%lu, degree:%lu, base_degree:%lu)",
(uint64_t)vec_i,
(uint64_t)invalid_edges,
(uint64_t)degree,
(uint64_t)base_degree);
}
num_warnings += 1;
}
}
if (num_warnings > max_warnings) {
RAFT_LOG_WARN("The number of queries that contain invalid search results: %d", num_warnings);
}
anaruse marked this conversation as resolved.
Show resolved Hide resolved

// Step 2: rank-based reordering
#pragma omp parallel
{
Expand All @@ -147,9 +172,16 @@ void add_node_core(
for (std::uint32_t i = 0; i < base_degree; i++) {
std::uint32_t detourable_node_count = 0;
const auto a_id = host_neighbor_indices(vec_i, i);
if (a_id >= idx.size()) {
// If the node ID is not valid, the number of detours is increased
// to a value greater than the maximum, so that the edge to that
// node is not selected as much as possible.
detourable_node_count_list[i] = std::make_pair(a_id, base_degree + 1);
anaruse marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
for (std::uint32_t j = 0; j < i; j++) {
const auto b0_id = host_neighbor_indices(vec_i, j);
assert(b0_id < idx.size());
if (b0_id >= idx.size()) { continue; }
for (std::uint32_t k = 0; k < degree; k++) {
const auto b1_id = updated_graph(b0_id, k);
if (a_id == b1_id) {
Expand All @@ -160,6 +192,7 @@ void add_node_core(
}
detourable_node_count_list[i] = std::make_pair(a_id, detourable_node_count);
}

std::sort(detourable_node_count_list.begin(),
detourable_node_count_list.end(),
[&](const std::pair<IdxT, std::size_t> a, const std::pair<IdxT, std::size_t> b) {
Expand All @@ -181,13 +214,18 @@ void add_node_core(
const auto target_new_node_id = old_size + batch.offset() + vec_i;
for (std::size_t i = 0; i < num_rev_edges; i++) {
const auto target_node_id = updated_graph(old_size + batch.offset() + vec_i, i);

if (target_node_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", target_node_id);
}
IdxT replace_id = new_size;
IdxT replace_id_j = 0;
std::size_t replace_num_incoming_edges = 0;
for (std::int32_t j = degree - 1; j >= static_cast<std::int32_t>(rev_edge_search_range);
j--) {
const auto neighbor_id = updated_graph(target_node_id, j);
const auto neighbor_id = updated_graph(target_node_id, j);
if (neighbor_id >= new_size) {
RAFT_FAIL("Invalid node ID found in updated_graph (%u)\n", neighbor_id);
}
const std::size_t num_incoming_edges = host_num_incoming_edges(neighbor_id);
if (num_incoming_edges > replace_num_incoming_edges) {
// Check duplication
Expand All @@ -206,10 +244,6 @@ void add_node_core(
replace_id_j = j;
}
}
if (replace_id >= new_size) {
std::fprintf(stderr, "Invalid rev edge index (%u)\n", replace_id);
return;
}
updated_graph(target_node_id, replace_id_j) = target_new_node_id;
rev_edges[i] = replace_id;
}
Expand All @@ -221,13 +255,15 @@ void add_node_core(
const auto rank_based_list_ptr =
updated_graph.data_handle() + (old_size + batch.offset() + vec_i) * degree;
const auto rev_edges_return_list_ptr = rev_edges.data();
while (num_add < degree) {
while ((num_add < degree) &&
((rank_base_i < degree) || (rev_edges_return_i < num_rev_edges))) {
const auto node_list_ptr =
interleave_switch == 0 ? rank_based_list_ptr : rev_edges_return_list_ptr;
auto& node_list_index = interleave_switch == 0 ? rank_base_i : rev_edges_return_i;
const auto max_node_list_index = interleave_switch == 0 ? degree : num_rev_edges;
for (; node_list_index < max_node_list_index; node_list_index++) {
const auto candidate = node_list_ptr[node_list_index];
if (candidate >= new_size) { continue; }
// Check duplication
bool dup = false;
for (std::uint32_t j = 0; j < num_add; j++) {
Expand All @@ -244,6 +280,12 @@ void add_node_core(
}
interleave_switch = 1 - interleave_switch;
}
if (num_add < degree) {
RAFT_FAIL("Number of edges is not enough (target_new_node_id:%lu, num_add:%lu, degree:%lu)",
(uint64_t)target_new_node_id,
(uint64_t)num_add,
(uint64_t)degree);
}
for (std::uint32_t i = 0; i < degree; i++) {
updated_graph(target_new_node_id, i) = temp[i];
}
Expand All @@ -259,7 +301,9 @@ void add_graph_nodes(
raft::host_matrix_view<IdxT, std::int64_t> updated_graph_view,
const cagra::extend_params& params)
{
assert(input_updated_dataset_view.extent(0) >= index.size());
if (input_updated_dataset_view.extent(0) < index.size()) {
RAFT_FAIL("Updated dataset must be not smaller than the previous index state.");
}

const std::size_t initial_dataset_size = index.size();
const std::size_t new_dataset_size = input_updated_dataset_view.extent(0);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void search_main_core(raft::resources const& res,
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DataT, IndexT, DistanceT, CagraSampleFilterT_s>> plan =
factory<DataT, IndexT, DistanceT, CagraSampleFilterT_s>::create(
res, params, dataset_desc, queries.extent(1), graph.extent(1), topk);
res, params, dataset_desc, queries.extent(1), graph.extent(0), graph.extent(1), topk);

plan->check(topk);

Expand Down
64 changes: 47 additions & 17 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(
const IndexT* __restrict__ seed_ptr, // [num_seeds]
const uint32_t num_seeds,
IndexT* __restrict__ visited_hash_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hash_ptr,
const uint32_t traversed_hash_bitlen,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand Down Expand Up @@ -145,19 +147,29 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes(

const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u);
if (valid_i && lane_id == 0) {
if (best_index_team_local != raft::upper_bound<IndexT>() &&
hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
} else {
result_distances_ptr[i] = raft::upper_bound<DistanceT>();
result_indices_ptr[i] = raft::upper_bound<IndexT>();
if (best_index_team_local != raft::upper_bound<IndexT>()) {
if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
} else if ((traversed_hash_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) {
// Deactivate this entry as it has been already used by others.
best_norm2_team_local = raft::upper_bound<DistanceT>();
best_index_team_local = raft::upper_bound<IndexT>();
}
}
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
}
}
}

template <typename IndexT, typename DistanceT, typename DATASET_DESCRIPTOR_T>
template <typename IndexT,
typename DistanceT,
typename DATASET_DESCRIPTOR_T,
int STATIC_RESULT_POSITION = 1>
RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
IndexT* __restrict__ result_child_indices_ptr,
DistanceT* __restrict__ result_child_distances_ptr,
Expand All @@ -168,13 +180,17 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const uint32_t knn_k,
// hashmap
IndexT* __restrict__ visited_hashmap_ptr,
const uint32_t hash_bitlen,
const uint32_t visited_hash_bitlen,
IndexT* __restrict__ traversed_hashmap_ptr,
const uint32_t traversed_hash_bitlen,
const IndexT* __restrict__ parent_indices,
const IndexT* __restrict__ internal_topk_list,
const uint32_t search_width)
const uint32_t search_width,
int* __restrict__ result_position = nullptr,
const int max_result_position = 0)
{
constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask<IndexT>::value;
constexpr IndexT invalid_index = raft::upper_bound<IndexT>();
constexpr IndexT invalid_index = ~static_cast<IndexT>(0);

// Read child indices of parents from knn graph and check if the distance
// computaiton is necessary.
Expand All @@ -186,11 +202,22 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) {
// Deactivate this entry as insertion into visited hash table has failed.
child_id = invalid_index;
} else if ((traversed_hashmap_ptr != nullptr) &&
hashmap::search<IndexT, 1>(
traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) {
// Deactivate this entry as this has been already used by others.
child_id = invalid_index;
}
}
result_child_indices_ptr[i] = child_id;
if (STATIC_RESULT_POSITION) {
result_child_indices_ptr[i] = child_id;
} else if (child_id != invalid_index) {
int j = atomicSub(result_position, 1) - 1;
result_child_indices_ptr[j] = child_id;
}
}
__syncthreads();

Expand All @@ -201,9 +228,11 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
const auto compute_distance = dataset_desc.compute_distance_impl;
const auto args = dataset_desc.args.load();
const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0;
const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0];
for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) {
const bool valid_i = i < num_k;
const auto child_id = valid_i ? result_child_indices_ptr[i] : invalid_index;
const auto j = i + ofst;
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position);
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index;

// We should be calling `dataset_desc.compute_distance(..)` here as follows:
// > const auto child_dist = dataset_desc.compute_distance(child_id, child_id != invalid_index);
Expand All @@ -213,9 +242,10 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes(
(child_id != invalid_index) ? compute_distance(args, child_id)
: (lead_lane ? raft::upper_bound<DistanceT>() : 0),
team_size_bits);
__syncwarp();

// Store the distance
if (valid_i && lead_lane) { result_child_distances_ptr[i] = child_dist; }
if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; }
}
}

Expand Down
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ class factory {
search_params const& params,
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc,
int64_t dim,
int64_t dataset_size,
int64_t graph_degree,
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, dataset_size, graph_degree, topk);
return dispatch_kernel(res, plan, dataset_desc);
}

Expand All @@ -56,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
Loading
Loading