diff --git a/build.sh b/build.sh index eb360ff32..ba71e5f93 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # cuvs build scripts @@ -24,7 +24,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=:CUDA::nvToolsExt>) +target_link_libraries(cuvs::cuvs INTERFACE $<$:CUDA::nvtx3>) target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENABLED>) ]=] diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 90c4218c5..c846416a4 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,7 @@ #include #include -#include -#include -#include +#include #include #include #include @@ -57,7 +55,7 @@ class RaftCagra : public ANN { using typename ANN::AnnSearchParam; struct SearchParam : public AnnSearchParam { - cuvs::neighbors::experimental::cagra::search_params p; + cuvs::neighbors::cagra::search_params p; AllocatorType graph_mem = AllocatorType::Device; AllocatorType dataset_mem = AllocatorType::Device; auto needs_dataset() const -> bool override { return true; } @@ -209,7 +207,7 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) allocator_to_string(dataset_mem_).c_str()); auto mr = get_mr(dataset_mem_); - cuvs::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_v_, mr); + raft::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_v_, mr); index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); diff --git a/cpp/bench/micro/neighbors/cagra_bench.cuh b/cpp/bench/micro/neighbors/cagra_bench.cuh index 3be664db8..0cc8c9578 100644 --- a/cpp/bench/micro/neighbors/cagra_bench.cuh +++ b/cpp/bench/micro/neighbors/cagra_bench.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index d45be4aef..d57d27312 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -35,8 +35,8 @@ function(find_and_configure_raft) #----------------------------------------------------- rapids_cpm_find(raft ${PKG_VERSION} GLOBAL_TARGETS raft::raft - BUILD_EXPORT_SET cuvs-template-exports - INSTALL_EXPORT_SET cuvs-template-exports + BUILD_EXPORT_SET cuvs-exports + INSTALL_EXPORT_SET cuvs-exports COMPONENTS ${RAFT_COMPONENTS} CPM_ARGS GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git @@ -46,7 +46,7 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_PRIMS_BENCH OFF" "BUILD_ANN_BENCH OFF" - "RAFT_NVTX ${ENABLE_NVTX}" + "RAFT_NVTX ${PKG_ENABLE_NVTX}" "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" ) endfunction() diff --git a/cpp/include/cuvs/cluster/detail/agglomerative.cuh b/cpp/include/cuvs/cluster/detail/agglomerative.cuh deleted file mode 100644 index e5f1a9ba9..000000000 --- a/cpp/include/cuvs/cluster/detail/agglomerative.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs::cluster::detail { -template -class UnionFind { - public: - value_idx next_label; - std::vector parent; - std::vector size; - - value_idx n_indices; - - UnionFind(value_idx N_) - : n_indices(2 * N_ - 1), parent(2 * N_ - 1, -1), size(2 * N_ - 1, 1), next_label(N_) - { - memset(size.data() + N_, 0, (size.size() - N_) * sizeof(value_idx)); - } - - value_idx find(value_idx n) - { - value_idx p; - p = n; - - while (parent[n] != -1) - n = parent[n]; - - // path compression - while (parent[p] != n) { - p = parent[p == -1 ? n_indices - 1 : p]; - parent[p == -1 ? n_indices - 1 : p] = n; - } - return n; - } - - void perform_union(value_idx m, value_idx n) - { - size[next_label] = size[m] + size[n]; - parent[m] = next_label; - parent[n] = next_label; - - next_label += 1; - } -}; - -/** - * Agglomerative labeling on host. This has not been found to be a bottleneck - * in the algorithm. A parallel version of this can be done using a parallel - * variant of Kruskal's MST algorithm - * (ref http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf), - * which breaks apart the sorted MST results into overlapping subsets and - * independently runs Kruskal's algorithm on each subset, merging them back - * together into a single hierarchy when complete. Unfortunately, - * this is nontrivial and the speedup wouldn't be useful until this - * becomes a bottleneck. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle the raft handle - * @param[in] rows src edges of the sorted MST - * @param[in] cols dst edges of the sorted MST - * @param[in] nnz the number of edges in the sorted MST - * @param[out] out_src parents of output - * @param[out] out_dst children of output - * @param[out] out_delta distances of output - * @param[out] out_size cluster sizes of output - */ -template -void build_dendrogram_host(raft::resources const& handle, - const value_idx* rows, - const value_idx* cols, - const value_t* data, - size_t nnz, - value_idx* children, - value_t* out_delta, - value_idx* out_size) -{ - auto stream = resource::get_cuda_stream(handle); - - value_idx n_edges = nnz; - - std::vector mst_src_h(n_edges); - std::vector mst_dst_h(n_edges); - std::vector mst_weights_h(n_edges); - - update_host(mst_src_h.data(), rows, n_edges, stream); - update_host(mst_dst_h.data(), cols, n_edges, stream); - update_host(mst_weights_h.data(), data, n_edges, stream); - - resource::sync_stream(handle, stream); - - std::vector children_h(n_edges * 2); - std::vector out_size_h(n_edges); - std::vector out_delta_h(n_edges); - - UnionFind U(nnz + 1); - - for (std::size_t i = 0; i < nnz; i++) { - value_idx a = mst_src_h[i]; - value_idx b = mst_dst_h[i]; - value_t delta = mst_weights_h[i]; - - value_idx aa = U.find(a); - value_idx bb = U.find(b); - - value_idx children_idx = i * 2; - - children_h[children_idx] = aa; - children_h[children_idx + 1] = bb; - out_delta_h[i] = delta; - out_size_h[i] = U.size[aa] + U.size[bb]; - - U.perform_union(aa, bb); - } - - raft::update_device(children, children_h.data(), n_edges * 2, stream); - raft::update_device(out_size, out_size_h.data(), n_edges, stream); - raft::update_device(out_delta, out_delta_h.data(), n_edges, stream); -} - -template -RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < n_vertices) { - value_idx level = tid / 2; - value_idx child = children[tid]; - parents[child] = level; - } -} - -/** - * Instead of propagating a label from roots to children, - * the children each iterate up the tree until they find - * the label of their parent. This increases the potential - * parallelism. - * @tparam value_idx - * @param children - * @param parents - * @param n_leaves - * @param labels - */ -template -RAFT_KERNEL inherit_labels(const value_idx* children, - const value_idx* levels, - std::size_t n_leaves, - value_idx* labels, - int cut_level, - value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - - if (tid < n_vertices) { - value_idx node = children[tid]; - value_idx cur_level = tid / 2; - - /** - * Any roots above the cut level should be ignored. - * Any leaves at the cut level should already be labeled - */ - if (cur_level > cut_level) return; - - value_idx cur_parent = node; - value_idx label = labels[cur_parent]; - - while (label == -1) { - cur_parent = cur_level + n_leaves; - cur_level = levels[cur_parent]; - label = labels[cur_parent]; - } - - labels[node] = label; - } -} - -template -struct init_label_roots { - init_label_roots(value_idx* labels_) : labels(labels_) {} - - template - __host__ __device__ void operator()(Tuple t) - { - labels[thrust::get<1>(t)] = thrust::get<0>(t); - } - - private: - value_idx* labels; -}; - -/** - * Cuts the dendrogram at a particular level where the number of nodes - * is equal to n_clusters, then propagates the resulting labels - * to all the children. - * - * @tparam value_idx - * @param handle - * @param labels - * @param children - * @param n_clusters - * @param n_leaves - */ -template -void extract_flattened_clusters(raft::resources const& handle, - value_idx* labels, - const value_idx* children, - size_t n_clusters, - size_t n_leaves) -{ - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Handle special case where n_clusters == 1 - if (n_clusters == 1) { - thrust::fill(thrust_policy, labels, labels + n_leaves, 0); - } else { - /** - * Compute levels for each node - * - * 1. Initialize "levels" array of size n_leaves * 2 - * - * 2. For each entry in children, write parent - * out for each of the children - */ - - auto n_edges = (n_leaves - 1) * 2; - - thrust::device_ptr d_ptr = thrust::device_pointer_cast(children); - value_idx n_vertices = *(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1; - - // Prevent potential infinite loop from labeling disconnected - // connectivities graph. - RAFT_EXPECTS(n_leaves > 0, "n_leaves must be positive"); - RAFT_EXPECTS( - static_cast(n_vertices) == static_cast((n_leaves - 1) * 2), - "Multiple components found in MST or MST is invalid. " - "Cannot find single-linkage solution."); - - rmm::device_uvector levels(n_vertices, stream); - - value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb); - write_levels_kernel<<>>(children, levels.data(), n_vertices); - /** - * Step 1: Find label roots: - * - * 1. Copying children[children.size()-(n_clusters-1):] entries to - * separate arrayo - * 2. sort array - * 3. take first n_clusters entries - */ - - value_idx child_size = (n_clusters - 1) * 2; - rmm::device_uvector label_roots(child_size, stream); - - value_idx children_cpy_start = n_edges - child_size; - raft::copy_async(label_roots.data(), children + children_cpy_start, child_size, stream); - - thrust::sort(thrust_policy, - label_roots.data(), - label_roots.data() + (child_size), - thrust::greater()); - - rmm::device_uvector tmp_labels(n_vertices, stream); - - // Init labels to -1 - thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices, -1); - - // Write labels for cluster roots to "labels" - thrust::counting_iterator first(0); - - auto z_iter = thrust::make_zip_iterator( - thrust::make_tuple(first, label_roots.data() + (label_roots.size() - n_clusters))); - - thrust::for_each( - thrust_policy, z_iter, z_iter + n_clusters, init_label_roots(tmp_labels.data())); - - /** - * Step 2: Propagate labels by having children iterate through their parents - * 1. Initialize labels to -1 - * 2. For each element in levels array, propagate until parent's - * label is !=-1 - */ - value_idx cut_level = (n_edges / 2) - (n_clusters - 1); - - inherit_labels<<>>( - children, levels.data(), n_leaves, tmp_labels.data(), cut_level, n_vertices); - - // copy tmp labels to actual labels - raft::copy_async(labels, tmp_labels.data(), n_leaves, stream); - } -} - -}; // namespace cuvs::cluster::detail diff --git a/cpp/include/cuvs/cluster/detail/connectivities.cuh b/cpp/include/cuvs/cluster/detail/connectivities.cuh deleted file mode 100644 index 165058dbd..000000000 --- a/cpp/include/cuvs/cluster/detail/connectivities.cuh +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include - -namespace cuvs::cluster::detail { - -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c); -}; - -/** - * Connectivities specialization to build a knn graph - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Need to symmetrize knn into undirected graph - raft::sparse::COO knn_graph_coo(stream); - - raft::sparse::neighbors::knn_graph(handle, X, m, n, metric, knn_graph_coo, c); - - indices.resize(knn_graph_coo.nnz, stream); - data.resize(knn_graph_coo.nnz, stream); - - // self-loops get max distance - auto transform_in = thrust::make_zip_iterator( - thrust::make_tuple(knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals())); - - thrust::transform(thrust_policy, - transform_in, - transform_in + knn_graph_coo.nnz, - knn_graph_coo.vals(), - [=] __device__(const thrust::tuple& tup) { - bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup); - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<2>(tup)); - }); - - raft::sparse::convert::sorted_coo_to_csr( - knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream); - - // TODO: Wouldn't need to copy here if we could compute knn - // graph directly on the device uvectors - // ref: https://github.com/rapidsai/raft/issues/227 - raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz, stream); - raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz, stream); - } -}; - -template -RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz) -{ - value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid >= nnz) return; - value_idx v = tid % m; - indices[tid] = v; -} - -/** - * Compute connected CSR of pairwise distances - * @tparam value_idx - * @tparam value_t - * @param handle - * @param X - * @param m - * @param n - * @param metric - * @param[out] indptr - * @param[out] indices - * @param[out] data - */ -template -void pairwise_distances(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - value_idx* indptr, - value_idx* indices, - value_t* data) -{ - auto stream = resource::get_cuda_stream(handle); - auto exec_policy = resource::get_thrust_policy(handle); - - value_idx nnz = m * m; - - value_idx blocks = raft::ceildiv(nnz, (value_idx)256); - fill_indices2<<>>(indices, m, nnz); - - thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m); - - raft::update_device(indptr + m, &nnz, 1, stream); - - // TODO: It would ultimately be nice if the MST could accept - // dense inputs directly so we don't need to double the memory - // usage to hand it a sparse array here. - distance::pairwise_distance(handle, X, X, data, m, m, n, metric); - // self-loops get max distance - auto transform_in = - thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); - - thrust::transform(exec_policy, - transform_in, - transform_in + nnz, - data, - [=] __device__(const thrust::tuple& tup) { - value_idx idx = thrust::get<0>(tup); - bool self_loop = idx % m == idx / m; - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<1>(tup)); - }); -} - -/** - * Connectivities specialization for pairwise distances - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - - size_t nnz = m * m; - - indices.resize(nnz, stream); - data.resize(nnz, stream); - - pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data()); - } -}; - -/** - * Returns a CSR connectivities graph based on the given linkage distance. - * @tparam value_idx - * @tparam value_t - * @tparam dist_type - * @param[in] handle raft handle - * @param[in] X dense data for which to construct connectivites - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metric to use - * @param[out] indptr indptr array of connectivities graph - * @param[out] indices column indices array of connectivities graph - * @param[out] data distances array of connectivities graph - * @param[out] c constant 'c' used for nearest neighbors-based distances - * which will guarantee k <= log(n) + c - */ -template -void get_distance_graph(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) -{ - auto stream = resource::get_cuda_stream(handle); - - indptr.resize(m + 1, stream); - - distance_graph_impl dist_graph; - dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c); -} - -}; // namespace cuvs::cluster::detail diff --git a/cpp/include/cuvs/cluster/detail/kmeans.cuh b/cpp/include/cuvs/cluster/detail/kmeans.cuh deleted file mode 100644 index 1ed9f4ccd..000000000 --- a/cpp/include/cuvs/cluster/detail/kmeans.cuh +++ /dev/null @@ -1,1255 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs { -namespace cluster { -namespace detail { - -// ========================================================= -// Init functions -// ========================================================= - -// Selects 'n_clusters' samples randomly from X -template -void initRandom(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids) -{ - raft::common::nvtx::range fun_scope("initRandom"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_clusters = params.n_clusters; - detail::shuffleAndGather(handle, X, centroids, n_clusters, params.rng_state.seed); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: while |C| < k - * 3: Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - * 4: C = C U {x} - * 5: end for - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - raft::common::nvtx::range fun_scope("kmeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // number of seeding trials for each center (except the first) - auto n_trials = 2 + static_cast(std::ceil(log(n_clusters))); - - RAFT_LOG_DEBUG( - "Run sequential k-means++ to select %d centroids from %d input samples " - "(%d seeding trials per iterations)", - n_clusters, - n_samples, - n_trials); - - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // temporary buffers - auto indices = raft::make_device_vector(handle, n_trials); - auto centroidCandidates = raft::make_device_matrix(handle, n_trials, n_features); - auto costPerCandidate = raft::make_device_vector(handle, n_trials); - auto minClusterDistance = raft::make_device_vector(handle, n_samples); - auto distBuffer = raft::make_device_matrix(handle, n_trials, n_samples); - - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - rmm::device_scalar clusterCost(stream); - rmm::device_scalar> minClusterIndexAndDistance(stream); - - // Device and matrix views - raft::device_vector_view indices_view(indices.data_handle(), n_trials); - auto const_weights_view = - raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples); - auto const_indices_view = - raft::make_device_vector_view(indices.data_handle(), n_trials); - auto const_X_view = - raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); - raft::device_matrix_view candidates_view( - centroidCandidates.data_handle(), n_trials, n_features); - - // L2 norm of X: ||c||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - // <<< Step-1 >>>: C <-- sample a point uniformly at random from X - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + dis(gen) * n_features, 1, n_features); - int n_clusters_picked = 1; - - // store the chosen centroid in the buffer - raft::copy( - centroidsRawData.data_handle(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - // C = initial set of centroids - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // Calculate cluster distance, d^2(x, C), for all the points x in X to the nearest centroid - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - - // <<<< Step-2 >>> : while |C| < k - while (n_clusters_picked < n_clusters) { - // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - // Choose 'n_trials' centroid candidates from X with probability proportional to the squared - // distance to the nearest existing cluster - - raft::random::discrete(handle, rng, indices_view, const_weights_view); - raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view); - - // Calculate pairwise distance between X and the centroid candidates - // Output - pwd [n_trials x n_samples] - auto pwd = distBuffer.view(); - detail::pairwise_distance_kmeans( - handle, centroidCandidates.view(), X, pwd, workspace, metric); - - // Update nearest cluster distance for each centroid candidate - // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. - // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated - // minClusterDistance that includes candidate-i - auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp(minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - raft::min_op{}, - stream); - - // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using - // centroid candidate-i - raft::linalg::reduce(costPerCandidate.data_handle(), - minDistBuf.data_handle(), - minDistBuf.extent(1), - minDistBuf.extent(0), - static_cast(0), - true, - true, - stream); - - // Greedy Choice - Choose the candidate that has minimum cluster cost - // ArgMin operation below identifies the index of minimum cost in costPerCandidate - { - // Determine temporary device storage requirements - size_t temp_storage_bytes = 0; - cub::DeviceReduce::ArgMin(nullptr, - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - // Allocate temporary storage - workspace.resize(temp_storage_bytes, stream); - - // Run argmin-reduction - cub::DeviceReduce::ArgMin(workspace.data(), - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - int bestCandidateIdx = -1; - raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); - resource::sync_stream(handle); - /// <<< End of Step-3 >>> - - /// <<< Step-4 >>>: C = C U {x} - // Update minimum cluster distance corresponding to the chosen centroid candidate - raft::copy(minClusterDistance.data_handle(), - minDistBuf.data_handle() + bestCandidateIdx * n_samples, - n_samples, - stream); - - raft::copy(centroidsRawData.data_handle() + n_clusters_picked * n_features, - centroidCandidates.data_handle() + bestCandidateIdx * n_features, - n_features, - stream); - - ++n_clusters_picked; - /// <<< End of Step-4 >>> - } - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - } /// <<<< Step-5 >>> -} - -/** - * - * @tparam DataT - * @tparam IndexT - * @param handle - * @param[in] X input matrix (size n_samples, n_features) - * @param[in] weight number of samples currently assigned to each centroid - * @param[in] cur_centroids matrix of current centroids (size n_clusters, n_features) - * @param[in] l2norm_x - * @param[out] min_cluster_and_dist - * @param[out] new_centroids - * @param[out] new_weight - * @param[inout] workspace - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - - // TODO: Figure out how to best wrap iterator types in mdspan - LabelsIterator cluster_labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids, - rmm::device_uvector& workspace) -{ - auto n_clusters = centroids.extent(0); - auto n_samples = X.extent(0); - - workspace.resize(n_samples, resource::get_cuda_stream(handle)); - - // Calculates weighted sum of all the samples assigned to cluster-i and stores the - // result in new_centroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - cluster_labels, - sample_weights.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - n_clusters, - new_centroids.data_handle(), - resource::get_cuda_stream(handle)); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - cluster_labels, - weight_per_cluster.data_handle(), - (IndexT)1, - (IndexT)sample_weights.extent(0), - (IndexT)n_clusters, - resource::get_cuda_stream(handle)); - - // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where - // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the - // samples assigned to cluster-i - // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in - // cluster-i. - // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp(new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - raft::div_checkzero_op{}, - resource::get_cuda_stream(handle)); - - // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); - raft::matrix::gather_if( - const_cast(centroids.data_handle()), - static_cast(centroids.extent(1)), - static_cast(centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(weight_per_cluster.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the sum of weights in the cluster is 0 - return map.value == 0; - }, - raft::key_op{}, - resource::get_cuda_stream(handle)); -} - -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - raft::common::nvtx::range fun_scope("kmeans_fit_main"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of - // raft::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - update_centroids(handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - itr, - wtInCluster.view(), - newCentroids.view(), - workspace); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - newCentroids.size(), - raft::sqdiff_op{}, - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); - - raft::copy( - centroidsRawData.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - resource::sync_stream(handle, stream); - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - inertia[0] = clusterCostD.value(stream); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} - -/* - * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. - - * @note This is the algorithm described in - * "Scalable K-Means++", 2012, Bahman Bahmani, Benjamin Moseley, - * Andrea Vattani, Ravi Kumar, Sergei Vassilvitskii, - * https://arxiv.org/abs/1203.6402 - - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: psi = phi_X (C) - * 3: for O( log(psi) ) times do - * 4: C' = sample each point x in X independently with probability - * p_x = l * (d^2(x, C) / phi_X (C) ) - * 5: C = C U C' - * 6: end for - * 7: For x in C, set w_x to be the number of points in X closer to x than any - * other point in C - * 8: Recluster the weighted points in C into k clusters - - * TODO: Resizing is needed to use mdarray instead of rmm::device_uvector - - */ -template -void initScalableKMeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - raft::common::nvtx::range fun_scope( - "initScalableKMeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - - // <<<< Step-1 >>> : C <- sample a point uniformly at random from X - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - auto cIdx = dis(gen); - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + cIdx * n_features, 1, n_features); - - // flag the sample that is chosen as initial centroid - std::vector h_isSampleCentroid(n_samples); - std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - h_isSampleCentroid[cIdx] = 1; - - // device buffer to flag the sample that is chosen as initial centroid - auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); - - rmm::device_uvector centroidsBuf(initialCentroid.size(), stream); - - // reset buffer to store the chosen centroid - raft::copy(centroidsBuf.data(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - auto potentialCentroids = raft::make_device_matrix_view( - centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - auto minClusterDistanceVec = raft::make_device_vector(handle, n_samples); - auto uniformRands = raft::make_device_vector(handle, n_samples); - rmm::device_scalar clusterCost(stream); - - // <<< Step-2 >>>: psi <- phi_X (C) - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // compute partial cluster cost from the samples in rank - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - auto psi = clusterCost.value(stream); - - // <<< End of Step-2 >>> - - // Scalable kmeans++ paper claims 8 rounds is sufficient - resource::sync_stream(handle, stream); - int niter = std::min(8, (int)ceil(log(psi))); - RAFT_LOG_DEBUG("KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); - - // <<<< Step-3 >>> : for O( log(psi) ) times do - for (int iter = 0; iter < niter; ++iter) { - RAFT_LOG_DEBUG("KMeans|| - Iteration %d: # potential centroids sampled - %d", - iter, - potentialCentroids.extent(0)); - - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - psi = clusterCost.value(stream); - - // <<<< Step-4 >>> : Sample each point x in X independently and identify new - // potentialCentroids - raft::random::uniform( - handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); - - detail::SamplingOp select_op(psi, - params.oversampling_factor, - n_clusters, - uniformRands.data_handle(), - isSampleCentroid.data_handle()); - - rmm::device_uvector CpRaw(0, stream); - detail::sampleCentroids(handle, - X, - minClusterDistanceVec.view(), - isSampleCentroid.view(), - select_op, - CpRaw, - workspace); - auto Cp = raft::make_device_matrix_view( - CpRaw.data(), CpRaw.size() / n_features, n_features); - /// <<<< End of Step-4 >>>> - - /// <<<< Step-5 >>> : C = C U C' - // append the data in Cp to the buffer holding the potentialCentroids - centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream); - raft::copy( - centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data_handle(), Cp.size(), stream); - - IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0); - potentialCentroids = - raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); - /// <<<< End of Step-5 >>> - } /// <<<< Step-6 >>> - - RAFT_LOG_DEBUG("KMeans||: total # potential centroids sampled - %d", - potentialCentroids.extent(0)); - - if ((int)potentialCentroids.extent(0) > n_clusters) { - // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X - // temporary buffer to store the sample count per cluster, destructor - // releases the resource - auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); - - detail::countSamplesInCluster( - handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); - - // <<< end of Step-7 >>> - - // Step-8: Recluster the weighted points in C into k clusters - detail::kmeansPlusPlus( - handle, params, potentialCentroids, centroidsRawData, workspace); - - auto inertia = make_host_scalar(0); - auto n_iter = make_host_scalar(0); - KMeansParams default_params; - default_params.n_clusters = params.n_clusters; - - detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); - - } else if ((int)potentialCentroids.extent(0) < n_clusters) { - // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.extent(0); - - RAFT_LOG_DEBUG( - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)", - n_clusters, - potentialCentroids.extent(0), - n_random_clusters); - - // generate `n_random_clusters` centroids - KMeansParams rand_params; - rand_params.init = KMeansParams::InitMethod::Random; - rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); - - // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } else { - // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } -} - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. - * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - raft::common::nvtx::range fun_scope("kmeans_fit"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - // Display a message if the batch size is smaller than n_samples but will be ignored - if (params.batch_samples < (int)n_samples && - (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_samples, - (int)n_samples); - } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (params.batch_centroids != 0 && params.batch_centroids != params.n_clusters && - (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_centroids, - params.n_clusters); - } - - logger::get(RAFT_NAME).set_level(params.verbosity); - - // Allocate memory - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(raft::resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); - - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); - - auto n_init = params.n_init; - if (params.init == KMeansParams::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - std::mt19937 gen(params.rng_state.seed); - inertia[0] = std::numeric_limits::max(); - - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { - KMeansParams iter_params = params; - iter_params.rng_state.seed = gen(); - - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - if (iter_params.init == KMeansParams::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == KMeansParams::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == KMeansParams::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); - raft::copy( - centroidsRawData.data_handle(), centroids.data_handle(), n_clusters * n_features, stream); - } else { - THROW("unknown initialization method to select initial centers"); - } - - detail::kmeans_fit_main(handle, - iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); - if (iter_inertia < inertia[0]) { - inertia[0] = iter_inertia; - n_iter[0] = n_current_iter; - raft::copy( - centroids.data_handle(), centroidsRawData.data_handle(), n_clusters * n_features, stream); - } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", - seed_iter + 1, - n_init, - inertia[0], - n_iter[0]); - } - RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; - if (sample_weight) - sample_weightView = - raft::make_device_vector_view(sample_weight, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit( - handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - raft::common::nvtx::range fun_scope("kmeans_predict"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(params.n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - logger::get(RAFT_NAME).set_level(params.verbosity); - auto metric = params.metric; - - // Allocate memory - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(raft::resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - if (normalize_weight) checkWeight(handle, weight.view(), workspace); - - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // calculate cluster cost phi_x(C) - rmm::device_scalar clusterCostD(stream); - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - labels.data_handle(), - raft::key_op{}); - - inertia[0] = clusterCostD.value(stream); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - - detail::kmeans_predict(handle, - params, - XView, - sample_weightView, - centroidsView, - labelsView, - normalize_weight, - inertiaView); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - raft::common::nvtx::range fun_scope("kmeans_fit_predict"); - if (!centroids.has_value()) { - auto n_features = X.extent(1); - auto centroids_matrix = - raft::make_device_matrix(handle, params.n_clusters, n_features); - detail::kmeans_fit( - handle, params, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids_matrix.view(), labels, true, inertia); - } else { - detail::kmeans_fit( - handle, params, X, sample_weight, centroids.value(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids.value(), labels, true, inertia); - } -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - std::optional> centroidsView{std::nullopt}; - if (centroids) - centroidsView.emplace( - raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit_predict( - handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @param[in] handle The handle to the cuML library context that - * manages the CUDA resources. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * @param[out] X_new X transformed in the new space.. - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - raft::common::nvtx::range fun_scope("kmeans_transform"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < (IndexT)n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min(static_cast(dataBatchSize), static_cast(n_samples - dIdx)); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + n_features * dIdx, ns, n_features); - - // pairwiseDistanceView [ns x n_clusters] - auto pairwiseDistanceView = raft::make_device_matrix_view( - X_new.data_handle() + n_clusters * dIdx, ns, n_clusters); - - // calculate pairwise distance between cluster centroids and current batch - // of input dataset - pairwise_distance_kmeans( - handle, datasetView, centroids, pairwiseDistanceView, workspace, metric); - } -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); - - detail::kmeans_transform(handle, params, XView, centroidsView, X_newView); -} -} // namespace detail -} // namespace cluster -} // namespace cuvs diff --git a/cpp/include/cuvs/cluster/detail/kmeans_auto_find_k.cuh b/cpp/include/cuvs/cluster/detail/kmeans_auto_find_k.cuh deleted file mode 100644 index 78566bb06..000000000 --- a/cpp/include/cuvs/cluster/detail/kmeans_auto_find_k.cuh +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include - -#include - -#include - -#include -#include - -namespace cuvs::cluster::detail { - -template -void compute_dispersion(raft::resources const& handle, - raft::device_matrix_view X, - KMeansParams& params, - raft::device_matrix_view centroids_view, - raft::device_vector_view labels, - raft::device_vector_view clusterSizes, - rmm::device_uvector& workspace, - raft::host_vector_view clusterDispertionView, - raft::host_vector_view resultsView, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - int val, - idx_t n, - idx_t d) -{ - auto centroids_const_view = - raft::make_device_matrix_view(centroids_view.data_handle(), val, d); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - auto cluster_sizes_view = - raft::make_device_vector_view(clusterSizes_ptr, val); - - params.n_clusters = val; - - cuvs::cluster::detail::kmeans_fit_predict( - handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); - - detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); - - resultsView[val] = residual[0]; - clusterDispertionView[val] = raft::stats::cluster_dispersion( - handle, centroids_const_view, cluster_sizes_view, std::nullopt, n); -} - -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-2) -{ - idx_t n = X.extent(0); - idx_t d = X.extent(1); - - RAFT_EXPECTS(n >= 1, "n must be >= 1"); - RAFT_EXPECTS(d >= 1, "d must be >= 1"); - RAFT_EXPECTS(kmin >= 1, "kmin must be >= 1"); - RAFT_EXPECTS(kmax <= n, "kmax must be <= number of data samples in X"); - RAFT_EXPECTS(tol >= 0, "tolerance must be >= 0"); - RAFT_EXPECTS(maxiter >= 0, "maxiter must be >= 0"); - // Allocate memory - // Device memory - - auto centroids = raft::make_device_matrix(handle, kmax, X.extent(1)); - auto clusterSizes = raft::make_device_vector(handle, kmax); - auto labels = raft::make_device_vector(handle, n); - - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - - // Host memory - auto results = raft::make_host_vector(kmax + 1); - auto clusterDispersion = raft::make_host_vector(kmax + 1); - - auto clusterDispertionView = clusterDispersion.view(); - auto resultsView = results.view(); - - // Loop to find *best* k - // Perform k-means in binary search - int left = kmin; // must be at least 2 - int right = kmax; // int(floor(len(data)/2)) #assumption of clusters of size 2 at least - int mid = ((unsigned int)left + (unsigned int)right) >> 1; - int oldmid = mid; - int tests = 0; - double objective[3]; // 0= left of mid, 1= right of mid - if (left == 1) left = 2; // at least do 2 clusters - - KMeansParams params; - params.max_iter = maxiter; - params.tol = tol; - - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), left, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - left, - n, - d); - - // eval right edge0 - resultsView[right] = 1e20; - while (resultsView[right] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), right, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - right, - n, - d); - - tests += 1; - } - - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - while (left < right - 1) { - resultsView[mid] = 1e20; - tests = 0; - while (resultsView[mid] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), mid, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - mid, - n, - d); - - if (resultsView[mid] > resultsView[left] && (mid + 1) < right) { - mid += 1; - resultsView[mid] = 1e20; - } else if (resultsView[mid] > resultsView[left] && (mid - 1) > left) { - mid -= 1; - resultsView[mid] = 1e20; - } - tests += 1; - } - - // maximize Calinski-Harabasz Index, minimize resid/ cluster - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - objective[2] = (n - mid) / (mid - 1) * clusterDispertionView[mid] / resultsView[mid]; - objective[0] = (objective[2] - objective[0]) / (mid - left); - objective[1] = (objective[1] - objective[2]) / (right - mid); - - if (objective[0] > 0 && objective[1] < 0) { - // our point is in the left-of-mid side - right = mid; - } else { - left = mid; - } - oldmid = mid; - mid = ((unsigned int)right + (unsigned int)left) >> 1; - } - - best_k[0] = right; - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - oldmid) / (oldmid - 1) * clusterDispertionView[oldmid] / resultsView[oldmid]; - if (objective[1] < objective[0]) { best_k[0] = left; } - - // if best_k isn't what we just ran, re-run to get correct centroids and dist data on return-> - // this saves memory - if (best_k[0] != oldmid) { - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); - - params.n_clusters = best_k[0]; - cuvs::cluster::detail::kmeans_fit_predict(handle, - params, - X, - std::nullopt, - std::make_optional(centroids_view), - labels.view(), - residual, - n_iter); - } -} -} // namespace cuvs::cluster::detail \ No newline at end of file diff --git a/cpp/include/cuvs/cluster/detail/kmeans_balanced.cuh b/cpp/include/cuvs/cluster/detail/kmeans_balanced.cuh deleted file mode 100644 index 1b946cc1e..000000000 --- a/cpp/include/cuvs/cluster/detail/kmeans_balanced.cuh +++ /dev/null @@ -1,1097 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -namespace cuvs::cluster::detail { - -constexpr static inline float kAdjustCentersWeight = 7.0f; - -/** - * @brief Predict labels for the dataset; floating-point types only. - * - * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows - * * n_cluster * sizeof(MathT)). - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[inout] mr (optional) Memory resource to use for temporary allocations - */ -template -inline std::enable_if_t> predict_core( - const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const MathT* dataset, - const MathT* dataset_norm, - IdxT n_rows, - LabelT* labels, - rmm::mr::device_memory_resource* mr) -{ - auto stream = resource::get_cuda_stream(handle); - switch (params.metric) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, make_extents((sizeof(int)) * n_rows)); - - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); - raft::linalg::rowNorm( - centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); - - cuvs::distance::fusedL2NNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - (params.metric == cuvs::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - - // todo(lsugy): use KVP + iterator in caller. - // Copy keys to output labels - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + n_rows, - labels, - raft::compose_op, raft::key_op>()); - break; - } - case cuvs::distance::DistanceType::InnerProduct: { - // TODO: pass buffer - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - - MathT alpha = -1.0; - MathT beta = 0.0; - - linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); - - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); - break; - } - default: { - RAFT_FAIL("The chosen distance metric is not supported (%d)", int(params.metric)); - } - } -} - -/** - * @brief Suggest a minibatch size for kmeans prediction. - * - * This function is used as a heuristic to split the work over a large dataset - * to reduce the size of temporary memory allocations. - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * - * @param[in] n_clusters number of clusters in kmeans clustering - * @param[in] n_rows Number of samples in the dataset - * @param[in] dim Number of features in the dataset - * @param[in] metric Distance metric - * @param[in] needs_conversion Whether the data needs to be converted to MathT - * @return A suggested minibatch size and the expected memory cost per-row (in bytes) - */ -template -constexpr auto calc_minibatch_size(IdxT n_clusters, - IdxT n_rows, - IdxT dim, - cuvs::distance::DistanceType metric, - bool needs_conversion) -> std::tuple -{ - n_clusters = std::max(1, n_clusters); - - // Estimate memory needs per row (i.e element of the batch). - size_t mem_per_row = 0; - switch (metric) { - // fusedL2NN needs a mutex and a key-value pair for each row. - case distance::DistanceType::L2Expanded: - case distance::DistanceType::L2SqrtExpanded: { - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } break; - // Other metrics require storing a distance matrix. - default: { - mem_per_row += sizeof(MathT) * n_clusters; - } - } - - // If we need to convert to MathT, space required for the converted batch. - if (!needs_conversion) { mem_per_row += sizeof(MathT) * dim; } - - // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. - IdxT minibatch_size = (1 << 30) / mem_per_row; - minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); - minibatch_size = std::min(minibatch_size, n_rows); - return std::make_tuple(minibatch_size, mem_per_row); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * @note all pointers must be accessible on the device. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[inout] centers Pointer to the output [n_clusters, dim] - * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number of samples in the `dataset` - * @param[in] labels Output predictions [n_rows] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing centers and sizes using - * the weighted average principle. - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device - */ -template -void calc_centers_and_sizes(const raft::resources& handle, - MathT* centers, - CounterT* cluster_sizes, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - bool reset_counters, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } - - if (!reset_counters) { - raft::linalg::matrixVectorOp( - centers, centers, cluster_sizes, dim, n_clusters, true, false, raft::mul_op(), stream); - } - - rmm::device_uvector workspace(0, stream, mr); - - // If we reset the counters, we can compute directly the new sizes in cluster_sizes. - // If we don't reset, we compute in a temporary buffer and add in a separate step. - rmm::device_uvector temp_cluster_sizes(0, stream, mr); - CounterT* temp_sizes = cluster_sizes; - if (!reset_counters) { - temp_cluster_sizes.resize(n_clusters, stream); - temp_sizes = temp_cluster_sizes.data(); - } - - // Apply mapping only when the data and math types are different. - if constexpr (std::is_same_v) { - raft::linalg::reduce_rows_by_key( - dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } else { - // todo(lsugy): use iterator from KV output of fusedL2NN - cub::TransformInputIterator mapping_itr(dataset, mapping_op); - raft::linalg::reduce_rows_by_key( - mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } - - // Compute weight of each cluster - cuvs::cluster::detail::countLabels(handle, labels, temp_sizes, n_rows, n_clusters, workspace); - - // Add previous sizes if necessary - if (!reset_counters) { - raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); - } - - raft::linalg::matrixVectorOp(centers, - centers, - cluster_sizes, - dim, - n_clusters, - true, - false, - raft::div_checkzero_op(), - stream); -} - -/** Computes the L2 norm of the dataset, converting to MathT if necessary */ -template -void compute_norm(const raft::resources& handle, - MathT* dataset_norm, - const T* dataset, - IdxT dim, - IdxT n_rows, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) -{ - raft::common::nvtx::range fun_scope("compute_norm"); - auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } - rmm::device_uvector mapped_dataset(0, stream, mr); - - const MathT* dataset_ptr = nullptr; - - if (std::is_same_v) { - dataset_ptr = reinterpret_cast(dataset); - } else { - mapped_dataset.resize(n_rows * dim, stream); - - linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); - - dataset_ptr = (const MathT*)mapped_dataset.data(); - } - - raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); -} - -/** - * @brief Predict labels for the dataset. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) memory resource to use for temporary allocations - * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] - */ -template -void predict(const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - LabelT* labels, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - raft::common::nvtx::range fun_scope( - "predict(%zu, %u)", static_cast(n_rows), n_clusters); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } - auto [max_minibatch_size, _mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); - bool need_compute_norm = - dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded); - rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mr); - const MathT* dataset_norm_ptr = nullptr; - auto cur_dataset_ptr = cur_dataset.data(); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - - if constexpr (std::is_same_v) { - cur_dataset_ptr = const_cast(dataset + offset * dim); - } else { - linalg::unaryOp( - cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); - } - - // Compute the norm now if it hasn't been pre-computed. - if (need_compute_norm) { - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr); - dataset_norm_ptr = cur_dataset_norm.data(); - } else if (dataset_norm != nullptr) { - dataset_norm_ptr = dataset_norm + offset; - } - - predict_core(handle, - params, - centers, - n_clusters, - dim, - cur_dataset_ptr, - dataset_norm_ptr, - minibatch_size, - labels + offset, - mr); - } -} - -template -__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL - adjust_centers_kernel(MathT* centers, // [n_clusters, dim] - IdxT n_clusters, - IdxT dim, - const T* dataset, // [n_rows, dim] - IdxT n_rows, - const LabelT* labels, // [n_rows] - const CounterT* cluster_sizes, // [n_clusters] - MathT threshold, - IdxT average, - IdxT seed, - IdxT* count, - MappingOpT mapping_op) -{ - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); - if (l >= n_clusters) return; - auto csize = static_cast(cluster_sizes[l]); - // skip big clusters - if (csize > static_cast(average * threshold)) return; - - // choose a "random" i that belongs to a rather large cluster - IdxT i; - IdxT j = laneId(); - if (j == 0) { - do { - auto old = atomicAdd(count, IdxT{1}); - i = (seed * (old + 1)) % n_rows; - } while (static_cast(cluster_sizes[labels[i]]) < average); - } - i = raft::shfl(i, 0); - - // Adjust the center of the selected smaller cluster to gravitate towards - // a sample from the selected larger cluster. - const IdxT li = static_cast(labels[i]); - // Weight of the current center for the weighted average. - // We dump it for anomalously small clusters, but keep constant otherwise. - const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); - // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; - for (; j < dim; j += raft::WarpSize) { - MathT val = 0; - val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); - val /= wc + wd; - centers[j + dim * l] = val; - } -} - -/** - * @brief Adjust centers for clusters that have small number of entries. - * - * For each cluster, where the cluster size is not bigger than a threshold, the center is moved - * towards a data point that belongs to a large cluster. - * - * NB: if this function returns `true`, you should update the labels. - * - * NB: all pointers must be on the device side. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[inout] centers cluster centers [n_clusters, dim] - * @param[in] n_clusters number of rows in `centers` - * @param[in] dim number of columns in `centers` and `dataset` - * @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim] - * @param[in] n_rows number of rows in `dataset` - * @param[in] labels a host pointer to the cluster indices [n_rows] - * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] threshold defines a criterion for adjusting a cluster - * (cluster_sizes <= average_size * threshold) - * 0 <= threshold < 1 - * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] stream CUDA stream - * @param[inout] device_memory memory resource to use for temporary allocations - * - * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). - */ -template -auto adjust_centers(MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - const CounterT* cluster_sizes, - MathT threshold, - MappingOpT mapping_op, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* device_memory) -> bool -{ - raft::common::nvtx::range fun_scope( - "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); - if (n_clusters == 0) { return false; } - constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, - 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, - 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, - 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static IdxT i = 0; - static IdxT i_primes = 0; - - bool adjusted = false; - IdxT average = n_rows / n_clusters; - IdxT ofst; - do { - i_primes = (i_primes + 1) % kPrimes.size(); - ofst = kPrimes[i_primes]; - } while (n_rows % ofst == 0); - - constexpr uint32_t kBlockDimY = 4; - const dim3 block_dim(WarpSize, kBlockDimY, 1); - const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); - rmm::device_scalar update_count(0, stream, device_memory); - adjust_centers_kernel<<>>(centers, - n_clusters, - dim, - dataset, - n_rows, - labels, - cluster_sizes, - threshold, - average, - ofst, - update_count.data(), - mapping_op); - adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync - - return adjusted; -} - -/** - * @brief Expectation-maximization-balancing combined in an iterative process. - * - * Note, the `cluster_centers` is assumed to be already initialized here. - * Thus, this function can be used for fine-tuning existing clusters; - * to train from scratch, use `build_clusters` function below. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] n_iters Requested number of iterations (can differ from params.n_iter!) - * @param[in] dim Dimensionality of the dataset - * @param[in] dataset Pointer to a managed row-major array [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number of rows in the dataset - * @param[in] n_cluster Requested number of clusters - * @param[inout] cluster_centers Pointer to a managed row-major array [n_clusters, dim] - * @param[out] cluster_labels Pointer to a managed row-major array [n_rows] - * @param[out] cluster_sizes Pointer to a managed row-major array [n_clusters] - * @param[in] balancing_pullback - * if the cluster centers are rebalanced on this number of iterations, - * one extra iteration is performed (this could happen several times) (default should be `2`). - * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds - * one more iteration to the main cycle. - * @param[in] balancing_threshold - * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` - * on a given iteration (default should be `~ 0.25`). - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] device_memory - * A memory resource for device allocations (makes sense to provide a memory pool here) - */ -template -void balancing_em_iters(const raft::resources& handle, - const kmeans_balanced_params& params, - uint32_t n_iters, - IdxT dim, - const T* dataset, - const MathT* dataset_norm, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - uint32_t balancing_pullback, - MathT balancing_threshold, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory) -{ - auto stream = resource::get_cuda_stream(handle); - uint32_t balancing_counter = balancing_pullback; - for (uint32_t iter = 0; iter < n_iters; iter++) { - // Balancing step - move the centers around to equalize cluster sizes - // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - cluster_sizes, - balancing_threshold, - mapping_op, - stream, - device_memory)) { - if (balancing_counter++ >= balancing_pullback) { - balancing_counter -= balancing_pullback; - n_iters++; - } - } - switch (params.metric) { - // For some metrics, cluster calculation and adjustment tends to favor zero center vectors. - // To avoid converging to zero, we normalize the center vectors on every iteration. - case cuvs::distance::DistanceType::InnerProduct: - case cuvs::distance::DistanceType::CosineExpanded: - case cuvs::distance::DistanceType::CorrelationExpanded: { - auto clusters_in_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - auto clusters_out_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - raft::linalg::row_normalize( - handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); - break; - } - default: break; - } - // E: Expectation step - predict labels - predict(handle, - params, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - mapping_op, - device_memory, - dataset_norm); - // M: Maximization step - calculate optimal cluster centers - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - } -} - -/** Randomly initialize cluster centers and then call `balancing_em_iters`. */ -template -void build_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - - // "randomly" initialize labels - auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); - linalg::map_offset( - handle, - labels_view, - raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); - - // update centers to match the initialized labels. - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - - // run EM - balancing_em_iters(handle, - params, - params.n_iters, - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - cluster_labels, - cluster_sizes, - 2, - MathT{0.25}, - mapping_op, - device_memory); -} - -/** Calculate how many fine clusters should belong to each mesocluster. */ -template -inline auto arrange_fine_clusters(IdxT n_clusters, - IdxT n_mesoclusters, - IdxT n_rows, - const CounterT* mesocluster_sizes) -{ - std::vector fine_clusters_nums(n_mesoclusters); - std::vector fine_clusters_csum(n_mesoclusters + 1); - fine_clusters_csum[0] = 0; - - IdxT n_lists_rem = n_clusters; - IdxT n_nonempty_ms_rem = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - n_nonempty_ms_rem += mesocluster_sizes[i] > CounterT{0} ? 1 : 0; - } - IdxT n_rows_rem = n_rows; - CounterT mesocluster_size_sum = 0; - CounterT mesocluster_size_max = 0; - IdxT fine_clusters_nums_max = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - if (i < n_mesoclusters - 1) { - // Although the algorithm is meant to produce balanced clusters, when something - // goes wrong, we may get empty clusters (e.g. during development/debugging). - // The code below ensures a proportional arrangement of fine cluster numbers - // per mesocluster, even if some clusters are empty. - if (mesocluster_sizes[i] == 0) { - fine_clusters_nums[i] = 0; - } else { - n_nonempty_ms_rem--; - auto s = static_cast( - static_cast(n_lists_rem * mesocluster_sizes[i]) / n_rows_rem + .5); - s = std::min(s, n_lists_rem - n_nonempty_ms_rem); - fine_clusters_nums[i] = std::max(s, IdxT{1}); - } - } else { - fine_clusters_nums[i] = n_lists_rem; - } - n_lists_rem -= fine_clusters_nums[i]; - n_rows_rem -= mesocluster_sizes[i]; - mesocluster_size_max = max(mesocluster_size_max, mesocluster_sizes[i]); - mesocluster_size_sum += mesocluster_sizes[i]; - fine_clusters_nums_max = max(fine_clusters_nums_max, fine_clusters_nums[i]); - fine_clusters_csum[i + 1] = fine_clusters_csum[i] + fine_clusters_nums[i]; - } - - RAFT_EXPECTS(static_cast(mesocluster_size_sum) == n_rows, - "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", - static_cast(mesocluster_size_sum), - static_cast(n_rows)); - RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, - "fine cluster numbers do not add up (%zu) to the total number of clusters (%zu)", - static_cast(fine_clusters_csum[n_mesoclusters]), - static_cast(n_clusters)); - - return std::make_tuple(static_cast(mesocluster_size_max), - fine_clusters_nums_max, - std::move(fine_clusters_nums), - std::move(fine_clusters_csum)); -} - -/** - * Given the (coarse) mesoclusters and the distribution of fine clusters within them, - * build the fine clusters. - * - * Processing one mesocluster at a time: - * 1. Copy mesocluster data into a separate buffer - * 2. Predict fine cluster - * 3. Refince the fine cluster centers - * - * As a result, the fine clusters are what is returned by `build_hierarchical`; - * this function returns the total number of fine clusters, which can be checked to be - * the same as the requested number of clusters. - * - * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; - * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data - * is ignored and a warning is reported. - */ -template -auto build_fine_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset_mptr, - const MathT* dataset_norm_mptr, - const LabelT* labels_mptr, - IdxT n_rows, - const IdxT* fine_clusters_nums, - const IdxT* fine_clusters_csum, - const CounterT* mesocluster_sizes, - IdxT n_mesoclusters, - IdxT mesocluster_size_max, - IdxT fine_clusters_nums_max, - MathT* cluster_centers, - MappingOpT mapping_op, - rmm::mr::device_memory_resource* managed_memory, - rmm::mr::device_memory_resource* device_memory) -> IdxT -{ - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); - rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); - rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); - auto mc_trainset_ids = mc_trainset_ids_buf.data(); - auto mc_trainset = mc_trainset_buf.data(); - auto mc_trainset_norm = mc_trainset_norm_buf.data(); - - // label (cluster ID) of each vector - rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); - - rmm::device_uvector mc_trainset_ccenters( - fine_clusters_nums_max * dim, stream, device_memory); - // number of vectors in each cluster - rmm::device_uvector mc_trainset_csizes_tmp( - fine_clusters_nums_max, stream, device_memory); - - // Training clusters in each meso-cluster - IdxT n_clusters_done = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - IdxT k = 0; - for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { - if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } - } - if (k != static_cast(mesocluster_sizes[i])) - RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", - static_cast(i), - static_cast(k), - static_cast(mesocluster_sizes[i])); - if (k == 0) { - RAFT_LOG_DEBUG("Empty cluster %d", i); - RAFT_EXPECTS(fine_clusters_nums[i] == 0, - "Number of fine clusters must be zero for the empty mesocluster (got %d)", - static_cast(fine_clusters_nums[i])); - continue; - } else { - RAFT_EXPECTS(fine_clusters_nums[i] > 0, - "Number of fine clusters must be non-zero for a non-empty mesocluster"); - } - - cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); - raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); - if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - thrust::gather(raft::resource::get_thrust_policy(handle), - mc_trainset_ids, - mc_trainset_ids + k, - dataset_norm_mptr, - mc_trainset_norm); - } - - build_clusters(handle, - params, - dim, - mc_trainset, - k, - fine_clusters_nums[i], - mc_trainset_ccenters.data(), - mc_trainset_labels.data(), - mc_trainset_csizes_tmp.data(), - mapping_op, - device_memory, - mc_trainset_norm); - - raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), - mc_trainset_ccenters.data(), - fine_clusters_nums[i] * dim, - stream); - resource::sync_stream(handle, stream); - n_clusters_done += fine_clusters_nums[i]; - } - return n_clusters_done; -} - -/** - * @brief Hierarchical balanced k-means - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param dim number of columns in `centers` and `dataset` - * @param[in] dataset a device pointer to the source dataset [n_rows, dim] - * @param n_rows number of rows in the input - * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] - * @param n_cluster - * @param metric the distance type - * @param mapping_op Mapping operation from T to MathT - * @param stream - */ -template -void build_hierarchical(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - MathT* cluster_centers, - IdxT n_clusters, - MappingOpT mapping_op) -{ - auto stream = resource::get_cuda_stream(handle); - using LabelT = uint32_t; - - raft::common::nvtx::range fun_scope( - "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); - - IdxT n_mesoclusters = std::min(n_clusters, static_cast(std::sqrt(n_clusters) + 0.5)); - RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); - - rmm::mr::managed_memory_resource managed_memory; - rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle); - auto [max_minibatch_size, mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - auto pool_guard = - raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size)); - if (pool_guard) { - RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes", - mem_per_row * size_t(max_minibatch_size)); - } - - // Precompute the L2 norm of the dataset if relevant. - const MathT* dataset_norm = nullptr; - rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - dataset_norm_buf.resize(n_rows, stream); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - device_memory); - } - dataset_norm = (const MathT*)dataset_norm_buf.data(); - } - - /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively - * supported by atomicAdd: find a supported CounterT based on the IdxT. */ - typedef typename std::conditional_t - CounterT; - - // build coarse clusters (mesoclusters) - rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); - rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); - { - rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); - build_clusters(handle, - params, - dim, - dataset, - n_rows, - n_mesoclusters, - mesocluster_centers_buf.data(), - mesocluster_labels_buf.data(), - mesocluster_sizes_buf.data(), - mapping_op, - device_memory, - dataset_norm); - } - - auto mesocluster_sizes = mesocluster_sizes_buf.data(); - auto mesocluster_labels = mesocluster_labels_buf.data(); - - resource::sync_stream(handle, stream); - - // build fine clusters - auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = - arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - - const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( - 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); - if (mesocluster_size_max > mesocluster_size_max_balanced) { - RAFT_LOG_WARN( - "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " - "At most %u points will be used for training within each mesocluster. " - "Consider increasing the number of training iterations `n_iters`.", - mesocluster_size_max, - mesocluster_size_max_balanced, - mesocluster_size_max_balanced); - RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); - RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); - mesocluster_size_max = mesocluster_size_max_balanced; - } - - auto n_clusters_done = build_fine_clusters(handle, - params, - dim, - dataset, - dataset_norm, - mesocluster_labels, - n_rows, - fine_clusters_nums.data(), - fine_clusters_csum.data(), - mesocluster_sizes, - n_mesoclusters, - mesocluster_size_max, - fine_clusters_nums_max, - cluster_centers, - mapping_op, - &managed_memory, - device_memory); - RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); - - rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); - - // Fine-tuning k-means for all clusters - // - // (*) Since the likely cluster centroids have been calculated hierarchically already, the number - // of iterations for fine-tuning kmeans for whole clusters should be reduced. However, there is a - // possibility that the clusters could be unbalanced here, in which case the actual number of - // iterations would be increased. - // - balancing_em_iters(handle, - params, - std::max(params.n_iters / 10, 2), - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - labels.data(), - cluster_sizes.data(), - 5, - MathT{0.2}, - mapping_op, - device_memory); -} - -} // namespace cuvs::cluster::detail diff --git a/cpp/include/cuvs/cluster/detail/kmeans_common.cuh b/cpp/include/cuvs/cluster/detail/kmeans_common.cuh deleted file mode 100644 index d4f6a43a2..000000000 --- a/cpp/include/cuvs/cluster/detail/kmeans_common.cuh +++ /dev/null @@ -1,663 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace cuvs { -namespace cluster { -namespace detail { - -template -struct SamplingOp { - DataT* rnd; - uint8_t* flag; - DataT cluster_cost; - double oversampling_factor; - IndexT n_clusters; - - CUB_RUNTIME_FUNCTION __forceinline__ - SamplingOp(DataT c, double l, IndexT k, DataT* rand, uint8_t* ptr) - : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) - { - } - - __host__ __device__ __forceinline__ bool operator()( - const raft::KeyValuePair& a) const - { - DataT prob_threshold = (DataT)rnd[a.key]; - - DataT prob_x = ((oversampling_factor * n_clusters * a.value) / cluster_cost); - - return !flag[a.key] && (prob_x > prob_threshold); - } -}; - -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -// Computes the intensity histogram from a sequence of labels -template -void countLabels(raft::resources const& handle, - SampleIteratorT labels, - CounterT* count, - IndexT n_samples, - IndexT n_clusters, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - // CUB::DeviceHistogram requires a signed index type - typedef typename std::make_signed_t CubIndexT; - - CubIndexT num_levels = n_clusters + 1; - CubIndexT lower_level = 0; - CubIndexT upper_level = n_clusters; - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); -} - -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(&wt_sum, wt_aggr.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - if (wt_sum != n_samples) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp(weight.data_handle(), - weight.data_handle(), - n_samples, - raft::mul_const_op{scale}, - stream); - } -} - -template -IndexT getDataBatchSize(int batch_samples, IndexT n_samples) -{ - auto minVal = std::min(static_cast(batch_samples), n_samples); - return (minVal == 0) ? n_samples : minVal; -} - -template -IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) -{ - auto minVal = std::min(static_cast(batch_centroids), n_local_clusters); - return (minVal == 0) ? n_local_clusters : minVal; -} - -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - MainOpT main_op, - ReductionOpT reduction_op) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - cub::TransformInputIterator itr(minClusterDistance.data_handle(), - main_op); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); -} - -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_local_samples = X.extent(0); - auto n_features = X.extent(1); - - auto nSelected = raft::make_device_scalar(handle, 0); - cub::ArgIndexInputIterator ip_itr(minClusterDistance.data_handle()); - auto sampledMinClusterDistance = - raft::make_device_vector, IndexT>(handle, n_local_samples); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceSelect::If(workspace.data(), - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - IndexT nPtsSampledInRank = 0; - raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); - thrust::for_each_n(raft::resource::get_thrust_policy(handle), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - [=] __device__(raft::KeyValuePair val) { - rawPtr_isSampleCentroid[val.key] = 1; - }); - - inRankCp.resize(nPtsSampledInRank * n_features, stream); - - raft::matrix::gather((DataT*)X.data_handle(), - X.extent(1), - X.extent(0), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - inRankCp.data(), - raft::key_op{}, - stream); -} - -// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', -// result will be stored in 'pairwiseDistance[n x k]' -template -void pairwise_distance_kmeans(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view pairwiseDistance, - rmm::device_uvector& workspace, - cuvs::distance::DistanceType metric) -{ - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - ASSERT(X.extent(1) == centroids.extent(1), - "# features in dataset and centroids are different (must be same)"); - - cuvs::distance::pairwise_distance(handle, - X.data_handle(), - centroids.data_handle(), - pairwiseDistance.data_handle(), - n_samples, - n_clusters, - n_features, - workspace, - metric); -} - -// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores -// in 'out' does not modify the input -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = in.extent(0); - auto n_features = in.extent(1); - - auto indices = raft::make_device_vector(handle, n_samples); - - // shuffle indices on device - raft::random::permute(indices.data_handle(), - nullptr, - nullptr, - (IndexT)in.extent(1), - (IndexT)in.extent(0), - true, - stream); - - raft::matrix::gather((DataT*)in.data_handle(), - in.extent(1), - in.extent(0), - indices.data_handle(), - static_cast(n_samples_to_gather), - out.data_handle(), - stream); -} - -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - // tile over the input dataset - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + (dIdx * n_features), ns, n_features); - - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(int)) * ns, stream); - - // todo(lsugy): remove cIdx - cuvs::distance::fusedL2NNMinReduce, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + (cIdx * n_features), nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid - raft::linalg::coalescedReduction( - minClusterAndDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - initial_value, - stream, - true, - [=] __device__(const DataT val, const IndexT i) { - raft::KeyValuePair pair; - pair.key = cIdx + i; - pair.value = val; - return pair; - }, - raft::argmin_op{}, - raft::identity_op{}); - } - } - } -} - -template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterDistance.data_handle(), - minClusterDistance.data_handle() + minClusterDistance.size(), - std::numeric_limits::max()); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + dIdx * n_features, ns, n_features); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = - raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(IndexT)) * ns, stream); - - cuvs::distance::fusedL2NNMinReduce( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + cIdx * n_features, nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - raft::identity_op{}, - raft::min_op{}, - raft::identity_op{}); - } - } - } -} - -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store distance matrix, destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - (raft::device_matrix_view)centroids, - minClusterAndDistance.view(), - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair - // and converting them to just return the Key to be used in reduce_rows_by_key - // prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - // count # of samples in each cluster - countLabels(handle, - itr, - sampleCountInCluster.data_handle(), - (IndexT)n_samples, - (IndexT)n_clusters, - workspace); -} -} // namespace detail -} // namespace cluster -} // namespace cuvs diff --git a/cpp/include/cuvs/cluster/detail/mst.cuh b/cpp/include/cuvs/cluster/detail/mst.cuh deleted file mode 100644 index 6d304d64c..000000000 --- a/cpp/include/cuvs/cluster/detail/mst.cuh +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include - -namespace cuvs::cluster::detail { - -template -void merge_msts(sparse::solver::Graph_COO& coo1, - sparse::solver::Graph_COO& coo2, - cudaStream_t stream) -{ - /** Add edges to existing mst **/ - int final_nnz = coo2.n_edges + coo1.n_edges; - - coo1.src.resize(final_nnz, stream); - coo1.dst.resize(final_nnz, stream); - coo1.weights.resize(final_nnz, stream); - - /** - * Construct final edge list - */ - raft::copy_async(coo1.src.data() + coo1.n_edges, coo2.src.data(), coo2.n_edges, stream); - raft::copy_async(coo1.dst.data() + coo1.n_edges, coo2.dst.data(), coo2.n_edges, stream); - raft::copy_async(coo1.weights.data() + coo1.n_edges, coo2.weights.data(), coo2.n_edges, stream); - - coo1.n_edges = final_nnz; -} - -/** - * Connect an unconnected knn graph (one in which mst returns an msf). The - * device buffers underlying the Graph_COO object are modified in-place. - * @tparam value_idx index type - * @tparam value_t floating-point value type - * @param[in] handle raft handle - * @param[in] X original dense data from which knn grpah was constructed - * @param[inout] msf edge list containing the mst result - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[inout] color the color labels array returned from the mst invocation - * @return updated MST edge list - */ -template -void connect_knn_graph( - raft::resources const& handle, - const value_t* X, - sparse::solver::Graph_COO& msf, - size_t m, - size_t n, - value_idx* color, - red_op reduction_op, - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded) -{ - auto stream = resource::get_cuda_stream(handle); - - raft::sparse::COO connected_edges(stream); - - // default row and column batch sizes are chosen for computing cross component nearest neighbors. - // Reference: PR #1445 - static constexpr size_t default_row_batch_size = 4096; - static constexpr size_t default_col_batch_size = 16; - - raft::sparse::neighbors::cross_component_nn(handle, - connected_edges, - X, - color, - m, - n, - reduction_op, - min(m, default_row_batch_size), - min(n, default_col_batch_size)); - - rmm::device_uvector indptr2(m + 1, stream); - raft::sparse::convert::sorted_coo_to_csr( - connected_edges.rows(), connected_edges.nnz, indptr2.data(), m + 1, stream); - - // On the second call, we hand the MST the original colors - // and the new set of edges and let it restart the optimization process - auto new_mst = - raft::sparse::solver::mst(handle, - indptr2.data(), - connected_edges.cols(), - connected_edges.vals(), - m, - connected_edges.nnz, - color, - stream, - false, - false); - - merge_msts(msf, new_mst, stream); -} - -/** - * Constructs an MST and sorts the resulting edges in ascending - * order by their weight. - * - * Hierarchical clustering heavily relies upon the ordering - * and vertices returned in the MST. If the result of the - * MST was actually a minimum-spanning forest, the CSR - * being passed into the MST is not connected. In such a - * case, this graph will be connected by performing a - * KNN across the components. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[in] indptr CSR indptr of connectivities graph - * @param[in] indices CSR indices array of connectivities graph - * @param[in] pw_dists CSR weights array of connectivities graph - * @param[in] m number of rows in X / src vertices in connectivities graph - * @param[in] n number of columns in X - * @param[out] mst_src output src edges - * @param[out] mst_dst output dst edges - * @param[out] mst_weight output weights (distances) - * @param[in] max_iter maximum iterations to run knn graph connection. This - * argument is really just a safeguard against the potential for infinite loops. - */ -template -void build_sorted_mst( - raft::resources const& handle, - const value_t* X, - const value_idx* indptr, - const value_idx* indices, - const value_t* pw_dists, - size_t m, - size_t n, - value_idx* mst_src, - value_idx* mst_dst, - value_t* mst_weight, - value_idx* color, - size_t nnz, - red_op reduction_op, - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded, - int max_iter = 10) -{ - auto stream = resource::get_cuda_stream(handle); - - // We want to have MST initialize colors on first call. - auto mst_coo = raft::sparse::solver::mst( - handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, true); - - int iters = 1; - int n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - - while (n_components > 1 && iters < max_iter) { - connect_knn_graph(handle, X, mst_coo, m, n, color, reduction_op); - - iters++; - - n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - } - - /** - * The `max_iter` argument was introduced only to prevent the potential for an infinite loop. - * Ideally the log2(n) guarantees of the MST should be enough to connect KNN graphs with a - * massive number of data samples in very few iterations. If it does not, there are 3 likely - * reasons why (in order of their likelihood): - * 1. There is a bug in this code somewhere - * 2. Either the given KNN graph wasn't generated from X or the same metric is not being used - * to generate the 1-nn (currently only L2SqrtExpanded is supported). - * 3. max_iter was not large enough to connect the graph (less likely). - * - * Note that a KNN graph generated from 50 random isotropic balls (with significant overlap) - * was able to be connected in a single iteration. - */ - RAFT_EXPECTS(n_components == 1, - "KNN graph could not be connected in %d iterations. " - "Please verify that the input knn graph is generated from X " - "(and the same distance metric used)," - " or increase 'max_iter'", - max_iter); - - raft::sparse::op::coo_sort_by_weight( - mst_coo.src.data(), mst_coo.dst.data(), mst_coo.weights.data(), mst_coo.n_edges, stream); - - raft::copy_async(mst_src, mst_coo.src.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_dst, mst_coo.dst.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); -} - -}; // namespace cuvs::cluster::detail \ No newline at end of file diff --git a/cpp/include/cuvs/cluster/detail/single_linkage.cuh b/cpp/include/cuvs/cluster/detail/single_linkage.cuh deleted file mode 100644 index 5eb5ffb61..000000000 --- a/cpp/include/cuvs/cluster/detail/single_linkage.cuh +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include - -namespace cuvs::cluster::detail { - -static const size_t EMPTY = 0; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points"); - - auto stream = resource::get_cuda_stream(handle); - - rmm::device_uvector indptr(EMPTY, stream); - rmm::device_uvector indices(EMPTY, stream); - rmm::device_uvector pw_dists(EMPTY, stream); - - /** - * 1. Construct distance graph - */ - detail::get_distance_graph( - handle, X, m, n, metric, indptr, indices, pw_dists, c); - - rmm::device_uvector mst_rows(m - 1, stream); - rmm::device_uvector mst_cols(m - 1, stream); - rmm::device_uvector mst_data(m - 1, stream); - - /** - * 2. Construct MST, sorted by weights - */ - rmm::device_uvector color(m, stream); - raft::sparse::neighbors::FixConnectivitiesRedOp op(m); - detail::build_sorted_mst(handle, - X, - indptr.data(), - indices.data(), - pw_dists.data(), - m, - n, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - color.data(), - indices.size(), - op, - metric); - - pw_dists.release(); - - /** - * Perform hierarchical labeling - */ - size_t n_edges = mst_rows.size(); - - rmm::device_uvector out_delta(n_edges, stream); - rmm::device_uvector out_size(n_edges, stream); - // Create dendrogram - detail::build_dendrogram_host(handle, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - n_edges, - out->children, - out_delta.data(), - out_size.data()); - detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m); - - out->m = m; - out->n_clusters = n_clusters; - out->n_leaves = m; - out->n_connected_components = 1; -} -}; // namespace cuvs::cluster::detail \ No newline at end of file diff --git a/cpp/include/cuvs/cluster/kmeans.cuh b/cpp/include/cuvs/cluster/kmeans.cuh deleted file mode 100644 index e773a09ea..000000000 --- a/cpp/include/cuvs/cluster/kmeans.cuh +++ /dev/null @@ -1,1116 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs::cluster::kmeans { - -/** - * Functor used for sampling centroids - */ -template -using SamplingOp = detail::SamplingOp; - -/** - * Functor used to extract the index from a KeyValue pair - * storing both index and a distance. - */ -template -using KeyValueIndexOp = detail::KeyValueIndexOp; - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::raft::resources handle; - * cuvs::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids, - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::raft::resources handle; - * cuvs::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * false, - * labels.view(), - * raft::make_scalar_view(&ineratia)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - detail::kmeans_predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::raft::resources handle; - * cuvs::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - detail::kmeans_transform(handle, params, X, centroids, X_new); -} - -template -void transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - detail::kmeans_transform( - handle, params, X, centroids, n_samples, n_features, X_new); -} - -/** - * Automatically find the optimal value of k using a binary search. - * This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia. - * - * @code{.cpp} - * #include - * #include - * #include - * - * #include - * - * using namespace cuvs::cluster; - * - * raft::handle_t handle; - * int n_samples = 100, n_features = 15, n_clusters = 10; - * auto X = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * - * raft::random::make_blobs(handle, X, labels, n_clusters); - * - * auto best_k = raft::make_host_scalar(0); - * auto n_iter = raft::make_host_scalar(0); - * auto inertia = raft::make_host_scalar(0); - * - * kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1); - * - * @endcode - * - * @tparam idx_t indexing type (should be integral) - * @tparam value_t value type (should be floating point) - * @param handle raft handle - * @param X input observations (shape n_samples, n_dims) - * @param best_k best k found from binary search - * @param inertia inertia of best k found - * @param n_iter number of iterations used to find best k - * @param kmax maximum k to try in search - * @param kmin minimum k to try in search (should be >= 1) - * @param maxiter maximum number of iterations to run - * @param tol tolerance for early stopping convergence - */ -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-3) -{ - detail::find_k(handle, X, best_k, inertia, n_iter, kmax, kmin, maxiter, tol); -} - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sample_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - detail::sampleCentroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void cluster_cost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - detail::computeClusterCost( - handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); -} - -/** - * @brief Update centroids given current centroids and number of points assigned to each centroid. - * This function also produces a vector of RAFT key/value pairs containing the cluster assignment - * for each point and its distance. - * - * @tparam DataT - * @tparam IndexT - * @param[in] handle: Raft handle to use for managing library resources - * @param[in] X: input matrix (size n_samples, n_features) - * @param[in] sample_weights: number of samples currently assigned to each centroid (size n_samples) - * @param[in] centroids: matrix of current centroids (size n_clusters, n_features) - * @param[in] labels: Iterator of labels (can also be a raw pointer) - * @param[out] weight_per_cluster: sum of sample weights per cluster (size n_clusters) - * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - LabelsIterator labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids) -{ - // TODO: Passing these into the algorithm doesn't really present much of a benefit - // because they are being resized anyways. - // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - detail::update_centroids( - handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric Distance metric to use - * @param[in] batch_samples batch size for input data samples - * @param[in] batch_centroids batch size for input centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_distance(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric distance metric - * @param[in] batch_samples batch size of data samples - * @param[in] batch_centroids batch size of centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_and_distance( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffle_and_gather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void count_samples_in_cluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - detail::countSamplesInCluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/** - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - * - * @see "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void init_plus_plus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - detail::kmeansPlusPlus(handle, params, X, centroids, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); -} - -}; // namespace cuvs::cluster::kmeans - -namespace cuvs::cluster { - -/** - * Note: All of the functions below in cuvs::cluster are deprecated and will - * be removed in a future release. Please use cuvs::cluster::kmeans instead. - */ - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit( - handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - kmeans::predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - kmeans::predict(handle, - params, - X, - sample_weight, - centroids, - n_samples, - n_features, - labels, - normalize_weight, - inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - kmeans::transform(handle, params, X, centroids, X_new); -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - kmeans::transform(handle, params, X, centroids, n_samples, n_features, X_new); -} - -template -using SamplingOp = kmeans::SamplingOp; - -template -using KeyValueIndexOp = kmeans::KeyValueIndexOp; - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - kmeans::sample_centroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - kmeans::cluster_cost(handle, minClusterDistance, workspace, clusterCost, reduction_op); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterDistanceCompute(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_distance(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_and_distance(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - kmeans::count_samples_in_cluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - kmeans::init_plus_plus(handle, params, X, centroidsRawData, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - kmeans::fit_main( - handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); -} -}; // namespace cuvs::cluster diff --git a/cpp/include/cuvs/cluster/kmeans_balanced.cuh b/cpp/include/cuvs/cluster/kmeans_balanced.cuh deleted file mode 100644 index 7735587e7..000000000 --- a/cpp/include/cuvs/cluster/kmeans_balanced.cuh +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace cuvs::cluster::kmeans_balanced { - -/** - * @brief Find clusters of balanced sizes with a hierarchical k-means algorithm. - * - * This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters - * the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means - * iterations over the whole dataset and with all the centroids to obtain the final clusters. - * - * Each k-means iteration applies expectation-maximization-balancing: - * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a - * cluster is below a threshold, the center is moved towards a bigger cluster. - * - Expectation: predict the labels (i.e find closest cluster centroid to each point) - * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) - * - * The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g - * for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is - * chosen proportionally to the number of points in each mesocluster. - * - * This variant of k-means uses random initialization and a fixed number of iterations, though - * iterations can be repeated if the balancing step moved the centroids. - * - * Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of - * the algorithm will work with a floating-point type, hence a conversion function can be provided - * to map the data type to the math type. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * cuvs::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * cuvs::cluster::kmeans_balanced::fit(handle, params, X, centroids.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The generated centroids [dim = n_clusters x n_features] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -void fit(const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(centroids.extent(0) > IndexT{0} && centroids.extent(0) <= X.extent(0), - "The number of centroids must be strictly positive and cannot exceed the number of " - "points in the training dataset."); - - detail::build_hierarchical(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.data_handle(), - centroids.extent(0), - mapping_op); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * cuvs::cluster::kmeans_balanced_params params; - * auto labels = raft::make_device_vector(handle, n_rows); - * cuvs::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Dataset for which to infer the closest clusters. - * [dim = n_samples x n_features] - * @param[in] centroids The input centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -void predict(const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) <= - static_cast(std::numeric_limits::max()), - "The chosen label type cannot represent all cluster labels"); - - detail::predict(handle, - params, - centroids.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - mapping_op); -} - -/** - * @brief Compute hierarchical balanced k-means clustering and predict cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * cuvs::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_rows); - * cuvs::cluster::kmeans_balanced::fit_predict( - * handle, params, X, centroids.view(), labels.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT and MathT are the same, this must be the identity. - */ -template -void fit_predict(const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - cuvs::cluster::kmeans_balanced::fit(handle, params, X, centroids, mapping_op); - cuvs::cluster::kmeans_balanced::predict(handle, params, X, centroids_const, labels, mapping_op); -} - -namespace helpers { - -/** - * @brief Randomly initialize centers and apply expectation-maximization-balancing iterations - * - * This is essentially the non-hierarchical balanced k-means algorithm which is used by the - * hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine - * clusters. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * cuvs::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * cuvs::cluster::kmeans_balanced::build_clusters( - * handle, params, X, centroids.view(), labels.view(), sizes.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] - */ -template -void build_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - raft::device_vector_view cluster_sizes, - MappingOpT mapping_op = raft::identity_op(), - std::optional> X_norm = std::nullopt) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::build_clusters(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.extent(0), - centroids.data_handle(), - labels.data_handle(), - cluster_sizes.data_handle(), - mapping_op, - resource::get_workspace_resource(handle), - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * Let `S_i = {x_k | x_k \in X & labels[k] == i}` be the vectors in the dataset with label i. - * - * On exit, - * `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`, - * where `w_i = reset_counters ? 0 : cluster_size[i]`. - * - * In other words, the updated cluster centers are a weighted average of the existing cluster - * center, and the coordinates of the points labeled with i. _This allows calling this function - * multiple times with different datasets with the same effect as if calling this function once - * on the combined dataset_. - * - * @code{.cpp} - * #include - * #include - * ... - * raft::handle_t handle; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * cuvs::cluster::kmeans_balanced::calc_centers_and_sizes( - * handle, X, labels, centroids.view(), sizes.view(), true); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] X Dataset for which to calculate cluster centers. The data must be in - * row-major format. [dim = n_samples x n_features] - * @param[in] labels The input labels [dim = n_samples] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing - * centers and sizes using the weighted average principle. - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - */ -template -void calc_centers_and_sizes(const raft::resources& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - raft::device_matrix_view centroids, - raft::device_vector_view cluster_sizes, - bool reset_counters = true, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::calc_centers_and_sizes(handle, - centroids.data_handle(), - cluster_sizes.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - reset_counters, - mapping_op); -} - -} // namespace helpers - -} // namespace cuvs::cluster::kmeans_balanced diff --git a/cpp/include/cuvs/cluster/kmeans_balanced_types.hpp b/cpp/include/cuvs/cluster/kmeans_balanced_types.hpp deleted file mode 100644 index 5a4793fbe..000000000 --- a/cpp/include/cuvs/cluster/kmeans_balanced_types.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace cuvs::cluster::kmeans_balanced { - -/** - * Simple object to specify hyper-parameters to the balanced k-means algorithm. - * - * The following metrics are currently supported in k-means balanced: - * - InnerProduct - * - L2Expanded - * - L2SqrtExpanded - */ -struct kmeans_balanced_params : kmeans_base_params { - /** - * Number of training iterations - */ - uint32_t n_iters = 20; -}; - -} // namespace cuvs::cluster::kmeans_balanced - -namespace cuvs::cluster { - -using kmeans_balanced::kmeans_balanced_params; - -} // namespace cuvs::cluster diff --git a/cpp/include/cuvs/cluster/kmeans_deprecated.cuh b/cpp/include/cuvs/cluster/kmeans_deprecated.cuh deleted file mode 100644 index c31f7e686..000000000 --- a/cpp/include/cuvs/cluster/kmeans_deprecated.cuh +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -namespace cuvs::cluster { - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - return detail::kmeans( - handle, n, d, k, tol, maxiter, obs, codes, residual, iters, seed); -} -} // namespace cuvs::cluster diff --git a/cpp/include/cuvs/cluster/kmeans_types.hpp b/cpp/include/cuvs/cluster/kmeans_types.hpp deleted file mode 100644 index c9090166d..000000000 --- a/cpp/include/cuvs/cluster/kmeans_types.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -namespace cuvs::cluster { - -/** Base structure for parameters that are common to all k-means algorithms */ -struct kmeans_base_params { - /** - * Metric to use for distance computation. The supported metrics can vary per algorithm. - */ - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; -}; - -} // namespace cuvs::cluster - -namespace cuvs::cluster::kmeans { - -/** - * Simple object to specify hyper-parameters to the kmeans algorithm. - */ -struct KMeansParams : kmeans_base_params { - enum InitMethod { - - /** - * Sample the centroids using the kmeans++ strategy - */ - KMeansPlusPlus, - - /** - * Sample the centroids uniformly at random - */ - Random, - - /** - * User provides the array of initial centroids - */ - Array - }; - - /** - * The number of clusters to form as well as the number of centroids to generate (default:8). - */ - int n_clusters = 8; - - /** - * Method for initialization, defaults to k-means++: - * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm - * to select the initial cluster centers. - * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at - * random from the input data for the initial centroids. - * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. - */ - InitMethod init = KMeansPlusPlus; - - /** - * Maximum number of iterations of the k-means algorithm for a single run. - */ - int max_iter = 300; - - /** - * Relative tolerance with regards to inertia to declare convergence. - */ - double tol = 1e-4; - - /** - * verbosity level. - */ - int verbosity = RAFT_LEVEL_INFO; - - /** - * Seed to the random number generator. - */ - raft::random::RngState rng_state{0}; - - /** - * Number of instance k-means algorithm will be run with different seeds. - */ - int n_init = 1; - - /** - * Oversampling factor for use in the k-means|| algorithm - */ - double oversampling_factor = 2.0; - - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids - int batch_samples = 1 << 15; - - /** - * if 0 then batch_centroids = n_clusters - */ - int batch_centroids = 0; // - - bool inertia_check = false; -}; - -} // namespace cuvs::cluster::kmeans - -namespace cuvs::cluster { - -using kmeans::KMeansParams; - -} // namespace cuvs::cluster diff --git a/cpp/include/cuvs/cluster/single_linkage.cuh b/cpp/include/cuvs/cluster/single_linkage.cuh deleted file mode 100644 index 88c964678..000000000 --- a/cpp/include/cuvs/cluster/single_linkage.cuh +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace cuvs::cluster { - -/** - * Note: All of the functions below in the cuvs::cluster namespace are deprecated - * and will be removed in a future release. Please use cuvs::cluster::hierarchy - * instead. - */ - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - cuvs::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - detail::single_linkage( - handle, X, m, n, metric, out, c, n_clusters); -} -}; // namespace cuvs::cluster - -namespace cuvs::cluster::hierarchy { - -constexpr int DEFAULT_CONST_C = 15; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) - * @param[out] labels output labels vector (size n_rows) - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[in] n_clusters number of clusters to assign data samples - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control of k. The algorithm will set `k = log(n) + c` - */ -template -void single_linkage(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view dendrogram, - raft::device_vector_view labels, - cuvs::distance::DistanceType metric, - size_t n_clusters, - std::optional c = std::make_optional(DEFAULT_CONST_C)) -{ - linkage_output out_arrs; - out_arrs.children = dendrogram.data_handle(); - out_arrs.labels = labels.data_handle(); - - cuvs::cluster::single_linkage( - handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out_arrs, - c.has_value() ? c.value() : DEFAULT_CONST_C, - n_clusters); -} -}; // namespace cuvs::cluster::hierarchy diff --git a/cpp/include/cuvs/cluster/single_linkage_types.hpp b/cpp/include/cuvs/cluster/single_linkage_types.hpp deleted file mode 100644 index 8da65a01f..000000000 --- a/cpp/include/cuvs/cluster/single_linkage_types.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace cuvs::cluster::hierarchy { - -/** - * Determines the method for computing the minimum spanning tree (MST) - */ -enum LinkageDistance { - - /** - * Use a pairwise distance matrix as input to the mst. This - * is very fast and the best option for fairly small datasets (~50k data points) - */ - PAIRWISE = 0, - - /** - * Construct a KNN graph as input to the mst and provide additional - * edges if the mst does not converge. This is slower but scales - * to very large datasets. - */ - KNN_GRAPH = 1 -}; - -}; // namespace cuvs::cluster::hierarchy - -// The code below is now considered legacy -namespace cuvs::cluster { - -using hierarchy::LinkageDistance; - -/** - * Simple container object for consolidating linkage results. This closely - * mirrors the trained instance variables populated in - * Scikit-learn's AgglomerativeClustering estimator. - * @tparam value_idx - * @tparam value_t - */ -template -class linkage_output { - public: - idx_t m; - idx_t n_clusters; - - idx_t n_leaves; - idx_t n_connected_components; - - // TODO: These will be made private in a future release - idx_t* labels; // size: m - idx_t* children; // size: (m-1, 2) - - raft::device_vector_view get_labels() - { - return raft::make_device_vector_view(labels, m); - } - - raft::device_matrix_view get_children() - { - return raft::make_device_matrix_view(children, m - 1, 2); - } -}; - -class linkage_output_int : public linkage_output {}; -class linkage_output_int64 : public linkage_output {}; - -}; // namespace cuvs::cluster diff --git a/cpp/include/cuvs/distance/detail/compress_to_bits.cuh b/cpp/include/cuvs/distance/detail/compress_to_bits.cuh deleted file mode 100644 index 9ce47774a..000000000 --- a/cpp/include/cuvs/distance/detail/compress_to_bits.cuh +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace cuvs::distance::detail { - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for masked_l2_nn. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -RAFT_KERNEL compress_to_bits_kernel( - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - constexpr int tile_dim_m = bits_per_element; - constexpr int nthreads = 128; - constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector - - // Tile in shared memory is transposed - __shared__ bool smem[tile_dim_n][tile_dim_m]; - - const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); - const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); - - for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { - const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); - const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); - - if (in.extent(0) <= tile_idx_m) { break; } - // Fill shared memory tile - bool reg_buf[tile_dim_m]; -#pragma unroll - for (int i = 0; i < tile_dim_m; ++i) { - const int in_m = tile_idx_m + i; - const int in_n = tile_idx_n + threadIdx.x; - bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); - reg_buf[i] = in_bounds ? in(in_m, in_n) : false; - smem[threadIdx.x][i] = reg_buf[i]; - } - __syncthreads(); - - // Drain memory tile into single output element out_elem. - T out_elem{0}; -#pragma unroll - for (int j = 0; j < tile_dim_n; ++j) { - if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } - } - __syncthreads(); - - // Write output. - int out_m = tile_idx_m / bits_per_element; - int out_n = tile_idx_n + threadIdx.x; - - if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } - } -} - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for masked_l2_nn. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -void compress_to_bits(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - auto stream = resource::get_cuda_stream(handle); - constexpr int bits_per_element = 8 * sizeof(T); - - RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), - "Number of output rows must be ceildiv(input rows, bits_per_elem)"); - RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); - - const int num_SMs = raft::getMultiProcessorCount(); - int blocks_per_sm = 0; - constexpr int num_threads = 128; - constexpr int dyn_smem_size = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); - - dim3 grid(num_SMs * blocks_per_sm); - dim3 block(128); - compress_to_bits_kernel<<>>(in, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/distance.cuh b/cpp/include/cuvs/distance/detail/distance.cuh deleted file mode 100644 index ea935bdcb..000000000 --- a/cpp/include/cuvs/distance/detail/distance.cuh +++ /dev/null @@ -1,814 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs { -namespace distance { -namespace detail { - -/** - * @brief: A tag type for overload resolution based on DistanceType - * - * It is not possible to partially specialize function templates on a single - * parameter. Instead, it is often easier to use a combination of conventional - * method overloading and a parameter with a specific tag type. The following - * type is used to help method overloading based on the DistanceType enum. - */ -template -using distance_tag = std::integral_constant; - -/** - * @brief Implement pairwise_matrix for specific distance - * - * There are multiple overloads for this function, one for each distance type. - * They are implemented below. The documentation of this function serves as - * documentation for all functions. The following overloads are defined: - * - * - DistanceType::Canberra: - * - DistanceType::CorrelationExpanded: - * - DistanceType::CosineExpanded: - * - DistanceType::HammingUnexpanded: - * - DistanceType::HellingerExpanded: - * - DistanceType::JensenShannon: - * - DistanceType::KLDivergence: - * - DistanceType::L1: - * - DistanceType::L2Expanded: - * - DistanceType::L2SqrtExpanded: - * - DistanceType::L2Unexpanded: - * - DistanceType::L2SqrtUnexpanded: - * - DistanceType::Linf: - * - DistanceType::LpUnexpanded: - * - DistanceType::RusselRaoExpanded: - * - * @tparam DataT Input data type - * @tparam AccT Accumulation data type - * @tparam OutT Output data type - * @tparam FinOpT Type of final operation - * @tparam IdxT Index type - * - * @param handle RAFT resources handle - * @param distance_type A tag type to indicate which distance is calculated. - * @param x First set of points - * @param y Second set of points - * @param out Output distance matrix - * @param m Number of points in x - * @param n Number of points in y - * @param k Dimensionality of points in x, y - * @param workspace Temporary workspace needed for computations - * @param worksize Number of bytes of the workspace - * @param is_row_major Whether the matrices are row-major or col-major - * @param metric_arg The `p` argument for Lp. - */ -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, // unused - size_t worksize, // unused - FinOpT fin_op, - bool is_row_major, - DataT metric_arg) // unused -{ - ops::canberra_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // unused -{ - ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - AccT* x_norm = workspace; - AccT* y_norm = workspace; - AccT* sq_x_norm = workspace; - AccT* sq_y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if (x == y && is_row_major) { - raft::linalg::reduce(x_norm, - x, - k, - std::max(m, n), - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - sq_x_norm += std::max(m, n); - sq_y_norm = sq_x_norm; - raft::linalg::rowNorm( - sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream); - } else { - y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - - sq_x_norm += (m + n); - sq_y_norm = sq_x_norm + m; - raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); - } - - using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); - pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // unused -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } else { - y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } - - ops::cosine_distance_op distance_op{}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::hamming_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::linalg::gemm(handle, - out, - const_cast(x), - const_cast(y), - m, - n, - k, - !is_row_major, - !is_row_major, - is_row_major, - stream); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - // First sqrt x and y - const auto raft_sqrt = raft::linalg::unaryOp; - - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } - - // Then calculate Hellinger distance - ops::hellinger_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - // Finally revert sqrt of x and y - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::jensen_shannon_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); - }; - - auto unaryOp_lambda_reverse = [] __device__(DataT input) { - // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); - }; - - if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); - } - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op distance_op{is_row_major, x == y}; - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - if (x != y) { - // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); - } -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::l1_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl_l2_expanded( // NOTE: different name - bool perform_sqrt, // dispatch on sqrt - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if ((x == y) && is_row_major) { - raft::linalg::rowNorm(x_norm, - x, - k, - std::max(m, n), - raft::linalg::L2Norm, - is_row_major, - stream, - raft::identity_op{}); - } else { - y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - } - - ops::l2_exp_distance_op distance_op{perform_sqrt}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = false; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = true; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = false; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = true; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::l_inf_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT metric_arg) -{ - ops::lp_unexp_distance_op distance_op{metric_arg}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::russel_rao_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * - * @param x first set of points - * @param y second set of points - * @param out output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* out, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), - "OutType can be uint8_t, float, double," - "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - - distance_impl( - handle, - distance_tag{}, - x, - y, - out, - m, - n, - k, - reinterpret_cast(workspace), - worksize, - fin_op, - isRowMajor, - metric_arg); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* out, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - auto fin_op = raft::identity_op(); - - distance( - handle, x, y, out, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) -{ - size_t worksize = 0; - constexpr bool is_allocated = (distanceType <= cuvs::distance::DistanceType::CosineExpanded) || - (distanceType == cuvs::distance::DistanceType::CorrelationExpanded); - constexpr int numOfBuffers = - (distanceType == cuvs::distance::DistanceType::CorrelationExpanded) ? 2 : 1; - - if (is_allocated) { - // TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input - // accuracy issue is resolved until then we allocate as m + n. - worksize += numOfBuffers * m * sizeof(AccType); - worksize += numOfBuffers * n * sizeof(AccType); - } - - return worksize; -} - -}; // namespace detail -}; // namespace distance -}; // namespace cuvs diff --git a/cpp/include/cuvs/distance/detail/distance_ops/all_ops.cuh b/cpp/include/cuvs/distance/detail/distance_ops/all_ops.cuh deleted file mode 100644 index ecbede398..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/all_ops.cuh +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Defines a named requirement "has_cutlass_op" -#include - -// The distance operations: -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include diff --git a/cpp/include/cuvs/distance/detail/distance_ops/canberra.cuh b/cpp/include/cuvs/distance/detail/distance_ops/canberra.cuh deleted file mode 100644 index 8bbdc9945..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/canberra.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::abs -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief The canberra distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) - */ -template -struct canberra_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - const auto add = raft::abs(x) + raft::abs(y); - // deal with potential for 0 in denominator by - // forcing 0/1 instead - acc += ((add != 0) * diff / (add + (add == 0))); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/correlation.cuh b/cpp/include/cuvs/distance/detail/distance_ops/correlation.cuh deleted file mode 100644 index f033f3dfa..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/correlation.cuh +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -/** @brief The correlation distance - * - * It computes the following equation: - * - * d(x, y) = ((x - mean(x)) â‹… (y - mean(y))) - * / - * (|| x - mean(x) ||_2 || y - mean(y) ||_2) - */ -template -struct correlation_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const DataT* x2n; - const DataT* y2n; - IdxT m; - IdxT n; - IdxT k; - - correlation_distance_op( - bool is_row_major, const DataT* x2n_, const DataT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept - : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) - { - // The distance op is typically created before the row-major/col-major - // swapping has been done. So we do it here. - if (!is_row_major) { - std::swap(x2n, y2n); - std::swap(m, n); - } - } - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - // Note how we can sneakily get a pointer to shared memory here, to store - // more data. If the implementation of PairwiseDistanceMatKernel ever - // changes, this will be where we find the bugs. - extern __shared__ char smem[]; - - DataT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; - - DataT* sx2Norm = - (DataT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(DataT)]); - DataT* sy2Norm = (&sx2Norm[Policy::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * Policy::Nblk) { - for (int i = threadIdx.x; i < Policy::Mblk; i += Policy::Nthreads) { - auto idx = gridStrideY + i; - sx2Norm[i] = idx < m ? x2n[idx] : 0; - } - } - - for (int i = threadIdx.x; i < Policy::Nblk; i += Policy::Nthreads) { - auto idx = gridStrideX + i; - sy2Norm[i] = idx < n ? y2n[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - regx2n[i] = sx2Norm[i * Policy::AccThRows + (threadIdx.x / Policy::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < Policy::AccColsPerTh; ++i) { - regy2n[i] = sy2Norm[i * Policy::AccThCols + (threadIdx.x % Policy::AccThCols)]; - } - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); - auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); - auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - - acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/cosine.cuh b/cpp/include/cuvs/distance/detail/distance_ops/cosine.cuh deleted file mode 100644 index d48731651..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/cosine.cuh +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -// Epilogue operator for CUTLASS based kernel -template -struct cosine_cutlass_op { - __device__ cosine_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the expanded cosine distance matrix calculation - * - * It computes the following equation: - * - * d(x, y) = 1 - (x â‹… y) / ( ||x||_2 ||y||_2) - */ -template -struct cosine_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); - } - } - } - - constexpr cosine_cutlass_op get_cutlass_op() const - { - return cosine_cutlass_op(); - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/cutlass.cuh b/cpp/include/cuvs/distance/detail/distance_ops/cutlass.cuh deleted file mode 100644 index 6d928314d..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/cutlass.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // std::false_type -#include // std::declval - -namespace cuvs::distance::detail::ops { - -// This file defines the named requirement "has_cutlass_op" that can be used to -// determine if a distance operation has a CUTLASS op that can be used to pass -// to CUTLASS. Examples of distance operations that satisfy this requirement are -// cosine_distance_op and l2_exp_distance_op. - -// Primary template handles types that do not support CUTLASS. -// This pattern is described in: -// https://en.cppreference.com/w/cpp/types/void_t -template -struct has_cutlass_op : std::false_type {}; - -// Specialization recognizes types that do support CUTLASS -template -struct has_cutlass_op().get_cutlass_op())>> - : std::true_type {}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/hamming.cuh b/cpp/include/cuvs/distance/detail/distance_ops/hamming.cuh deleted file mode 100644 index 7c6553f38..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/hamming.cuh +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the Hamming Unexpanded distance matrix calculation - * It computes the following equation: - * - * c_ij = sum_k (x_ik != y_kj) / k - */ -template -struct hamming_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - IdxT k; - - hamming_distance_op(IdxT k_) noexcept : k(k_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += (x != y); }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - const DataT one_over_k = DataT(1.0) / k; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] *= one_over_k; - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/hellinger.cuh b/cpp/include/cuvs/distance/detail/distance_ops/hellinger.cuh deleted file mode 100644 index ad5ca3156..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/hellinger.cuh +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the Hellinger distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) - * - */ -template -struct hellinger_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - // This is sqrt(x) * sqrt(y). - const auto product = x * y; - acc += product; - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - const auto finalVal = (1 - acc[i][j]); - const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::sqrt(rectifier * finalVal); - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/cuvs/distance/detail/distance_ops/jensen_shannon.cuh deleted file mode 100644 index 216639494..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/jensen_shannon.cuh +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::log -#include // DI - -namespace cuvs::distance::detail::ops { - -// Describes the computation the jensen_shannon distance - -/** - * @brief the Jensen Shannon distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) - */ -template -struct jensen_shannon_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const DataT m = 0.5f * (x + y); - const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::log(m + m_zero); - - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(0.5 * acc[i][j]); - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/cuvs/distance/detail/distance_ops/kl_divergence.cuh deleted file mode 100644 index 929c3a559..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/kl_divergence.cuh +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::log -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the KL Divergence distance matrix calculation - * - * It computes the following equation: - * - * c_ij = 0.5 * sum(x * log (x / y)); - */ -template -struct kl_divergence_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const bool is_row_major; - const bool x_equal_y; - - kl_divergence_op(bool row_major_, bool x_equal_y_ = false) noexcept - : is_row_major(row_major_), x_equal_y(x_equal_y_) - { - } - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - // TODO: make sure that these branches get hoisted out of main loop.. Could - // be quite expensive otherwise. - if (x_equal_y) { - if (is_row_major) { - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); - } else { - const bool y_zero = (y == 0); - const bool x_zero = (x == 0); - acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); - } - } else { - if (is_row_major) { - const bool x_zero = (x == 0); - acc += x * (raft::log(x + x_zero) - y); - } else { - const bool y_zero = (y == 0); - acc += y * (raft::log(y + y_zero) - x); - } - } - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = (0.5f * acc[i][j]); - } - } - } -}; -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/l1.cuh b/cpp/include/cuvs/distance/detail/distance_ops/l1.cuh deleted file mode 100644 index 76eaffaf3..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/l1.cuh +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the L1 distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sum_k abs(x_ik - y_kj) - */ -template -struct l1_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Do not load norms of data, the computation of L1 distance does not use them. - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - }; -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/cuvs/distance/detail/distance_ops/l2_exp.cuh deleted file mode 100644 index f45c41206..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/l2_exp.cuh +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * Reserve 1 digit of precision from each floating-point type - * for round-off error tolerance. - * @tparam DataT - */ -template -__device__ constexpr DataT get_clamp_precision() -{ - switch (sizeof(DataT)) { - case 2: return 1e-3; - case 4: return 1e-6; - case 8: return 1e-15; - default: return 0; - } -} - -// Epilogue operator for CUTLASS based kernel -template -struct l2_exp_cutlass_op { - bool sqrt; - - __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} - __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept - { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - - /** - * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) - * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. - */ - outVal = outVal * !((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm)); - return sqrt ? raft::sqrt(outVal * (outVal > 0)) : outVal; - } - - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the expanded euclidean distance matrix calculation - * - * It computes the following equation: - * - * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 - * - */ -template -struct l2_exp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const bool sqrt; - - l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - DataT accVal = acc[i][j]; - DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal; - - /** - * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product - * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal - * instead. - */ - acc[i][j] = - val * (val > 0) * !((val * val < get_clamp_precision()) * (regxn[i] == regyn[j])); - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - } - - constexpr l2_exp_cutlass_op get_cutlass_op() const - { - return l2_exp_cutlass_op(sqrt); - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/cuvs/distance/detail/distance_ops/l2_unexp.cuh deleted file mode 100644 index aa6cc27f3..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/l2_unexp.cuh +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the unexpanded euclidean distance matrix calculation - * - * It computes the following equation: - * - * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) - */ -template -struct l2_unexp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - bool sqrt; - - l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} - - // Do not load norms of data, the computation of L1 distance does not use them. - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = x - y; - acc += diff * diff; - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - if (sqrt) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - }; -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/l_inf.cuh b/cpp/include/cuvs/distance/detail/distance_ops/l_inf.cuh deleted file mode 100644 index d8f9384d7..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/l_inf.cuh +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the L_inf (Chebyshev) distance matrix calculation - * - * It computes the following equation: - * - * c_ij = max_k | x_ik - y_kj | - */ -template -struct l_inf_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - acc = raft::max(acc, diff); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/cuvs/distance/detail/distance_ops/lp_unexp.cuh deleted file mode 100644 index 6136f9f3e..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/lp_unexp.cuh +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::pow, raft::abs -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the unexpanded Lp (Minkowski) distance matrix calculation - * - * It computes the following equation: - * - * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) - */ -template -struct lp_unexp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - DataT p; - - lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - acc += raft::pow(diff, p); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - const auto one_over_p = 1.0f / p; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::pow(acc[i][j], one_over_p); - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/cuvs/distance/detail/distance_ops/russel_rao.cuh deleted file mode 100644 index 5dffdcdb8..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/russel_rao.cuh +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -/** - * @brief the Russell Rao distance matrix calculation - * - * It computes the following equation: - * - * c_ij = (k - (sum_k x_ik * y_kj)) / k - */ -template -struct russel_rao_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - IdxT k; - const float one_over_k; - - russel_rao_distance_op(IdxT k_) noexcept : k(k_), one_over_k(1.0f / k_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = (k - acc[i][j]) * one_over_k; - } - } - } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/distance_ops/template.cuh b/cpp/include/cuvs/distance/detail/distance_ops/template.cuh deleted file mode 100644 index bdb933237..000000000 --- a/cpp/include/cuvs/distance/detail/distance_ops/template.cuh +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace cuvs::distance::detail::ops { - -// Describes the computation the template distance -// -// Fill in the TODO items. - -template -struct template_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - TODO member; - - template_distance_op(TODO member_) noexcept : member(member_) {} - - // Load norms of input data - static constexpr bool use_norms = TODO; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + TODO; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { TODO; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - TODO; - } - - // If exist, returns a cutlass op that performs the same operation. - // See cosine and l2_exp distance ops for an example. - constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } -}; - -} // namespace cuvs::distance::detail::ops diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/cuvs/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h deleted file mode 100644 index f659ed256..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ /dev/null @@ -1,671 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - -This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2.9.1 -(https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) - -Changes: -- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row -reduction -*/ - -#pragma once - -#if defined(__CUDACC_RTC__) -#include -#include -#else -#include -#include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -struct EpilogueWithBroadcastOpBaseCustom { - using ElementOutput = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = StoreZ; - - /// If true, the 'T' tensor is stored - static bool const kStoreT = StoreT; - - /// Parameters structure - required - struct Params {}; - - // - // Methods - // - - /// Constructor from Params - EpilogueWithBroadcastOpBaseCustom(Params const& params_) {} - - /// Determine if the source is needed. May return false if - bool is_source_needed() const { return true; } - - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Epilogue operator with bias vector broadcast over columns. -/// -/// Computes the following: -/// -/// -/// Z, T = OutputOp(AB, C, Broadcast) -/// -/// if (ElementwiseOp::kStoreZ) { -/// store(converted_u); -/// } -/// -/// if (ElementwiseOp::kStoreT) { -/// store(v); -/// } -/// -template < - typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) - int PartitionsK, ///< Number of partitions of the K dimension - typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) - typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) - typename ElementVector_, ///< Pointer to broadcast vector - typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators - typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM - typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM - typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp - typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: - ///< MatrixShape) - int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity - int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large - (!IsEpilogueFunctorHeavy::value)> -class EpilogueWithBroadcastCustom : public EpilogueBase { - public: - using Base = EpilogueBase; - - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - using OutputTileIterator = OutputTileIterator_; - using TensorTileIterator = TensorTileIterator_; - using ElementVector = ElementVector_; - using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp = OutputOp_; - using Padding = Padding_; - - using Layout = layout::RowMajor; - using LongIndex = typename Layout::LongIndex; - - /// The complete warp-level accumulator tile - using AccumulatorTile = typename Base::AccumulatorTile; - - /// Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; - - /// Compute data type produced by the output op - using ElementCompute = typename OutputOp::ElementCompute; - - /// Compute fragment - using FragmentCompute = Array; - - /// Thread map used by output tile iterators - using ThreadMap = typename OutputTileIterator::ThreadMap; - - /// Fragment object used to store the broadcast values - using BroadcastFragment = - Array; - - /// Output element - using ElementOutput = typename OutputTileIterator::Element; - - /// Data type of additional tensor - using ElementTensor = typename TensorTileIterator::Element; - - /// Output access size - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - /// Tensor reference to destination tensor - using TensorRef = typename OutputTileIterator::TensorRef; - - /// Tensor reference to sync tensor - using SyncTensorRef = typename cutlass::TensorRef; - - /// Const tensor reference to source tensor - using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - - /// Array type used to output - using OutputAccessType = - Array; - - /// Array type used by output functor - using AccumulatorAccessType = - Array; - - /// Array type used by output functor - using ComputeAccessType = Array; - - /// Tensor access type - using TensorAccessType = Array; - - /// Number of warps - using WarpCount = typename Base::WarpCount; - - /// Shared memory allocation from epilogue base class - using BaseSharedStorage = typename Base::SharedStorage; - - static int constexpr kSmemTiles = - Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; - static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; - - /// Used for the broadcast - struct BroadcastDetail { - /// Number of threads per warp - static int const kWarpSize = 32; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar column indices handled by each thread - static int const kColumnsPerThread = - ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar row indices handled by each thread - static int const kRowsPerThread = - ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; - - /// Number of threads per threadblock - static int const kThreadCount = kWarpSize * WarpCount::kCount; - - /// Number of distinct threads per row of output tile - static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); - - /// Number of distinct threads which must be reduced during the final reduction phase within the - /// threadblock. - static int const kThreadRows = kThreadCount / kThreadsPerRow; - - /// I'm not sure what I meant here. - static int const kThreadAccessesPerRow = - const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); - - /// Shape of the shared memory allocation for the epilogue - using StorageShape = MatrixShape; - - /// Debug printing - CUTLASS_DEVICE - static void print() - { -#if 0 - printf("BroadcastDetail {\n"); - printf( - " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" - "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", - kColumnsPerThread, - kRowsPerThread, - kThreadCount, - kThreadsPerRow, - kThreadRows, - kThreadAccessesPerRow, - StorageShape::kRow, - StorageShape::kColumn, - StorageShape::kCount - ); - printf("};\n"); -#endif - } - }; - - /// Shared storage structure (shadows base) with additional SMEM buffer for reduction - struct SharedStorage { - union { - BaseSharedStorage base; - }; - - CUTLASS_HOST_DEVICE - SharedStorage() {} - }; - - public: - static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, - "OutputTileIterator::kElementsPerAccess must not be zero."); - - static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), - "Divisibility"); - - private: - /// Loads fragment from shared memory aligned with output tensor - SharedLoadIterator shared_load_iterator_; - - /// Thread index within the threadblock - int thread_idx_; - - public: - /// Constructor - CUTLASS_DEVICE - EpilogueWithBroadcastCustom(SharedStorage& shared_storage, ///< Shared storage object - int thread_idx, ///< ID of a thread within the threadblock - int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ) - : Base(shared_storage.base, thread_idx, warp_idx, lane_idx), - shared_load_iterator_(shared_storage.base.reference(), thread_idx), - thread_idx_(thread_idx) - { - } - - /// Streams the result to global memory - CUTLASS_DEVICE - void operator()( - OutputOp const& output_op, ///< Output operator - ElementVector const* broadcast_ptr, ///< Broadcast vector - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix - TensorTileIterator - tensor_iterator, ///< Threadblock tile iterator for additional tensor operand - MatrixCoord const& - problem_size = ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord(Shape::kM, Shape::kN), - MatrixCoord const& - threadblock_offset = ///< Threadblock's initial offset within the problem size space - MatrixCoord()) - { - BroadcastFragment broadcast_fragment; - - load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); - - compute_source_needed_( - output_op, broadcast_fragment, accumulators, source_iterator, tensor_iterator); - } - - private: - CUTLASS_DEVICE - void load_broadcast_fragment_( - BroadcastFragment& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - ElementVector const* broadcast_ptr, ///< Broadcast vector - MatrixCoord const& - problem_size, ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord const& - threadblock_offset ///< Threadblock's initial offset within the problem size space - ) - { - broadcast_fragment.clear(); - - // If no pointer is supplied, set with all zeros and avoid memory accesses - if (!broadcast_ptr) { return; } - - int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); - - int thread_column_idx = threadblock_offset.column() + thread_initial_column; - broadcast_ptr += thread_initial_column; - - NumericArrayConverter - converter; - using AccessType = AlignedArray; - using ComputeFragmentType = Array; - - ComputeFragmentType* frag_ptr = reinterpret_cast(&broadcast_fragment); - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { - AccessType loaded; - - loaded.clear(); - - if (thread_column_idx < problem_size.column()) { - loaded = *reinterpret_cast(broadcast_ptr); - } - - ComputeFragmentType cvt = converter(loaded); - frag_ptr[j] = cvt; - - thread_column_idx += ThreadMap::Delta::kColumn; - broadcast_ptr += ThreadMap::Delta::kColumn; - } - } - - template - struct acc2smem_source_not_needed; - - template - struct acc2smem_source_not_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator& warp_tile_iterator) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - warp_tile_iterator.store(accum_fragment); - if (p < Base::kFragmentsPerIteration - 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); - } - } - - if (Base::kFragmentsPerIteration > 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * - (1 - Base::kFragmentsPerIteration)); - } - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const& iterator_begin, - WarpTileIterator& warp_tile_iterator) - { - int dummy[] = { - (pos == (Seq * Base::kFragmentsPerIteration)) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - - CUTLASS_UNUSED(dummy[0]); - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_not_needed_( - OutputOp const& output_op, ///< Output operator - BroadcastFragment const& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) - { - } - - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator& warp_tile_iterator) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const& iterator_begin, - WarpTileIterator& warp_tile_iterator) - { - int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_needed_( - OutputOp const& output_op, ///< Output operator - BroadcastFragment const& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator - source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) - { - typename OutputTileIterator::Fragment source_fragment; - source_fragment.clear(); - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - -#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Convert and store fragment - // - - //__syncthreads(); - - acc2smem_source_needed>::push( - iter, accum_fragment_iterator, this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - // - // Apply output operation - // - - typename TensorTileIterator::Fragment frag_T; - - // - // Load the source - // - - source_iterator.load(source_fragment); - ++source_iterator; - - apply_output_operator_( - frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); - - // - // Conditionally store fragments - // - if (OutputOp::kStoreT) { - tensor_iterator.store(frag_T); - ++tensor_iterator; - } - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_(typename TensorTileIterator::Fragment& frag_T, - OutputOp const& output_op, - typename SharedLoadIterator::Fragment const& frag_AB, - typename OutputTileIterator::Fragment const& frag_C, - BroadcastFragment const& frag_Broadcast) - { - using AccessTypeT = Array; - using AccessTypeBroadcast = Array; - - AccessTypeT* frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const* frag_AB_ptr = - reinterpret_cast(&frag_AB); - - OutputAccessType const* frag_C_ptr = reinterpret_cast(&frag_C); - - AccessTypeBroadcast const* frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); - - int const kOutputOpIterations = - TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - output_op(frag_T_ptr[i], - frag_AB_ptr[i], - frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_source_not_needed_( - typename OutputTileIterator::Fragment& frag_Z, - typename TensorTileIterator::Fragment& frag_T, - OutputOp const& output_op, - typename SharedLoadIterator::Fragment const& frag_AB, - BroadcastFragment const& frag_Broadcast) - { - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/cuvs/distance/detail/fused_distance_nn/cutlass_base.cuh deleted file mode 100644 index 7c0b5d127..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/cutlass_base.cuh +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#pragma GCC diagnostic ignored "-Wtautological-compare" - -// We define CUTLASS_NAMESPACE in case -// RAFT cmake is not used -#ifndef CUTLASS_NAMESPACE -#define cutlass raft_cutlass -#endif - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include // FusedDistanceNNEpilogueElementwise -#include // FusedDistanceNNGemm -#include // getMultiProcessorCount -#include // RAFT_CUTLASS_TRY - -namespace cuvs { -namespace distance { -namespace detail { - -template -void cutlassFusedDistanceNN(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - int* mutexes, - CGReduceOpT cg_reduce_op, - DistanceFn dist_op, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - cudaStream_t stream) -{ - using EpilogueOutputOp = cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise< - DataT, // ElementC_ - AccT, // ElementAccumulator_ - DataT, // ElementCompute_ - AccT, // ElementZ_ - OutT, // ElementT_ - // 128 / cutlass::sizeof_bits::value, - 1, // Elements per access 1 - DistanceFn, - CGReduceOpT, - ReduceOpT, - KVPReduceOpT>; - constexpr int batch_count = 1; - - typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); - - // Number of pipelines you want to use - constexpr int NumStages = 3; - // Alignment - constexpr int Alignment = VecLen; - - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(m, n, k); - - constexpr bool isRowMajor = true; - - using fusedDistanceNNKernel = - typename cutlass::gemm::kernel::FusedDistanceNNGemm::GemmKernel; - - using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; - - int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); - int num_sms = raft::getMultiProcessorCount(); - int full_wave = num_blocks_per_sm * num_sms; - constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; - constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; - int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; - int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; - int totalTiles = columnTiles * rowTiles; - int thread_blocks = - rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles; - - typename fusedDistanceNN::Arguments arguments{ - problem_size, - batch_count, // num of problems. - thread_blocks, - epilog_op_param, - x, - y, - xn, // C matrix eq vector param, which here is A norm - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)lda, // stride A - (int64_t)ldb, // stride B - (int64_t)1, // stride A norm - (int64_t)ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = fusedDistanceNN::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - fusedDistanceNN fusedDistanceNN_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(fusedDistanceNN_op.can_implement(arguments)); - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(fusedDistanceNN_op.initialize(arguments, workspace.data(), stream)); - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(fusedDistanceNN_op.run(stream)); -} - -}; // namespace detail -}; // namespace distance -}; // namespace cuvs - -#pragma GCC diagnostic pop diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue.cuh deleted file mode 100644 index 7053f2702..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue.cuh +++ /dev/null @@ -1,136 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) - -This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec -and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise -operation. --- A norm load is provided PredicatedTileIteratorNormVec --- B norm load is provided by EpilogueWithBroadcast --- elementwise operation is provided by OutputOp -*/ - -#pragma once - -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template -struct FusedDistanceNNEpilogue { - /// Use defaults related to the existing epilogue - using Base = - DefaultEpilogueTensorOp; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using RowNormTileIterator = cutlass::epilogue::threadblock:: - PredicatedTileIteratorNormVecSmem; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec< - typename Base::OutputTileThreadMap, - ElementTensor, - LayoutT, - typename OutputOp::Params>; - - /// Define the epilogue - using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom< - Shape, - WarpMmaTensorOp, - PartitionsK, - RowNormTileIterator, - OutputTileIterator, - ElementVector, - typename Base::AccumulatorFragmentIterator, - typename Base::WarpTileIterator, - typename Base::SharedLoadIterator, - OutputOp, - typename Base::Padding, - Base::kFragmentsPerIteration>; -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue_elementwise.cuh deleted file mode 100644 index a21f3d60e..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ /dev/null @@ -1,216 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -/*! \file - \brief Functor performing distance operations used by epilogues of pairwise distance - * kernels. -* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 -* customized for applying elementwise distance formula on accumulated GEMM value -* and applying user-defined operation which can convert distance values to key-value pair. -* . -*/ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -class FusedDistanceNNEpilogueElementwise { - public: - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using DistanceOp = DistanceOp_; - using CGReduceOp = CGReduceOp_; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using OutValT = typename CGReduceOp::AccTypeT; - using FragmentT = Array; - - using FragmentOutput = FragmentZ; - - static bool const kIsHeavy = true; // ElementwiseOp::kIsHeavy; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = false; // We don't store anything in Z, - - /// If true, the 'T' tensor is stored - static bool const kStoreT = true; // this is our final output storage. - - /// Host-constructable parameters structure - struct Params { - CGReduceOp_ cg_reduce_op; - DistanceOp_ dist_op_; - KVPReduceOpT_ pair_redop_; - ReduceOpT_ red_op_; - int* mutexes_; - using CGReduceT = CGReduceOp_; - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, - CGReduceOp cg_reduce_op, - ReduceOpT_ red_op, - KVPReduceOpT_ pair_redop, - int* mutexes) - : cg_reduce_op(cg_reduce_op), - dist_op_(dist_op), - pair_redop_(pair_redop), - red_op_(red_op), - mutexes_(mutexes) - { - } - - CUTLASS_HOST_DEVICE - Params() {} - }; - - private: - // - // Data members - // - DistanceOp_ elementwise_op; - KVPReduceOpT_ pair_redop; - - public: - ReduceOpT_ red_op; - - // - // Methods - // - - /// Constructor from Params - CUTLASS_HOST_DEVICE - FusedDistanceNNEpilogueElementwise(Params const& params) - : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) - { - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const - { - // we use for making sure C matrix is used for A mat norm. - return true; - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_C = - NumericArrayConverter()(frag_C); - FragmentCompute result_Z; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - frag_T[i] = res_Z; - } - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/gemm.h b/cpp/include/cuvs/distance/detail/fused_distance_nn/gemm.h deleted file mode 100644 index fd5956a57..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/gemm.h +++ /dev/null @@ -1,410 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include - -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/* - * This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, - * ideal threadblock tile shape is 32x256x16 for such cases as there is no - * registers spills for it. - * - */ -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // This struct is specialized for fp32/3xTF32 - - /// Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 32, N = 256, K = 16 - // this is more performant but note that for veclen = 1 - // this shape has register spills - using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; - - // <- threadblock tile M = 32, N = 128, K = 16 - // this shape has high occupancy but less perf - // this is less performant but this shape has *no* register spills - // for any veclens(1, 2, 4) - // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 64, N = 64, K = 16 - // this is more performant for veclen 2,4. - using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; - - // this shape has high occupancy but less perf used for 32x128x16 - // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - // <- MMA Op tile M = 16, N = 8, K = 4 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -/* - * This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, - * ideal threadblock tile shape is 32x128x16 for such cases as there is no - * registers spills for it. - * - */ -template < - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // This struct is specialized for fp32/3xTF32 - using ElementA_ = float; - using ElementB_ = float; - - /// Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 32, N = 128, K = 16 - // this shape has high occupancy and no register spills for veclen = 1. - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 32, N = 32, K = 16 - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - // <- MMA Op tile M = 16, N = 8, K = 4 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -template < - /// Layout type for A matrix operand - int kAlignmentA, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 64, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; - // using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 32, N = 32, K = 16 - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - // using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - // Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAdd; - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/cuvs/distance/detail/fused_distance_nn/persistent_gemm.h deleted file mode 100644 index 3a8d6c865..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/persistent_gemm.h +++ /dev/null @@ -1,515 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Problem visitor for grouped GEMMs -This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 -(https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) - -Changes: -- adds support for only single problem size to be launched persistently - where each threablock processes more than one tile of the same problem. -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FusedDistanceNNPersistent { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = GemmGroupedProblemVisitor; - - // - // Structures - // - - struct temp_problem_visitor { - int problem_count; - - CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0){}; - CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; - }; - - /// Argument structure - struct Arguments { - // - // Data members - // - GemmCoord problem_sizes; - temp_problem_visitor problem_visitor; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - void const* ptr_A; - void const* ptr_B; - void const* ptr_C; - void* ptr_Vector; - void* ptr_Tensor; - - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; - typename LayoutC::Stride::Index ldt; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : // problem_count(0), - threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_Vector(nullptr), - ptr_Tensor(nullptr), - lda(0), - ldb(0), - ldc(0), - ldt(0), - host_problem_sizes(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord problem_sizes, - int problem_count, - int threadblock_count, - typename EpilogueOutputOp::Params output_op, - void const* ptr_A, - void const* ptr_B, - void const* ptr_C, - void* ptr_Vector, - void* ptr_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldt, - GemmCoord* host_problem_sizes = nullptr) - : problem_sizes(problem_sizes), - threadblock_count(threadblock_count), - output_op(output_op), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - lda(lda), - ldb(ldb), - ldc(ldc), - ldt(ldt), - host_problem_sizes(host_problem_sizes) - { - problem_visitor.problem_count = problem_count; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - // typename ProblemVisitor::Params problem_visitor; - temp_problem_visitor problem_visitor; - int threadblock_count; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; - - void* ptr_A; - void* ptr_B; - void* ptr_C; - void* ptr_Vector; - void* ptr_Tensor; - - GemmCoord problem_size; - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; - typename LayoutC::Stride::Index ldt; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : params_A(0), - params_B(0), - params_C(0), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_Vector(nullptr), - ptr_Tensor(nullptr), - lda(0), - ldb(0), - ldc(0), - ldt(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_size(args.problem_sizes), - threadblock_count(args.threadblock_count), - output_op(args.output_op), - params_A(args.lda), - params_B(args.ldb), - params_C(args.ldc), - // Here we pass additional user args via args.output_op - // to the reduction output tile iterator - params_Tensor(args.ldt, args.output_op), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_Vector(args.ptr_Vector), - ptr_Tensor(args.ptr_Tensor), - lda(args.lda), - ldb(args.ldb), - ldc(args.ldc), - ldt(args.ldt) - { - problem_visitor.problem_count = args.problem_visitor.problem_count; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); - ptr_Vector = args.ptr_Vector; - ptr_Tensor = args.ptr_Tensor; - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldt = args.ldt; - - problem_size = args.problem_sizes; - } - }; - - /// Shared memory storage structure - struct SharedStorage { - union { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - typename Epilogue::TensorTileIterator::SharedStorage reduced_store; - typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; - }; - - public: - // - // Methods - // - - CUTLASS_DEVICE - FusedDistanceNNPersistent() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) { return Status::kSuccess; } - - static size_t get_extra_workspace_size(Arguments const& args, - cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - return 0; - } - - CUTLASS_DEVICE - static uint32_t tile_count(const cutlass::MatrixCoord& grid) - { - return grid.row() * grid.column(); - } - - /// Get the grid shape - CUTLASS_DEVICE - static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) - { - return cutlass::MatrixCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if __CUDA_ARCH__ >= 800 - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - const GemmCoord& problem_size = params.problem_size; - const auto grid_shape_ = grid_shape(problem_size); - const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; - const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - typename LayoutB::Index column = - ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; - - typename LayoutB::Index row = - ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; - if (column) { - shared_storage.reduced_store.initSmem(params.output_op); - shared_storage.rownorm_store.initSmem(params.ptr_C, problem_size.m(), row, sizeof(ElementC)); - } - - // Outer 'persistent' loop to iterate over tiles - for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { - const auto grid_shape_ = grid_shape(problem_size); - cutlass::MatrixCoord threadblock_offset( - int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, - int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); - - const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); - const bool doesRowChange = - ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); - const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - //__syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = static_cast(params.ptr_C); - typename Epilogue::ElementTensor* ptr_Tensor = - static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector* ptr_Vector = - static_cast(params.ptr_Vector); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_rownorm(shared_storage.rownorm_store, - params.params_C, - ptr_C, - problem_size.mn(), - thread_idx, - threadblock_offset); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - problem_size.mn(), - thread_idx, - do_gmem_reduce, - threadblock_offset); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - // Move to appropriate location for this output tile - if (ptr_Vector) { ptr_Vector += threadblock_offset.column(); } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - ptr_Vector, - // iterator_D, - accumulators, - iterator_rownorm, - tensor_iterator, - problem_size.mn(), - threadblock_offset); - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h deleted file mode 100644 index 14c09f6ae..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ /dev/null @@ -1,448 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- Only the row index is used to load the data in load_with_byte_offset(). - This way the same normalization data is used across all columns in a row. - -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorNormVecSmem { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile * ThreadMap::Delta::kRow; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - /// Shared storage allocation needed by the predicated tile - // iterator for storing rowNorm chunk. - struct SharedStorage { - // - // Type definitions - // - using Shape = MatrixShape; - - /// Shape of the shared memory allocation - using StorageShape = MatrixShape; - - // - // Data members - // - // Methods - // - AlignedBuffer storage; - - CUTLASS_DEVICE - Element* data() { return storage.data(); } - - SharedStorage() {} - - CUTLASS_DEVICE - void initSmem(void* pointer, - const Index& num_rows, - const Index& tb_row_offset, - const LongIndex& stride) - { - Element* shared_elem_arr = data(); - uint8_t* first_tile_byte_pointer_ = - reinterpret_cast(pointer) + LongIndex(tb_row_offset) * LongIndex(stride); - const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - bool guard = (tb_row_offset + row) < num_rows; - cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); - cutlass::arch::cp_async_wait<0>(); - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t* byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - - /// Scatter indices - int const* indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - - private: - // - // Methods - // - - protected: - SharedStorage& shared_storage_; - - public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, - PredicatedTileIteratorParams const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - TensorCoord& threadblock_offset, - int const* indices = nullptr) - : params_(params), indices_(indices), shared_storage_(shared_storage) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = - ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { - mask_.clear(); - return; - } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - if (threadblock_offset.column() == 0) { - shared_storage_.initSmem(pointer, extent_row_, threadblock_offset.row(), params_.stride); - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - AccessType* frag_ptr = reinterpret_cast(&frag); - - Element* shared_elem_arr = shared_storage_.data(); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - int iter_row = ((row_offset + thread_start_row_) % total_rows); - Element val = shared_elem_arr[iter_row]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - (*frag_ptr)[frag_row_idx + i] = val; - } - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorNormVecSmem& operator++() - { - ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h deleted file mode 100644 index dc224c5c9..000000000 --- a/cpp/include/cuvs/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ /dev/null @@ -1,626 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). -- makes use of `SharedStorage` to store reduced values across warps to gmem in coalesced manner. -- customized the store_with_byte_offset() to perform reduction per row and write final value to -gmem. -- customized the Params() struct to take user inputs from epilogueOp params. - -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cg = cooperative_groups; - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorReducedVec { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - using EpilogueOpParams = EpilogueOpParams_; - using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; - using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); - - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile * ThreadMap::Delta::kRow; - /// Fragment object - using Fragment = - Array; - - // Memory access size - using AccessType = AlignedArray; - using AccessTypeValT = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - EpilogueOpParams user_param; - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Layout const& layout, EpilogueOpParams const& user_param_) - : PredicatedTileIteratorParams(int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()), - user_param(user_param_) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - // static int const kCount = ThreadMap::Iterations::kColumn; - static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - /// Shared storage allocation needed by the predicated tile - // iterator for reduction. - struct SharedStorage { - // - // Type definitions - // - using Shape = MatrixShape; - - /// Shape of the shared memory allocation for the reduced values store - using StorageShape = MatrixShape; - - // - // Data members - - // - // Methods - // - AlignedBuffer storage; - - CUTLASS_DEVICE - Element* data() { return storage.data(); } - - SharedStorage() {} - - CUTLASS_DEVICE - void initSmem(EpilogueOpParams const& user_params) - { - Element* shared_elem_arr = data(); - constexpr auto maxVal = std::numeric_limits::max(); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - user_params.red_op_.init(&shared_elem_arr[row], maxVal); - } - } - }; - - template - struct select_reduce { - /// Performs warp level reduction and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(OutT value, - ValT prev_red_val, - cg_reduce_op_t reduce_op, - cg_group_t cg_warp_group, - OutT& shmem_ptr) - { - if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { - OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); - if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } - } - } - }; - - template - struct select_reduce> { - using ValT = float; - using Ty = raft::KeyValuePair; - /// Performs warp level reduction of key value pair and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(Ty val_to_red, - float prev_red_val, - cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, - Ty& shmem_ptr) - { - ValT val = val_to_red.value; - - if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { - ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } - } - } - } - }; - - template - struct select_reduce> { - using ValT = double; - using Ty = raft::KeyValuePair; - /// Performs warp level reduction of key value pair and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(Ty val_to_red, - double prev_red_val, - cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, - Ty& shmem_ptr) - { - ValT val = val_to_red.value; - - if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { - ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } - } - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - Params params_; - - /// Byte-level pointer - uint8_t* byte_pointer_; - /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - Index block_start_row_first_tile_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - // mutable int shared_tile_id; - - /// Scatter indices - int const* indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); - - protected: - SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; - - private: - // - // Methods - // - public: - // - // Methods - // - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, - Params const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - const bool& do_gmem_reduction, - TensorCoord threadblock_offset = TensorCoord(), - int const* indices = nullptr) - : params_(params), - indices_(indices), - shared_storage_(shared_storage), - do_gmem_reduction_(do_gmem_reduction) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; - block_start_row_first_tile_ = block_offset.row(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { - int columnPerAccess = (c / kElementsPerAccess); - int columnWithinPerAccess = c % kElementsPerAccess; - mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + - columnWithinPerAccess) < extent.column()); - } - - if (threadblock_offset.column() == 0) { - EpilogueOpParams const& user_params = params_.user_param; - shared_storage_.initSmem(user_params); - } - - // Null pointer performs no accesses - if (!pointer) { mask_.clear(); } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(block_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() - { - if (do_gmem_reduction_) { - EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); - const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } - } - } - - __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } - - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } - } - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Performs reduction and Stores a reduced output to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - AccessTypeValT* frag_ptr = reinterpret_cast(&frag); - - cg::thread_block cta = cg::this_thread_block(); - // tile_width 16 is required if kElementPerAccess > 1 - constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; - cg::thread_block_tile tile32 = cg::tiled_partition(cta); - EpilogueOpParams const& user_params = params_.user_param; - - using cg_reduce_t = decltype(user_params.cg_reduce_op); - using tile32_t = decltype(tile32); - - Element* shared_elem_arr = shared_storage_.data(); - constexpr auto maxVal = std::numeric_limits::max(); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - const OutIdxT row_id = row_offset + thread_start_row_; - bool row_guard = (row_id < extent_row_); - - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; - Element red_val; - user_params.red_op_.init(&red_val, maxVal); - - if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; - ++column) { - int columnPerAccess = column / kElementsPerAccess; - int columnWithPerAccess = column % kElementsPerAccess; - bool guard = mask_.predicates[column]; - if (guard) { - const OutIdxT key_id = thread_start_column_ + - ThreadMap::Delta::kColumn * columnPerAccess + - columnWithPerAccess; - const int frag_col_idx = frag_idx + column; - - Element this_val; - user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); - user_params.red_op_.init_key(this_val, key_id); - user_params.red_op_(row_id, &red_val, this_val); - } - } - // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, - // this satisfies the requirement of mst/single linkage of checking colors buffer. - select_reduce red_obj( - red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); - } - } - } - } - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorReducedVec& operator++() - { - ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/fused_l2_nn.cuh b/cpp/include/cuvs/distance/detail/fused_l2_nn.cuh deleted file mode 100644 index 0c2548863..000000000 --- a/cpp/include/cuvs/distance/detail/fused_l2_nn.cuh +++ /dev/null @@ -1,385 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // size_t -#include // ops::l2_exp_distance_op -#include -#include // PairwiseDistances -#include // std::numeric_limits -#include // raft::KeyValuePair -#include // raft::identity_op -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl - -namespace cuvs { -namespace distance { - -namespace detail { - -template -struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rid, DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } - - DI void init_key(DataT& out, LabelT idx) const { return; } - DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } - - DI DataT get_value(KVP& out) const - { - return out.value; - ; - } - DI DataT get_value(DataT& out) const { return out; } -}; - -template -struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { redOp.init(min + tid, maxVal); } -} - -template -void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) -{ - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); -} - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and raft::shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } -} - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) -{ -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; - - typedef KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the raft::shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif -} - -// cg::reduce functor for FusedDistanceNN used in its cutlass version -// to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template -struct kvp_cg_min_reduce_op { - typedef typename raft::KeyValuePair KVP; - - __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; - - using AccTypeT = AccType; - using IndexT = Index; - // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - - __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - - __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } -}; - -template -void fusedL2NNImpl(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedL2NN. - typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef raft::KeyValuePair KVPair; - - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - namespace arch = raft::util::arch; - using AccT = DataT; - ops::l2_exp_distance_op distance_op{sqrt}; - - raft::identity_op fin_op{}; - - auto kernel = fusedL2NNkernel; - - // Get pointer to fp32 SIMT kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = cuvs::distance::detail::ops::l2_exp_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; - kvp_cg_min_reduce_op_ cg_reduce_op; - L2Op L2_dist_op(sqrt); - - IdxT lda, ldb, ldd; - lda = k, ldb = k, ldd = n; - - cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - L2_dist_op, - redOp, - pairRedOp, - stream); - } else { - // If device less than SM_80, use fp32 SIMT kernel. - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); - - kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); - } -} - -} // namespace detail -} // namespace distance -} // namespace cuvs diff --git a/cpp/include/cuvs/distance/detail/kernels/gram_matrix.cuh b/cpp/include/cuvs/distance/detail/kernels/gram_matrix.cuh deleted file mode 100644 index 1f4424ea9..000000000 --- a/cpp/include/cuvs/distance/detail/kernels/gram_matrix.cuh +++ /dev/null @@ -1,489 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -// #include -#include -#include - -#include -#include - -namespace cuvs::distance::kernels::detail { - -template -using dense_input_matrix_view_t = raft::device_matrix_view; -template -using dense_output_matrix_view_t = raft::device_matrix_view; -template -using csr_input_matrix_view_t = raft::device_csr_matrix_view; - -/** - * Base class for general Gram matrices - * A Gram matrix is the Hermitian matrix of inner probucts G_ik = - * Here, the inner product is evaluated for all elements from vectors sets X1, - * and X2. - * - * To be more precise, on exit the output buffer will store: - * - if is_row_major == true: out[j+k*n1] = , - * - if is_row_major == false: out[j*n2 + k] = , - * where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector - * from the x2 set. - */ -template -class GramMatrixBase { - protected: - cublasHandle_t cublas_handle; - bool legacy_interface; - - public: - GramMatrixBase() : legacy_interface(false){}; - [[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle) - : cublas_handle(cublas_handle), legacy_interface(true){}; - - virtual ~GramMatrixBase(){}; - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - // unfortunately, 'evaluate' cannot be templatized as it needs to be virtual - - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] virtual void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out - */ - [[deprecated]] void operator()(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1 = 0, - int ld2 = 0, - int ld_out = 0) - { - ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor."); - if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } - if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } - if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } - evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - - protected: - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out - */ - [[deprecated]] void linear(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - math_t alpha = 1.0; - math_t beta = 0.0; - if (is_row_major) { - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - n2, - n1, - n_cols, - &alpha, - x2, - ld2, - x1, - ld1, - &beta, - out, - ld_out, - stream)); - } else { - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - n1, - n2, - n_cols, - &alpha, - x1, - ld1, - x2, - ld2, - &beta, - out, - ld_out, - stream)); - } - } - - protected: - bool get_is_row_major(dense_output_matrix_view_t matrix) - { - return (matrix.stride(1) == 1); - } - - bool get_is_row_major(dense_input_matrix_view_t matrix) - { - return (matrix.stride(1) == 1); - } - - bool get_is_col_major(dense_output_matrix_view_t matrix) - { - return (matrix.stride(0) == 1); - } - - bool get_is_col_major(dense_input_matrix_view_t matrix) - { - return (matrix.stride(0) == 1); - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check is_row_major consistency - bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out); - bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out); - ASSERT(is_row_major || is_col_major, - "GramMatrix leading dimensions for x1, x2 and out do not match"); - - // check dimensions - int n1 = out.extent(0); - int n2 = out.extent(1); - int n_cols = x1.extent(1); - ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match"); - ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match"); - ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match"); - - // extract major stride - int ld1 = is_row_major ? x1.stride(0) : x1.stride(1); - int ld2 = is_row_major ? x2.stride(0) : x2.stride(1); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - - math_t alpha = 1.0; - math_t beta = 0.0; - if (is_row_major) { - // #TODO: Use mdspan-based API when stride-capable - // https://github.com/rapidsai/raft/issues/875 - raft::linalg::gemm(handle, - true, - false, - n2, - n1, - n_cols, - &alpha, - x2.data_handle(), - ld2, - x1.data_handle(), - ld1, - &beta, - out.data_handle(), - ld_out, - resource::get_cuda_stream(handle)); - } else { - // #TODO: Use mdspan-based API when stride-capable - // https://github.com/rapidsai/raft/issues/875 - raft::linalg::gemm(handle, - false, - true, - n1, - n2, - n_cols, - &alpha, - x1.data_handle(), - ld1, - x2.data_handle(), - ld2, - &beta, - out.data_handle(), - ld_out, - resource::get_cuda_stream(handle)); - } - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check is_row_major consistency - bool is_row_major = get_is_row_major(x2) && get_is_row_major(out); - bool is_col_major = get_is_col_major(x2) && get_is_col_major(out); - ASSERT(is_row_major || is_col_major, - "GramMatrix leading dimensions for x2 and out do not match"); - - // check dimensions - auto x1_structure = x1.structure_view(); - ASSERT(x1_structure.get_n_rows() == out.extent(0), - "GramMatrix input matrix dimensions for x1 and out do not match"); - ASSERT(x2.extent(0) == out.extent(1), - "GramMatrix input matrix dimensions for x2 and out do not match"); - ASSERT(x2.extent(1) == x1_structure.get_n_cols(), - "GramMatrix input matrix dimensions for x1 and x2 do not match"); - - math_t alpha = 1.0; - math_t beta = 0.0; - - raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out); - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check layout consistency (w.r.t. strides a matrix might be both row & col major) - bool is_row_major_nopad = get_is_row_major(out) && out.stride(0) == out.extent(1); - bool is_col_major_nopad = get_is_col_major(out) && out.stride(1) == out.extent(0); - - ASSERT(is_row_major_nopad || is_col_major_nopad, - "Sparse linear Kernel distance does not support ld_out parameter"); - - // switch a,b based on is_row_major - if (is_col_major_nopad) { - auto out_row_major = raft::make_device_matrix_view( - out.data_handle(), out.extent(1), out.extent(0)); - raft::sparse::distance::pairwise_distance( - handle, x2, x1, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0); - } else { - auto out_row_major = raft::make_device_matrix_view( - out.data_handle(), out.extent(0), out.extent(1)); - raft::sparse::distance::pairwise_distance( - handle, x1, x2, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0); - } - } -}; - -}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/include/cuvs/distance/detail/kernels/kernel_factory.cuh b/cpp/include/cuvs/distance/detail/kernels/kernel_factory.cuh deleted file mode 100644 index d0f1f5569..000000000 --- a/cpp/include/cuvs/distance/detail/kernels/kernel_factory.cuh +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "gram_matrix.cuh" -#include "kernel_matrices.cuh" -#include -#include - -namespace cuvs::distance::kernels::detail { - -template -class KernelFactory { - public: - static GramMatrixBase* create(KernelParams params) - { - GramMatrixBase* res; - // KernelParams is not templated, we convert the parameters to math_t here: - math_t coef0 = params.coef0; - math_t gamma = params.gamma; - switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(); break; - case POLYNOMIAL: res = new PolynomialKernel(params.degree, gamma, coef0); break; - case TANH: res = new TanhKernel(gamma, coef0); break; - case RBF: res = new RBFKernel(gamma); break; - default: throw raft::exception("Kernel not implemented"); - } - return res; - } - - [[deprecated]] static GramMatrixBase* create(KernelParams params, cublasHandle_t handle) - { - GramMatrixBase* res; - // KernelParams is not templated, we convert the parameters to math_t here: - math_t coef0 = params.coef0; - math_t gamma = params.gamma; - switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(handle); break; - case POLYNOMIAL: - res = new PolynomialKernel(params.degree, gamma, coef0, handle); - break; - case TANH: res = new TanhKernel(gamma, coef0, handle); break; - case RBF: res = new RBFKernel(gamma, handle); break; - default: throw raft::exception("Kernel not implemented"); - } - return res; - } -}; - -}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/include/cuvs/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/cuvs/distance/detail/kernels/kernel_matrices.cuh deleted file mode 100644 index 1f9db896e..000000000 --- a/cpp/include/cuvs/distance/detail/kernels/kernel_matrices.cuh +++ /dev/null @@ -1,777 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "gram_matrix.cuh" -#include - -#include -#include -#include -#include -#include - -namespace cuvs::distance::kernels::detail { - -/** Epiloge function for polynomial kernel without padding. - * Calculates output = (gain*in + offset)^exponent - * @param inout device vector in column major format, size [len] - * @param len array length - * @param exponent - * @param gain - * @param offset - */ -template -RAFT_KERNEL polynomial_kernel_nopad( - math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; - tid += blockDim.x * gridDim.x) { - inout[tid] = pow(gain * inout[tid] + offset, exponent); - } -} - -/** Epiloge function for polynomial kernel with padding. - * Calculates output = (gain*input + offset)^exponent - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param exponent - * @param gain - * @param offset - */ -template -RAFT_KERNEL polynomial_kernel( - math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = pow(gain * inout[tidx + tidy * ld] + offset, exponent); - } -} - -/** Epiloge function for tanh kernel without padding. - * Calculates output = tanh(gain*input + offset) - * @param inout device vector, size [len] - * @param len length of the input vector - * @param gain - * @param offset - */ -template -RAFT_KERNEL tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; - tid += blockDim.x * gridDim.x) { - inout[tid] = tanh(gain * inout[tid] + offset); - } -} - -/** Epiloge function for tanh kernel without padding. - * Calculates output = tanh(gain*input + offset) - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param gain - * @param offset - */ -template -RAFT_KERNEL tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = tanh(gain * inout[tidx + tidy * ld] + offset); - } -} - -/** Epiloge function for rbf kernel using expansion. - * - * Calculates output_ij = exp(-gain * (norm_x_i + norm_y_j - 2*input_ij)); - * - * Intended usage - * - input is the product of two matrices X and Y input_ij = sum_k X_ik * Y_jk - * - norm_x_i = l2_norm(x_i), where x_i is the i-th row of matrix X - * - norm_y_j = l2_norm(y_j), where y_j is the j-th row of matrix Y - * - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param norm_x l2-norm of X's rows - * @param norm_y l2-norm of Y's rows - * @param gain - */ -template -RAFT_KERNEL rbf_kernel_expanded( - math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) { - math_t norm_y_val = norm_y[tidy]; - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = - exp(-1.0 * gain * (norm_x[tidx] + norm_y_val - inout[tidx + tidy * ld] * 2)); - } - } -} - -namespace { -std::tuple generateLaunchConfig2dElementwiseOp(int n1, int n2) -{ - dim3 block_shape = dim3(32, 4); - const int num_blocks_x = raft::ceildiv(n1, 32); - const int num_blocks_y = std::min(raft::ceildiv(n2, 32), (1 << 16) - 1); - dim3 grid_shape = dim3(num_blocks_x, num_blocks_y); - return std::make_tuple(grid_shape, block_shape); -} -} // namespace - -/** - * Create a kernel matrix using polynomial kernel function. - */ -template -class PolynomialKernel : public GramMatrixBase { - exp_t exponent; - math_t gain; - math_t offset; - - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) - { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - polynomial_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, exponent, gain, offset); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - polynomial_kernel<<>>( - inout, ld, n1, n2, exponent, gain, offset); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - public: - /** - * Constructs a polynomial kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = (gain* + offset)^exponent - * - * @tparam math_t floating point type - * @tparam exp_t type of exponent - * @param exponent - * @param gain - * @param offset - */ - PolynomialKernel(exp_t exponent, math_t gain, math_t offset) - : GramMatrixBase(), exponent(exponent), gain(gain), offset(offset) - { - } - - [[deprecated]] PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t handle) - : GramMatrixBase(handle), exponent(exponent), gain(gain), offset(offset) - { - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - GramMatrixBase::linear( - x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - applyKernel(out, ld_out, n1, n2, is_row_major, stream); - } -}; - -/** - * Create a kernel matrix using tanh kernel function. - */ -template -class TanhKernel : public GramMatrixBase { - math_t gain, offset; - - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) - { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - tanh_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, gain, offset); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - tanh_kernel<<>>(inout, ld, n1, n2, gain, offset); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - public: - /** - * Constructs a tanh kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = tanh(gain* + offset) - * - * @tparam math_t floating point type - * @param gain - * @param offset - */ - TanhKernel(math_t gain, math_t offset) : GramMatrixBase(), gain(gain), offset(offset) {} - - [[deprecated]] TanhKernel(math_t gain, math_t offset, cublasHandle_t handle) - : GramMatrixBase(handle), gain(gain), offset(offset) - { - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - GramMatrixBase::linear( - x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - applyKernel(out, ld_out, n1, n2, is_row_major, stream); - } -}; - -/** - * Create a kernel matrix using RBF kernel function. - */ -template -class RBFKernel : public GramMatrixBase { - math_t gain; - - void applyKernel(math_t* inout, - int ld, - int rows, - int cols, - math_t* norm_x1, - math_t* norm_x2, - bool is_row_major, - cudaStream_t stream) - { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1; - math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - rbf_kernel_expanded<<>>( - inout, ld, n1, n2, norm_n1, norm_n2, gain); - } - - public: - /** - * Constructs a RBF kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = exp(-gain*|x1_i- x2_k|^2) - * - * @tparam math_t floating point type - * @param gain - */ - RBFKernel(math_t gain) : GramMatrixBase(), gain(gain) {} - - [[deprecated]] RBFKernel(math_t gain, cublasHandle_t handle) - : GramMatrixBase(handle), gain(gain) - { - } - - void matrixRowNormL2(raft::resources const& handle, - dense_input_matrix_view_t matrix, - math_t* target) - { - bool is_row_major = GramMatrixBase::get_is_row_major(matrix); - int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); - int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); - ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); - raft::linalg::rowNorm(target, - matrix.data_handle(), - matrix.extent(1), - matrix.extent(0), - raft::linalg::NormType::L2Norm, - is_row_major, - resource::get_cuda_stream(handle)); - } - - void matrixRowNormL2(raft::resources const& handle, - csr_input_matrix_view_t matrix, - math_t* target) - { - auto matrix_structure = matrix.structure_view(); - raft::sparse::linalg::rowNormCsr(handle, - matrix_structure.get_indptr().data(), - matrix.get_elements().data(), - matrix_structure.get_nnz(), - matrix_structure.get_n_rows(), - target, - raft::linalg::NormType::L2Norm); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.extent(0), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.extent(0), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.extent(0), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.structure_view().get_n_rows(), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - int minor1 = is_row_major ? n_cols : n1; - int minor2 = is_row_major ? n_cols : n2; - int minor_out = is_row_major ? n2 : n1; - ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); - ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); - ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); - - math_t gain = this->gain; - using index_t = int64_t; - - rbf_fin_op fin_op{gain}; - - raft::resources handle; - resource::set_cuda_stream(handle, stream); - - cuvs::distance::distance(handle, - const_cast(x1), - const_cast(x2), - out, - n1, - n2, - n_cols, - NULL, - 0, - fin_op, - is_row_major); - } -}; - -}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/include/cuvs/distance/detail/kernels/rbf_fin_op.cuh b/cpp/include/cuvs/distance/detail/kernels/rbf_fin_op.cuh deleted file mode 100644 index 73588baea..000000000 --- a/cpp/include/cuvs/distance/detail/kernels/rbf_fin_op.cuh +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -/* - * This file defines rbf_fin_op, which is used in GramMatrixBase. - * - * This struct has been moved to a separate file, so that it is cheap to include - * in distance/distance-ext.cuh, where an instance of cuvs::distance::distance - * with the rbf_fin_op is instantiated. - * - */ - -#include // raft::exp -#include // HD - -namespace cuvs::distance::kernels::detail { - -/** @brief: Final op for Gram matrix with RBF kernel. - * - * Calculates output = e^(-gain * in) - * - */ -template -struct rbf_fin_op { - OutT gain; - - explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} - - template - HDI OutT operator()(OutT d_val, Args... unused_args) - { - return raft::exp(-gain * d_val); - } -}; // struct rbf_fin_op - -} // namespace cuvs::distance::kernels::detail diff --git a/cpp/include/cuvs/distance/detail/masked_distance_base.cuh b/cpp/include/cuvs/distance/detail/masked_distance_base.cuh deleted file mode 100644 index 0c8db755b..000000000 --- a/cpp/include/cuvs/distance/detail/masked_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -#include - -namespace cuvs { -namespace distance { -namespace detail { - -/** - * @brief Device class for masked nearest neighbor computations. - * - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for x and y matrices) - * @tparam AccT accumulation data-type - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @tparam rowEpilogueLambda epilog lambda that executes when a full row has - * been processed. - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of x - * @param[in] n number of columns of y - * @param[in] k number of cols of x and y - * @param[in] lda leading dimension of x - * @param[in] ldb leading dimension of y - * @param[in] ldd parameter to keep Contractions_NT happy.. - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `(m / 64) x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups The number of groups in group_idxs. - * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. - * @param core_op the core accumulation operation lambda - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed. - */ -template > -struct MaskedDistances : public BaseClass { - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - const uint64_t* adj; - const IdxT* group_idxs; - IdxT num_groups; - char* smem; - CoreLambda core_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI MaskedDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - const uint64_t* _adj, - const IdxT* _group_idxs, - IdxT _num_groups, - char* _smem, - CoreLambda _core_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - adj(_adj), - group_idxs(_group_idxs), - num_groups(_num_groups), - smem(_smem), - core_op(_core_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) - { - } - - DI void run() - { - const auto grid_stride_m = (P::Mblk * gridDim.y); - const auto grid_offset_m = (P::Mblk * blockIdx.y); - - const auto grid_stride_g = gridDim.x; - const auto grid_offset_g = blockIdx.x; - - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - // Start loop over groups - for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { - const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); - // block_adj is a bitfield that contains a 1 if a row is adjacent to the - // current group. All zero means we can skip this group. - if (block_adj == 0) { continue; } - - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). That is, - // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: - // - // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. - // - // We precompute this information because it is used in various - // locations to skip thread-local computations, specifically: - // - // 1. To skip computations if thread_adj == 0, i.e., none of the values - // of `acc` have to be computed. - // - // 2. In epilog_op, to consider only values of `acc` to be reduced that - // are not masked of. - // - // Note 1: Even when the computation can be skipped for a specific thread, - // the thread still participates in synchronization operations. - // - // Note 2: In theory, it should be possible to skip computations for - // specific rows of `acc`. In practice, however, this does not improve - // performance. - int thread_adj = compute_thread_adjacency(block_adj); - - auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; - const auto group_end_n = group_idxs[idx_g]; - for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { - // We provide group_end_n to limit the number of unnecessary data - // points that are loaded from y. - this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); - - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - if (thread_adj != 0) { accumulate(); } - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - if (thread_adj != 0) { - accumulate(); // last iteration - } - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer - // back so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - if (useNorms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); - if (thread_adj != 0) { - epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); - } - } else { - if (thread_adj != 0) { - epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); - } - } - } // tile_idx_n - } // idx_g - rowEpilog_op(tile_idx_m); - } // tile_idx_m - } - - private: - DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) - { - // A single element of `adj` contains exactly enough bits to indicate which - // rows in the current tile to skip and which to compute. - static_assert(P::Mblk == 8 * sizeof(adj[0]), - "masked_l2_nn only supports a policy with 64 rows per block."); - IdxT block_flag_idx = tile_idx_m / P::Mblk; - // Index into adj at row tile_idx_m / 64 and column idx_group. - return adj[block_flag_idx * this->num_groups + idx_group]; - } - - DI uint32_t compute_thread_adjacency(const uint64_t block_adj) - { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the run() method. - uint32_t thread_adj = 0; -#pragma unroll - for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { - // Index `thread_row_idx` refers to a row of the current threads' register - // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the - // corresponding row of the current block tile in shared memory. - const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; - - // block_row_is_adjacent is true if the current block_row_idx is adjacent - // to the current group. - const uint64_t block_mask = 1ull << block_row_idx; - const bool block_row_is_adjacent = (block_adj & block_mask) != 0; - if (block_row_is_adjacent) { - // If block row is adjacent, write a 1 bit to thread_adj at location - // `thread_row_idx`. - const uint32_t thread_mask = 1 << thread_row_idx; - thread_adj |= thread_mask; - } - } - return thread_adj; - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - } - - DI void accumulate() - { -#pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } - } - } - } - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - IdxT end_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < end_n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } -}; // struct MaskedDistances - -}; // namespace detail -}; // namespace distance -}; // namespace cuvs diff --git a/cpp/include/cuvs/distance/detail/masked_nn.cuh b/cpp/include/cuvs/distance/detail/masked_nn.cuh deleted file mode 100644 index 8b30d8eec..000000000 --- a/cpp/include/cuvs/distance/detail/masked_nn.cuh +++ /dev/null @@ -1,327 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace cuvs { -namespace distance { -namespace detail { - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL masked_l2_nn_kernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const uint64_t* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - bool sqrt, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - CoreLambda core_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - - typedef raft::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - int thread_adj, - DataT* regxn, - DataT* regyn, - IdxT tile_idx_n, - IdxT tile_idx_m, - IdxT tile_end_n) { - KVPReduceOpT pairRed_op(pairRedOp); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the maskedDistances.run() method. - const bool ignore = (thread_adj & (1 << i)) == 0; - if (ignore) { continue; } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; - if (tile_end_n <= tmpkey) { - // Do not process beyond end of tile. - continue; - } - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < tile_end_n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - MaskedDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - adj, - group_idxs, - num_groups, - smem, - core_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -/** - * @brief Wrapper for masked_l2_nn_kernel - * - * Responsibilities: - * - Allocate (and initialize) workspace memory for: - * - mutexes used in nearest neighbor update step - * - adjacency matrix bitfield - * - Compress adjacency matrix to bitfield - * - Initialize output buffer (conditional on `initOutBuffer`) - * - Specify core and final operations for the L2 norm - * - Determine optimal launch configuration for kernel. - * - Launch kernel and check for errors. - * - * @tparam DataT Input data-type (for x and y matrices). - * @tparam OutT Output data-type (for key-value pairs). - * @tparam IdxT Index data-type. - * @tparam ReduceOpT A struct to perform the final needed reduction - * operation and also to initialize the output array - * elements with the appropriate initial value needed for - * reduction. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * @param handle RAFT handle for managing expensive resources - * @param[out] out Will contain reduced output (nn key-value pairs) - * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) - * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) - * @param[in] xn L2 squared norm of `x`. Length = `m`. - * @param[in] yn L2 squared norm of `y`. Length = `n`. - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups Length of `group_idxs`. - * @param m Rows of `x`. - * @param n Rows of `y`. - * @param k Cols of `x` and `y`. - * @param redOp Reduction operator in the epilogue - * @param pairRedOp Reduction operation on key value pairs - * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. - * @param initOutBuffer Whether to initialize the output buffer - * - * - */ -template -void masked_l2_nn_impl(raft::resources const& handle, - OutT* out, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const bool* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer) -{ - typedef typename linalg::Policy4x4::Policy P; - - static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block."); - - // Get stream and workspace memory resource - rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(raft::resource::get_workspace_resource(handle)); - auto stream = resource::get_cuda_stream(handle); - - // Acquire temporary buffers and initialize to zero: - // 1) Adjacency matrix bitfield - // 2) Workspace for fused nearest neighbor operation - size_t m_div_64 = raft::ceildiv(m, IdxT(64)); - rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; - rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; - RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); - - // Compress boolean adjacency matrix to bitfield. - auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); - auto adj64_view = - raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); - compress_to_bits(handle, adj_view, adj64_view); - - // Initialize output buffer with keyvalue pairs as determined by the reduction - // operator (it will be called with maxVal). - constexpr auto maxVal = std::numeric_limits::max(); - if (initOutBuffer) { - dim3 grid(raft::ceildiv(m, P::Nthreads)); - dim3 block(P::Nthreads); - - initKernel<<>>(out, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = raft::identity_op{}; - - auto kernel = masked_l2_nn_kernel; - constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 block(P::Nthreads); - dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); - - kernel<<>>(out, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -} // namespace detail -} // namespace distance -} // namespace cuvs diff --git a/cpp/include/cuvs/distance/detail/pairwise_distance_base.cuh b/cpp/include/cuvs/distance/detail/pairwise_distance_base.cuh deleted file mode 100644 index 57366dec9..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include // raft::linalg::Contractions_NT -#include // ceildiv -#include // RAFT_CUDA_TRY - -#include // size_t - -namespace cuvs { -namespace distance { -namespace detail { - -/** - * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam OpT A distance operation, e.g., cosine_distance_op. - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[output] pD output matrix - * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param distance_op the distance operation, e.g. cosine_distance_op - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed - */ - -template > -struct PairwiseDistances : public BaseClass { - // Get accumulation type from distance_op - using AccT = typename OpT::AccT; - - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - OutT* dOutput; - char* smem; - OpT distance_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - const IdxT grid_stride_m; - const IdxT grid_stride_n; - const IdxT grid_offset_m; - const IdxT grid_offset_n; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI PairwiseDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - OutT* _dOutput, - char* _smem, - OpT _distance_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - dOutput(_dOutput), - smem(_smem), - distance_op(_distance_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op), - grid_stride_m(P::Mblk * gridDim.y), - grid_stride_n(P::Nblk * gridDim.x), - grid_offset_m(P::Mblk * blockIdx.y), - grid_offset_n(P::Nblk * blockIdx.x) - { - } - - DI void run() - { - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - this->ldgXY(tile_idx_m, grid_offset_n, 0); - for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { - // Prolog: - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - // Main loop: - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - accumulate(); // last iteration - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer back - // so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - // Epilog: - if (distance_op.use_norms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, regxn, regyn); - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - // Calculate distance_op epilog. - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); - // And any possible additional epilogs - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - // Calculate distance_op epilog. - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - // And any possible additional epilogs - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - } - if (writeOut) { store_output(tile_idx_m, tile_idx_n); } - } - rowEpilog_op(tile_idx_m); - } - } - - private: - DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) - { - // Fetch next grid stride ldg if within range - const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; - const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; - if ((next_tile_tile_idx_n) < this->n) { - this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); - } else if ((next_tile_tile_idx_m) < this->m) { - this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); - } - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - } - - DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], - DataT (®_y)[P::AccColsPerTh][P::Veclen]) - { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); - } - } - } - } - - DI void accumulate() - { - // We have a separate raft::ldsXY and accumulate_reg_tile outside the loop body, - // so that these separated calls can be interspersed with preceding and - // following instructions, thereby hiding latency. - this->ldsXY(0); - - // If expensive inner loop, do not unroll loop. - constexpr int num_iterations = P::Kblk / P::Veclen - 1; - constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; -#pragma unroll unroll_count - for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { - accumulate_reg_tile(this->regx, this->regy); - this->ldsXY(ki); - } - - // Accumulate last loaded tile. - accumulate_reg_tile(this->regx, this->regy); - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } - - DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) - { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } - } - } - } -}; // struct PairwiseDistances - -template -dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) -{ - int devId; - RAFT_CUDA_TRY(cudaGetDevice(&devId)); - int numSMs; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); - - int numBlocksPerSm = 0; - dim3 grid; - - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, func, P::Nthreads, sMemSize)); - std::size_t minGridSize = numSMs * numBlocksPerSm; - std::size_t yChunks = raft::ceildiv(m, P::Mblk); - std::size_t xChunks = raft::ceildiv(n, P::Nblk); - grid.y = yChunks > minGridSize ? minGridSize : yChunks; - grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; - if (grid.x != 1) { - std::size_t i = 1; - while (grid.y * i < minGridSize) { - i++; - } - grid.x = i >= xChunks ? xChunks : i; - } - - return grid; -} - -}; // namespace detail -}; // namespace distance -}; // namespace cuvs diff --git a/cpp/include/cuvs/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/cuvs/distance/detail/pairwise_distance_cutlass_base.cuh deleted file mode 100644 index b9dd49977..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_distance_cutlass_base.cuh +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#pragma GCC diagnostic ignored "-Wtautological-compare" - -// We define CUTLASS_NAMESPACE in case -// RAFT cmake is not used -#ifndef CUTLASS_NAMESPACE -#define cutlass raft_cutlass -#endif - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -#include "./pairwise_distance_epilogue_elementwise.h" -#include "./pairwise_distance_gemm.h" - -namespace cuvs { -namespace distance { -namespace detail { - -template -std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) -{ - static_assert(!(std::is_same::value), - "OutType bool is not supported use uint8_t instead"); - - auto dist_op = distance_op.get_cutlass_op(); - using DistanceFn = decltype(dist_op); - using EpilogueOutputOp = - cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; - constexpr int batch_count = 1; - - constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); - - const DataT *a, *b; - - IdxT gemm_lda, gemm_ldb; - - // Number of pipelines you want to use - constexpr int NumStages = 3; - // Alignment - constexpr int Alignment = VecLen; - - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(n, m, k); - - using cutlassDistKernel = - typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; - - using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - - if constexpr (isRowMajor) { - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - } else { - problem_size = cutlass::gemm::GemmCoord(m, n, k); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; - } - - typename cutlassDist::Arguments arguments{ - mode, problem_size, batch_count, epilog_op_param, a, b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); - - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); - - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op(stream)); -} - -}; // namespace detail -}; // namespace distance -}; // namespace cuvs - -#pragma GCC diagnostic pop diff --git a/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue.h b/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue.h deleted file mode 100644 index 06b83ace9..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue.h +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) - -This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec -and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise -operation. --- A norm load is provided PredicatedTileIteratorNormVec --- B norm load is provided by EpilogueWithBroadcast --- elementwise operation is provided by OutputOp -*/ - -#pragma once - -#include -#include -#include - -#include - -#include "./predicated_tile_iterator_normvec.h" -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template -struct PairwiseDistanceEpilogue { - /// Use defaults related to the existing epilogue - using Base = - DefaultEpilogueTensorOp; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock:: - PredicatedTileIteratorNormVec; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using TensorTileIterator = - cutlass::epilogue::threadblock::PredicatedTileIterator; - - /// Define the epilogue - using Epilogue = EpilogueWithBroadcast; -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue_elementwise.h deleted file mode 100644 index 9004bd2c7..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_distance_epilogue_elementwise.h +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -/*! \file - \brief Functor performing distance operations used by epilogues of pairwise distance - * kernels. -* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 -* customized for applying elementwise distance formula on accumulated GEMM value -* and applying user-defined final custom operation on the distance value. -*/ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -class PairwiseDistanceEpilogueElementwise { - public: - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using DistanceOp = DistanceOp_; - using FinalOp = FinalOp_; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - using FragmentOutput = FragmentZ; - - static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = false; // We don't store anything in Z, - - /// If true, the 'T' tensor is stored - static bool const kStoreT = true; // this is our final output storage. - - /// Host-constructable parameters structure - struct Params { - FinalOp_ final_op_; - DistanceOp_ dist_op_; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} - - CUTLASS_HOST_DEVICE - Params() {} - }; - - private: - // - // Data members - // - FinalOp_ final_op; - DistanceOp_ elementwise_op; - - public: - // - // Methods - // - - /// Constructor from Params - CUTLASS_HOST_DEVICE - PairwiseDistanceEpilogueElementwise(Params const& params) - : final_op(params.final_op_), elementwise_op(params.dist_op_) - { - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const - { - // we use for making sure C matrix path is used for A mat norm. - return true; - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_C = - NumericArrayConverter()(frag_C); - FragmentCompute result_Z; - FragmentCompute result_T; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - result_T[i] = final_op(result_Z[i], 0); - } - - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/detail/pairwise_distance_gemm.h b/cpp/include/cuvs/distance/detail/pairwise_distance_gemm.h deleted file mode 100644 index 2c88d8b70..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_distance_gemm.h +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include -#include - -#include "./pairwise_distance_epilogue.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Element type for final output - // typename ElementOutT, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct PairwiseDistanceGemm { - // This struct is specialized for fp32/3xTF32 - - /// Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = - cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; -}; - -template < - /// Layout type for A matrix operand - int kAlignmentA, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct PairwiseDistanceGemm { - // using Transform = cutlass::ComplexTransform::kNone; - // Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = - cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - // Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAdd; - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-ext.cuh deleted file mode 100644 index efaebb379..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // ops::* -#include // ops::has_cutlass_op -#include // rbf_fin_op -#include // pairwise_matrix_params -#include // raft::identity_op -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::distance::detail { - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) RAFT_EXPLICIT; - -}; // namespace cuvs::distance::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ - OpT, DataT, AccT, OutT, FinOpT, IdxT) \ - extern template void cuvs::distance::detail:: \ - pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ - OpT distance_op, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - const DataT* x, \ - const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ - OutT* out, \ - FinOpT fin_op, \ - cudaStream_t stream, \ - bool is_row_major) - -/* - * Hierarchy of instantiations: - * - * This file defines extern template instantiations of the distance kernels. The - * instantiation of the public API is handled in cuvs/distance/distance-ext.cuh. - * - * After adding an instance here, make sure to also add the instance there. - */ - -// The following two instances are used in the RBF kernel object. Note the use of int64_t for the -// index type. -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, - float, - float, - float, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); - -// Rest of instances -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::canberra_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::correlation_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::correlation_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::hellinger_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::jensen_shannon_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::jensen_shannon_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::lp_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - cuvs::distance::detail::ops::russel_rao_distance_op, - double, - double, - double, - raft::identity_op, - int); - -#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-inl.cuh deleted file mode 100644 index ca011731e..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/detail/pairwise_matrix/. Previously, - * cuvs::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The cuvs::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. - */ - -#include // ops::has_cutlass_op -#include // dispatch_sm60 -#include // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in cuvs/distance/detail/distance.cuh -// and src/distance/detail/pairwise_matrix/dispatch_*.cu. - -namespace cuvs::distance::detail { - -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - - // Dispatch rule: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel below SM_80 - namespace arch = raft::util::arch; - - constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - - if constexpr (cutlass_op_unavailable) { - // Always execute legacy kernels when no cutlass op is available - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch.cuh deleted file mode 100644 index 4a52b7ebe..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "dispatch-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "dispatch-ext.cuh" -#endif diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_layout.cuh deleted file mode 100644 index 2e9004b56..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // std::min -#include // size_t -#include // pairwise_matrix_params -#include // RAFT_EXPECTS -#include // std::integral_constant -namespace cuvs::distance::detail { - -/** - * @brief: Computes minimal common alignment of the rows in a 2D array in bytes - * - * The 2D matrix `x` is assumed to be row-major. This function computes the - * minimal alignment in bytes of the first elements of each row. - * Output can be 16, 8, 4, 2, 1. - * - * @param x Base pointer of row-major input matrix - * @param stride Stride in number of element between consecutive rows. - */ -template -size_t alignment_of_2d_array(const DataT* x, size_t stride) -{ - auto base = reinterpret_cast(x); - size_t stride_bytes = sizeof(DataT) * stride; - - for (int align = 16; align >= 0; align /= 2) { - bool base_aligned = base % align == 0; - bool stride_aligned = stride_bytes % align == 0; - if (base_aligned && stride_aligned) { return align; } - } - return 1; -} - -/** - * @brief: Computes the vec_len parameter kernel policy parameter - * - * @param params Kernel parameters - */ -template -int determine_vec_len(pairwise_matrix_params params) -{ - size_t align_x = alignment_of_2d_array(params.x, params.ldx); - size_t align_y = alignment_of_2d_array(params.y, params.ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; - - // In the future, pairwise_matrix might support `int8_t` input. In that case, - // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to - // prevent adding more cases in dispatch_layout below (which are expensive to - // compile). - vec_len_aligned = std::min(vec_len_aligned, 4); - - return vec_len_aligned; -} - -template -using vec_len_constant = std::integral_constant; - -/** - * @brief: Converts run-time arguments to compile-time arguments - * - * Converts run-time arguments row_major and vec_len to compile-time arguments - * and dispatches a lambda f with these compile-time arguments. - * - * This is equivalent to copying and pasting the lambda function `f` in each of - * the switch case statements. - * - * @tparam F Type of lambda f. - * @param row_major Boolean indicating whether input arrays have row-major layout. - * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of - * the KernelPolicy. - * @param f Lambda that takes two std::integral_constant parameters representing - * row_major and vec_len. - */ -template -auto dispatch_layout(bool row_major, int vec_len, F&& f) -{ - if (row_major) { - switch (vec_len) { - case 4: return f(std::true_type(), vec_len_constant<4>()); - case 2: return f(std::true_type(), vec_len_constant<2>()); - default: return f(std::true_type(), vec_len_constant<1>()); - } - } else { - switch (vec_len) { - case 4: return f(std::false_type(), vec_len_constant<4>()); - case 2: return f(std::false_type(), vec_len_constant<2>()); - default: return f(std::false_type(), vec_len_constant<1>()); - } - } -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm60.cuh deleted file mode 100644 index 9f9ed1cad..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // std::min -#include // dispatch_layout -#include // pairwise_matrix_sm60_wrapper -#include // raft::linalg::Policy4x4 - -namespace cuvs::distance::detail { - -template -pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and returns - // the corresponding kernel wrapper. The wrapper contains the launch - // parameters of the kernel: a pointer to the kernel function, grid size, - // block size, and shared memory size. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - - using RowPolicy = typename raft::linalg::Policy4x4::Policy; - using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; - using Policy = typename std::conditional::type; - - auto wrapper = - make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); - - return wrapper; - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - return dispatch_layout(params.is_row_major, vec_len, f); -} - -template -void pairwise_matrix_sm60_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); - - wrapper.launch(distance_op, params, stream); -} - -} // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm80.cuh deleted file mode 100644 index ccff73658..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/dispatch_sm80.cuh +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // std::min -#include // cutlassDistanceKernel -#include // dispatch_layout - -namespace cuvs::distance::detail { - -template -void pairwise_matrix_sm80_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and runs the - // corresponding cutlass launch code. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); - - using AccT = typename OpT::AccT; - cutlassDistanceKernel(params.x, - params.y, - params.x_norm, - params.y_norm, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.out, - params.fin_op, - distance_op, - stream); - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - dispatch_layout(params.is_row_major, vec_len, f); -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/kernel_sm60.cuh deleted file mode 100644 index baea4830e..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // assert -#include // PairwiseDistances -#include // pairwise_matrix_params -#include // raft::void_op -#include // raft::util::arch::SM_compute_arch - -namespace cuvs::distance::detail { - -template -__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL - pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) -{ - // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { - assert(false); - return; - } - - extern __shared__ char smem[]; - - // The epilog is already provided by distance_op. Do not provide additional - // epilogs. - auto epilog_op = raft::void_op(); - // No support for row_epilog_op. - auto row_epilog_op = raft::void_op(); - - // Always write output - constexpr bool write_out = true; - constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances - obj(params.x, - params.y, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.x_norm, - params.y_norm, - params.out, - smem, - distance_op, - epilog_op, - params.fin_op, - row_epilog_op); - obj.run(); -} - -// The type of a pointer to the pairwise matrix kernel. The following template -// arguments are type-erased: -// -// - The kernel policy -// - row_major -// - SM_compat_t -template -using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); - -// A wrapper for the pairwise matrix kernel launch. Includes kernel launch -// parameters. -template -struct pairwise_matrix_sm60_wrapper { - dim3 grid; - dim3 block; - int smem_size; - pairwise_matrix_kernel_t kernel_ptr; - - void launch(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) - { - kernel_ptr<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -/** @brief: Create kernel launch wrapper for pairwise matrix kernel - * - * This can be used to type-erase the kernel execution policy, row_major, and SM - * compatibility range. - * - * @tparam Policy: Kernel execution policy - * @tparam row_major: Indicates whether input matrices are row major - * @tparam OpT: Type of distance operation - * @tparam IdxT: Index type - * @tparam DataT: Data type - * @tparam OutT: Output data type - * @tparam FinOpT: Final operation type - * @tparam SM_compat_t: Type of the SM architecture compatibility - * - * @param distance_op: Distance operation - * @param params: Parameters - * @param sm_compat_range: Which SM architectures to compile for. - */ -template -pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - dim3 block(Policy::Nthreads); - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = OpT::template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - return pairwise_matrix_sm60_wrapper{ - grid, block, smem_size, kernel}; -} - -}; // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/pairwise_matrix/params.cuh b/cpp/include/cuvs/distance/detail/pairwise_matrix/params.cuh deleted file mode 100644 index aa419aca0..000000000 --- a/cpp/include/cuvs/distance/detail/pairwise_matrix/params.cuh +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace cuvs::distance::detail { - -template -struct pairwise_matrix_params { - IdxT m; - IdxT n; - IdxT k; - IdxT ldx; - IdxT ldy; - IdxT ld_out; - const DataT* x; - const DataT* y; - const DataT* x_norm; - const DataT* y_norm; - OutT* out; - FinOpT fin_op; - bool is_row_major; - - /// @brief: Flips the x and y input and corresponding sizes - void flip_x_and_y() - { - // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. - std::swap(m, n); - std::swap(ldx, ldy); - std::swap(x, y); - std::swap(x_norm, y_norm); - } -}; - -} // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/cuvs/distance/detail/predicated_tile_iterator_normvec.h deleted file mode 100644 index 951f8a013..000000000 --- a/cpp/include/cuvs/distance/detail/predicated_tile_iterator_normvec.h +++ /dev/null @@ -1,585 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- Only the row index is used to load the data in load_with_byte_offset(). - This way the same normalization data is used across all columns in a row. - -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorNormVec { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t* byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - - /// Scatter indices - int const* indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - - private: - // - // Methods - // - - public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorNormVec(PredicatedTileIteratorParams const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset = TensorCoord(), - int const* indices = nullptr) - : params_(params), indices_(indices) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = - ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { mask_.clear(); } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - if (column == 0) { - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); - } else { - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType const* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void downsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - - int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + - (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void upsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - int row_add_P = add_P; - int row_add_Q = add_Q; - if (output_P > convolution_P - 2) row_add_P = 0; - if (output_Q > convolution_Q - 2) row_add_Q = 0; - - int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + - ((output_P + row_add_P) / 2) * (convolution_Q / 2) + - (output_Q + row_add_Q) / 2; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorNormVec& operator++() - { - ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/cuvs/distance/distance-ext.cuh b/cpp/include/cuvs/distance/distance-ext.cuh deleted file mode 100644 index fdbe6a971..000000000 --- a/cpp/include/cuvs/distance/distance-ext.cuh +++ /dev/null @@ -1,1065 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // rbf_fin_op -#include // cuvs::distance::DistanceType -#include // raft::device_matrix_view -#include // raft::identity_op -#include // raft::resources -#include // RAFT_EXPLICIT -#include // rmm::device_uvector - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs { -namespace distance { - -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; - -template -size_t getWorkspaceSize(raft::device_matrix_view const& x, - raft::device_matrix_view const& y) RAFT_EXPLICIT; - -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - rmm::device_uvector& workspace, - cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - cuvs::distance::DistanceType metric, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -}; // namespace distance -}; // namespace cuvs - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -/* - * Hierarchy of instantiations: - * - * This file defines the extern template instantiations for the public API of - * cuvs::distance. To improve compile times, the extern template instantiation - * of the distance kernels is handled in - * distance/detail/pairwise_matrix/dispatch-ext.cuh. - * - * After adding an instance here, make sure to also add the instance to - * dispatch-ext.cuh and the corresponding .cu files. - */ - -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) - -// The following two instances are used in test/distance/gram.cu. Note the use -// of int64_t for the index type. -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - float, - float, - float, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - cuvs::distance::kernels::detail::rbf_fin_op, - int64_t); - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); - -#undef instantiate_raft_distance_distance - -// Same, but without raft::identity_op -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -// Same, but without workspace -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - extern template size_t cuvs::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ - extern template size_t cuvs::distance::getWorkspaceSize( \ - raft::device_matrix_view const& x, \ - raft::device_matrix_view const& y) - -// We could consider not taking template parameters for this function. The -// number of instantiations seems a bit excessive.. -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ - extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - DataT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - rmm::device_uvector& workspace, \ - cuvs::distance::DistanceType metric, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Same, but without workspace -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ - extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - DataT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - cuvs::distance::DistanceType metric, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Version with mdspan -#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ - extern template void cuvs::distance::distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - DataT metric_arg) - -// Again, we might want to consider reigning in the number of instantiations... -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ - extern template void cuvs::distance::pairwise_distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - cuvs::distance::DistanceType metric, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); - -#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/cuvs/distance/distance-inl.cuh b/cpp/include/cuvs/distance/distance-inl.cuh deleted file mode 100644 index 0abdeacff..000000000 --- a/cpp/include/cuvs/distance/distance-inl.cuh +++ /dev/null @@ -1,477 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs { -namespace distance { - -/** - * @defgroup pairwise_distance pointer-based pairwise distance prims - * @{ - */ - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccT and returns the output in OutT. It's signature is - * as follows:

OutT fin_op(AccT in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified DistT doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) -{ - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param x first set of points (size m*k) - * @param y second set of points (size n*k) - * @return number of bytes needed in workspace - * - * @note If the specified DistT doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(raft::device_matrix_view const& x, - raft::device_matrix_view const& y) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - - return getWorkspaceSize( - x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - rmm::device_uvector& workspace, - cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); - }; - - switch (metric) { - case DistanceType::Canberra: - dispatch(std::integral_constant{}); - break; - case DistanceType::CorrelationExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::CosineExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HammingUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HellingerExpanded: - dispatch(std::integral_constant{}); - break; - case cuvs::distance::DistanceType::InnerProduct: - dispatch(std::integral_constant{}); - break; - case DistanceType::JensenShannon: - dispatch(std::integral_constant{}); - break; - case DistanceType::KLDivergence: - dispatch(std::integral_constant{}); - break; - case DistanceType::L1: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Expanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Unexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::Linf: - dispatch(std::integral_constant{}); - break; - case DistanceType::LpUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::RusselRaoExpanded: - dispatch(std::integral_constant{}); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - cuvs::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - pairwise_distance( - handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); -} - -/** @} */ - -/** - * \defgroup distance_mdspan Pairwise distance functions - * @{ - */ - -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * #include - * - * raft::raft::resources handle; - * int n_samples = 5000; - * int n_features = 50; - * - * auto input = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto output = raft::make_device_matrix(handle, n_samples, n_samples); - * - * raft::random::make_blobs(handle, input.view(), labels.view()); - * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; - * cuvs::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); - * @endcode - * - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - DataT metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first matrix of points (size mxk) - * @param y second matrix of points (size nxk) - * @param dist output distance matrix (size mxn) - * @param metric distance metric - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - cuvs::distance::DistanceType metric, - Type metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); - - constexpr auto rowmajor = std::is_same_v; - - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - - pairwise_distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - metric, - rowmajor, - metric_arg); -} - -/** @} */ - -}; // namespace distance -}; // namespace cuvs diff --git a/cpp/include/cuvs/distance/distance.cuh b/cpp/include/cuvs/distance/distance.cuh deleted file mode 100644 index de70cd469..000000000 --- a/cpp/include/cuvs/distance/distance.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "distance-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "distance-ext.cuh" -#endif diff --git a/cpp/include/cuvs/distance/fused_l2_nn-ext.cuh b/cpp/include/cuvs/distance/fused_l2_nn-ext.cuh deleted file mode 100644 index eb993b681..000000000 --- a/cpp/include/cuvs/distance/fused_l2_nn-ext.cuh +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // int64_t -#include // include initialize and reduce operations -#include // raft::KeyValuePair -#include // raft::resources -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs { -namespace distance { - -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) RAFT_EXPLICIT; - -} // namespace distance -} // namespace cuvs - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ - extern template void cuvs::distance::fusedL2NNMinReduce(OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); -instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); -instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); -instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); -instantiate_raft_distance_fusedL2NNMinReduce(double, - raft::KeyValuePair, - int64_t); -instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedL2NNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/cuvs/distance/fused_l2_nn-inl.cuh b/cpp/include/cuvs/distance/fused_l2_nn-inl.cuh deleted file mode 100644 index c6e7acb51..000000000 --- a/cpp/include/cuvs/distance/fused_l2_nn-inl.cuh +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __FUSED_L2_NN_H -#define __FUSED_L2_NN_H - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs { -namespace distance { - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } -} - -/** - * @brief Wrapper around fusedL2NN with minimum reduction operators. - * - * fusedL2NN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * This should be preferred to the more generic API when possible, in order to - * reduce compilation times for users of the shared library. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedL2NN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); -} - -/** @} */ - -} // namespace distance -} // namespace cuvs - -#endif diff --git a/cpp/include/cuvs/distance/fused_l2_nn.cuh b/cpp/include/cuvs/distance/fused_l2_nn.cuh deleted file mode 100644 index b1a355132..000000000 --- a/cpp/include/cuvs/distance/fused_l2_nn.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "fused_l2_nn-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "fused_l2_nn-ext.cuh" -#endif diff --git a/cpp/include/cuvs/distance/fused_l2_nn_helpers.cuh b/cpp/include/cuvs/distance/fused_l2_nn_helpers.cuh deleted file mode 100644 index 29a4ae523..000000000 --- a/cpp/include/cuvs/distance/fused_l2_nn_helpers.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace cuvs::distance { - -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize(raft::resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize( - min, m, maxVal, redOp, resource::get_cuda_stream(handle)); -} - -} // namespace cuvs::distance diff --git a/cpp/include/cuvs/distance/kernels.cuh b/cpp/include/cuvs/distance/kernels.cuh deleted file mode 100644 index 0133892a6..000000000 --- a/cpp/include/cuvs/distance/kernels.cuh +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace cuvs::distance::kernels { - -// TODO: Need to expose formal APIs for this that are more consistent w/ other APIs in RAFT -using cuvs::distance::kernels::detail::GramMatrixBase; -using cuvs::distance::kernels::detail::KernelFactory; - -}; // end namespace cuvs::distance::kernels diff --git a/cpp/include/cuvs/distance/masked_nn.cuh b/cpp/include/cuvs/distance/masked_nn.cuh deleted file mode 100644 index 6f3bde891..000000000 --- a/cpp/include/cuvs/distance/masked_nn.cuh +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __MASKED_L2_NN_H -#define __MASKED_L2_NN_H - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace cuvs { -namespace distance { -/** - * \defgroup masked_nn Masked 1-nearest neighbors - * @{ - */ - -/** - * @brief Parameter struct for masked_l2_nn function - * - * @tparam ReduceOpT Type of reduction operator in the epilogue. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * Usage example: - * @code{.cpp} - * #include - * - * using IdxT = int; - * using DataT = float; - * using RedOpT = cuvs::distance::MinAndDistanceReduceOp; - * using PairRedOpT = cuvs::distance::KVPMinReduce; - * using ParamT = cuvs::distance::masked_l2_nn_params; - * - * bool init_out = true; - * bool sqrt = false; - * - * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - * @endcode - * - * Prescribes how to reduce a distance to an intermediate type (`redOp`), and - * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is - * mapped to an (index, value) pair and (index, value) pair with the lowest - * value (distance) is selected. - * - * In addition, prescribes whether to compute the square root of the distance - * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ -template -struct masked_l2_nn_params { - /** Reduction operator in the epilogue */ - ReduceOpT redOp; - /** Reduction operation on key value pairs */ - KVPReduceOpT pairRedOp; - /** Whether the output `minDist` should contain L2-sqrt */ - bool sqrt; - /** Whether to initialize the output buffer before the main kernel launch */ - bool initOutBuffer; -}; - -/** - * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. - * - * This function enables faster computation of nearest neighbors if the - * computation of distances between certain point pairs can be skipped. - * - * We use an adjacency matrix that describes which distances to calculate. The - * points in `y` are divided into groups, and the adjacency matrix indicates - * whether to compute distances between points in `x` and groups in `y`. In other - * words, if `adj[i,k]` is true then distance between point `x_i`, and points in - * `group_k` will be calculated. - * - * **Performance considerations** - * - * The points in `x` are processed in tiles of `M` points (`M` is currently 64, - * but may change in the future). As a result, the largest compute time - * reduction occurs if all `M` points can skip a group. If only part of the `M` - * points can skip a group, then at most a minor compute time reduction and a - * modest energy use reduction can be expected. - * - * The points in `y` are also grouped into tiles of `N` points (`N` is currently - * 64, but may change in the future). As a result, group sizes should be larger - * than `N` to avoid wasting computational resources. If the group sizes are - * evenly divisible by `N`, then the computation is most efficient, although for - * larger group sizes this effect is minor. - * - * - * **Comparison to SDDM** - * - * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense - * matrix multiplication) is a matrix-matrix multiplication where only part of - * the output is computed. Compared to masked_l2_nn, there are a few differences: - * - * - The output of masked_l2_nn is a single vector (of nearest neighbors) and not - * a sparse matrix. - * - * - The sampling in masked_l2_nn is expressed through intermediate "groups" - rather than a CSR format. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param handle RAFT handle for managing expensive resources - * @param params Parameter struct specifying the reduction operations. - * @param[in] x First matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y Second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[out] out will contain the reduced output (Length = `m`) - * (on device) - */ -template -void masked_l2_nn(raft::resources const& handle, - cuvs::distance::masked_l2_nn_params params, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_vector_view x_norm, - raft::device_vector_view y_norm, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs, - raft::device_vector_view out) -{ - IdxT m = x.extent(0); - IdxT n = y.extent(0); - IdxT k = x.extent(1); - IdxT num_groups = group_idxs.extent(0); - - // Match k dimension of x, y - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - // Match x, x_norm and y, y_norm - RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); - RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); - // Match adj to x and group_idxs - RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); - RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); - // NOTE: We do not check if all indices in group_idxs actually points *inside* y. - - // If there is no work to be done, return immediately. - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } - - detail::masked_l2_nn_impl(handle, - out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - adj.data_handle(), - group_idxs.data_handle(), - num_groups, - m, - n, - k, - params.redOp, - params.pairRedOp, - params.sqrt, - params.initOutBuffer); -} - -/** @} */ - -} // namespace distance -} // namespace cuvs - -#endif diff --git a/cpp/include/cuvs/neighbors/ball_cover-ext.cuh b/cpp/include/cuvs/neighbors/ball_cover-ext.cuh deleted file mode 100644 index b1cd2b4ed..000000000 --- a/cpp/include/cuvs/neighbors/ball_cover-ext.cuh +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // uint32_t -#include // cuvs::distance::DistanceType -#include // BallCoverIndex -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::ball_cover { - -template -void build_index(raft::resources const& handle, - BallCoverIndex& index) RAFT_EXPLICIT; - -template -void all_knn_query(raft::resources const& handle, - BallCoverIndex& index, - int_t k, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) RAFT_EXPLICIT; - -template -void all_knn_query(raft::resources const& handle, - BallCoverIndex& index, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) RAFT_EXPLICIT; - -template -void knn_query(raft::resources const& handle, - const BallCoverIndex& index, - int_t k, - const value_t* query, - int_t n_query_pts, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) RAFT_EXPLICIT; - -template -void knn_query(raft::resources const& handle, - const BallCoverIndex& index, - raft::device_matrix_view query, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::ball_cover - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ - extern template void \ - cuvs::neighbors::ball_cover::build_index( \ - raft::resources const& handle, \ - cuvs::neighbors::ball_cover::BallCoverIndex& index); \ - \ - extern template void \ - cuvs::neighbors::ball_cover::all_knn_query( \ - raft::resources const& handle, \ - cuvs::neighbors::ball_cover::BallCoverIndex& index, \ - int_t k, \ - idx_t* inds, \ - value_t* dists, \ - bool perform_post_filtering, \ - float weight); \ - \ - extern template void \ - cuvs::neighbors::ball_cover::all_knn_query( \ - raft::resources const& handle, \ - cuvs::neighbors::ball_cover::BallCoverIndex& index, \ - raft::device_matrix_view inds, \ - raft::device_matrix_view dists, \ - int_t k, \ - bool perform_post_filtering, \ - float weight); \ - \ - extern template void cuvs::neighbors::ball_cover::knn_query( \ - raft::resources const& handle, \ - const cuvs::neighbors::ball_cover::BallCoverIndex& index, \ - int_t k, \ - const value_t* query, \ - int_t n_query_pts, \ - idx_t* inds, \ - value_t* dists, \ - bool perform_post_filtering, \ - float weight); \ - \ - extern template void \ - cuvs::neighbors::ball_cover::knn_query( \ - raft::resources const& handle, \ - const cuvs::neighbors::ball_cover::BallCoverIndex& index, \ - raft::device_matrix_view query, \ - raft::device_matrix_view inds, \ - raft::device_matrix_view dists, \ - int_t k, \ - bool perform_post_filtering, \ - float weight); - -instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); - -#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/include/cuvs/neighbors/ball_cover-inl.cuh b/cpp/include/cuvs/neighbors/ball_cover-inl.cuh deleted file mode 100644 index 4d0f170df..000000000 --- a/cpp/include/cuvs/neighbors/ball_cover-inl.cuh +++ /dev/null @@ -1,395 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef __BALL_COVER_H -#define __BALL_COVER_H - -#pragma once - -#include - -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::ball_cover { - -/** - * @defgroup random_ball_cover Random Ball Cover algorithm - * @{ - */ - -/** - * Builds and populates a previously unbuilt BallCoverIndex - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * auto metric = cuvs::distance::DistanceType::L2Expanded; - * BallCoverIndex index(handle, X, metric); - * - * ball_cover::build_index(handle, index); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn value type - * @tparam int_t integral type for knn params - * @tparam matrix_idx_t matrix indexing type - * @param[in] handle library resource management handle - * @param[inout] index an empty (and not previous built) instance of BallCoverIndex - */ -template -void build_index(raft::resources const& handle, - BallCoverIndex& index) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == cuvs::distance::DistanceType::Haversine) { - cuvs::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::HaversineFunc()); - } else if (index.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - index.metric == cuvs::distance::DistanceType::L2SqrtUnexpanded) { - cuvs::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::EuclideanFunc()); - } else { - RAFT_FAIL("Metric not support"); - } - - index.set_index_trained(); -} - -/** @} */ // end group random_ball_cover - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::resources const& handle, - BallCoverIndex& index, - int_t k, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == cuvs::distance::DistanceType::Haversine) { - cuvs::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - index.metric == cuvs::distance::DistanceType::L2SqrtUnexpanded) { - cuvs::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } - - index.set_index_trained(); -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * auto metric = cuvs::distance::DistanceType::L2Expanded; - * - * // Construct a ball cover index - * BallCoverIndex index(handle, X, metric); - * - * // Perform all neighbors knn query - * ball_cover::all_knn_query(handle, index, inds, dists, k); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::resources const& handle, - BallCoverIndex& index, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in index matrix."); - - all_knn_query( - handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); -} - -/** @} */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] query the - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - * @param[in] n_query_pts number of query points - */ -template -void knn_query(raft::resources const& handle, - const BallCoverIndex& index, - int_t k, - const value_t* query, - int_t n_query_pts, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == cuvs::distance::DistanceType::Haversine) { - cuvs::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - index.metric == cuvs::distance::DistanceType::L2SqrtUnexpanded) { - cuvs::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * auto metric = cuvs::distance::DistanceType::L2Expanded; - * - * // Build a ball cover index - * BallCoverIndex index(handle, X, metric); - * ball_cover::build_index(handle, index); - * - * // Perform all neighbors knn query - * ball_cover::knn_query(handle, index, inds, dists, k); - * @endcode - - * - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @tparam matrix_idx_t - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[in] query device matrix containing query data points - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void knn_query(raft::resources const& handle, - const BallCoverIndex& index, - raft::device_matrix_view query, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - - RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), - "Number of columns in query and index matrices must match."); - - knn_query(handle, - index, - k, - query.data_handle(), - query.extent(0), - inds.data_handle(), - dists.data_handle(), - perform_post_filtering, - weight); -} - -/** @} */ - -// TODO: implement functions for: -// 4. rbc_eps_neigh() - given a populated index, perform query against different query array -// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data - -} // namespace cuvs::neighbors::ball_cover - -#endif diff --git a/cpp/include/cuvs/neighbors/ball_cover.cuh b/cpp/include/cuvs/neighbors/ball_cover.cuh deleted file mode 100644 index 41c5d0310..000000000 --- a/cpp/include/cuvs/neighbors/ball_cover.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ball_cover-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ball_cover-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/ball_cover_types.hpp b/cpp/include/cuvs/neighbors/ball_cover_types.hpp deleted file mode 100644 index c6e9fab2c..000000000 --- a/cpp/include/cuvs/neighbors/ball_cover_types.hpp +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::ball_cover { - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Stores raw index data points, sampled landmarks, the 1-nns of index points - * to their closest landmarks, and the ball radii of each landmark. This - * class is intended to be constructed once and reused across subsequent - * queries. - * @tparam value_idx - * @tparam value_t - * @tparam value_int - */ -template -class BallCoverIndex { - public: - explicit BallCoverIndex(raft::resources const& handle_, - const value_t* X_, - value_int m_, - value_int n_, - cuvs::distance::DistanceType metric_) - : handle(handle_), - X(raft::make_device_matrix_view(X_, m_, n_)), - m(m_), - n(n_), - metric(metric_), - /** - * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound - * - * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) - */ - n_landmarks(sqrt(m_)), - R_indptr(raft::make_device_vector(handle, sqrt(m_) + 1)), - R_1nn_cols(raft::make_device_vector(handle, m_)), - R_1nn_dists(raft::make_device_vector(handle, m_)), - R_closest_landmark_dists(raft::make_device_vector(handle, m_)), - R(raft::make_device_matrix(handle, sqrt(m_), n_)), - R_radius(raft::make_device_vector(handle, sqrt(m_))), - index_trained(false) - { - } - - explicit BallCoverIndex(raft::resources const& handle_, - raft::device_matrix_view X_, - cuvs::distance::DistanceType metric_) - : handle(handle_), - X(X_), - m(X_.extent(0)), - n(X_.extent(1)), - metric(metric_), - /** - * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound - * - * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) - */ - n_landmarks(sqrt(X_.extent(0))), - R_indptr(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1)), - R_1nn_cols(raft::make_device_vector(handle, X_.extent(0))), - R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), - R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), - R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), - R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), - index_trained(false) - { - } - - auto get_R_indptr() const -> raft::device_vector_view - { - return R_indptr.view(); - } - auto get_R_1nn_cols() const -> raft::device_vector_view - { - return R_1nn_cols.view(); - } - auto get_R_1nn_dists() const -> raft::device_vector_view - { - return R_1nn_dists.view(); - } - auto get_R_radius() const -> raft::device_vector_view - { - return R_radius.view(); - } - auto get_R() const -> raft::device_matrix_view - { - return R.view(); - } - auto get_R_closest_landmark_dists() const -> raft::device_vector_view - { - return R_closest_landmark_dists.view(); - } - - raft::device_vector_view get_R_indptr() { return R_indptr.view(); } - raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } - raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } - raft::device_vector_view get_R_radius() { return R_radius.view(); } - raft::device_matrix_view get_R() { return R.view(); } - raft::device_vector_view get_R_closest_landmark_dists() - { - return R_closest_landmark_dists.view(); - } - raft::device_matrix_view get_X() const { return X; } - - cuvs::distance::DistanceType get_metric() const { return metric; } - - value_int get_n_landmarks() const { return n_landmarks; } - bool is_index_trained() const { return index_trained; }; - - // This should only be set by internal functions - void set_index_trained() { index_trained = true; } - - raft::resources const& handle; - - value_int m; - value_int n; - value_int n_landmarks; - - raft::device_matrix_view X; - - cuvs::distance::DistanceType metric; - - private: - // CSR storing the neighborhoods for each data point - raft::device_vector R_indptr; - raft::device_vector R_1nn_cols; - raft::device_vector R_1nn_dists; - raft::device_vector R_closest_landmark_dists; - - raft::device_vector R_radius; - - raft::device_matrix R; - - protected: - bool index_trained; -}; - -/** @} */ - -} // namespace cuvs::neighbors::ball_cover diff --git a/cpp/include/cuvs/neighbors/brute_force-ext.cuh b/cpp/include/cuvs/neighbors/brute_force-ext.cuh deleted file mode 100644 index bc4773513..000000000 --- a/cpp/include/cuvs/neighbors/brute_force-ext.cuh +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include // cuvs::distance::DistanceType -#include -#include // raft::device_matrix_view -#include // raft::identity_op -#include // raft::resources -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::brute_force { - -template -inline void knn_merge_parts( - raft::resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) RAFT_EXPLICIT; - -template -index build( - raft::resources const& res, - raft::mdspan, raft::row_major, Accessor> dataset, - cuvs::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - T metric_arg = 0.0) RAFT_EXPLICIT; - -template -void search(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - -template -void knn(raft::resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; - -template -void fused_l2_knn(raft::resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - cuvs::distance::DistanceType metric) RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::brute_force - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -// No extern template for cuvs::neighbors::brute_force::knn_merge_parts - -#define instantiate_raft_neighbors_brute_force_knn( \ - idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ - extern template void cuvs::neighbors::brute_force:: \ - knn( \ - raft::resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - cuvs::distance::DistanceType metric, \ - std::optional metric_arg, \ - std::optional global_id_offset, \ - epilogue_op distance_epilogue); - -instantiate_raft_neighbors_brute_force_knn( - int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - int, float, int, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); - -#undef instantiate_raft_neighbors_brute_force_knn - -namespace cuvs::neighbors::brute_force { - -extern template void search( - raft::resources const& res, - const cuvs::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template void search( - raft::resources const& res, - const cuvs::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template cuvs::neighbors::brute_force::index build( - raft::resources const& res, - raft::device_matrix_view dataset, - cuvs::distance::DistanceType metric, - float metric_arg); -} // namespace cuvs::neighbors::brute_force - -#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ - value_t, idx_t, idx_layout, query_layout) \ - extern template void cuvs::neighbors::brute_force::fused_l2_knn( \ - raft::resources const& handle, \ - raft::device_matrix_view index, \ - raft::device_matrix_view query, \ - raft::device_matrix_view out_inds, \ - raft::device_matrix_view out_dists, \ - cuvs::distance::DistanceType metric); - -instantiate_raft_neighbors_brute_force_fused_l2_knn(float, - int64_t, - raft::row_major, - raft::row_major) - -#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/include/cuvs/neighbors/brute_force-inl.cuh b/cpp/include/cuvs/neighbors/brute_force-inl.cuh deleted file mode 100644 index 3d5c449a9..000000000 --- a/cpp/include/cuvs/neighbors/brute_force-inl.cuh +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::brute_force { - -/** - * @defgroup brute_force_knn Brute-force K-Nearest Neighbors - * @{ - */ - -/** - * @brief Performs a k-select across several (contiguous) row-partitioned index/distance - * matrices formatted like the following: - * - * part1row1: k0, k1, k2, k3 - * part1row2: k0, k1, k2, k3 - * part1row3: k0, k1, k2, k3 - * part2row1: k0, k1, k2, k3 - * part2row2: k0, k1, k2, k3 - * part2row3: k0, k1, k2, k3 - * etc... - * - * The example above shows what an aggregated index/distance matrix - * would look like with two partitions when n_samples=3 and k=4. - * - * When working with extremely large data sets that have been broken - * over multiple indexes, such as when computing over multiple GPUs, - * the ids will often start at 0 for each local knn index but the - * global ids need to be used when merging them together. An optional - * translations vector can be supplied to map the starting id of - * each partition to its global id so that the final merged knn - * is based on the global ids. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * compute multiple knn graphs and aggregate row-wise - * (see detailed description above) - * ... - * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); - * @endcode - * - * @tparam idx_t - * @tparam value_t - * - * @param[in] handle - * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) - * @param[in] in_values matrix of input values (size n_samples * n_parts * k) - * @param[out] out_keys matrix of output keys (size n_samples * k) - * @param[out] out_values matrix of output values (size n_samples * k) - * @param[in] n_samples number of rows in each partition - * @param[in] translations optional vector of starting global id mappings for each local partition - */ -template -inline void knn_merge_parts( - raft::resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), - "in_keys and in_values must have the same shape."); - RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples, - "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS( - out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1), - "Number of columns in output indices and distances matrices must be equal to k"); - - idx_t* translations_ptr = nullptr; - if (translations.has_value()) { translations_ptr = translations.value().data_handle(); } - - auto n_parts = in_keys.extent(0) / n_samples; - detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - resource::get_cuda_stream(handle), - translations_ptr); -} - -/** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. Inputs can be either - * row- or column-major but the output matrices will always be in - * row-major format. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, metric); - * @endcode - * - * @param[in] handle: the cuml handle to use - * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search: matrix (size n*d) to be used for searching the index - * @param[out] indices: matrix (size n*k) to store output knn indices - * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * @param[in] global_id_offset: optional starting global id mapping for the local partition - * (assumes the index contains contiguous ids in the global id space) - * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This - function takes a triple of the (value, rowid, colid) for each - element in the pairwise distances and returns a transformed value - back. - */ -template -void knn(raft::resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) -{ - RAFT_EXPECTS(index[0].extent(1) == search.extent(1), - "Number of dimensions for both index and search matrices must be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), - "Number of columns in output indices and distances matrices must the same"); - - bool rowMajorIndex = std::is_same_v; - bool rowMajorQuery = std::is_same_v; - - std::vector inputs; - std::vector sizes; - for (std::size_t i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extent(0)); - } - - std::vector trans; - if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } - - std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - - cuvs::neighbors::detail::brute_force_knn_impl(handle, - inputs, - sizes, - index[0].extent(1), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - search.extent(0), - indices.data_handle(), - distances.data_handle(), - indices.extent(1), - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f), - distance_epilogue); -} - -/** - * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - * - * This is a specialized function for fusing the k-selection with the distance - * computation when k < 64. The value of k will be inferred from the number - * of columns in the output matrices. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * - * raft::resources handle; - * ... - * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; - * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); - * @endcode - - * @tparam value_t type of values - * @tparam idx_t type of indices - * @tparam idx_layout layout type of index matrix - * @tparam query_layout layout type of query matrix - * @param[in] handle raft handle for sharing expensive resources - * @param[in] index input index array on device (size m * d) - * @param[in] query input query array on device (size n * d) - * @param[out] out_inds output indices array on device (size n * k) - * @param[out] out_dists output dists array on device (size n * k) - * @param[in] metric type of distance computation to perform (must be a variant of L2) - */ -template -void fused_l2_knn(raft::resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - cuvs::distance::DistanceType metric) -{ - int k = static_cast(out_inds.extent(1)); - - RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); - RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(index.extent(1) == query.extent(1), - "Number of columns in input matrices must be the same."); - - RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || - metric == distance::DistanceType::L2Unexpanded || - metric == distance::DistanceType::L2SqrtUnexpanded || - metric == distance::DistanceType::L2SqrtExpanded, - "Distance metric must be L2"); - - size_t n_index_rows = index.extent(0); - size_t n_query_rows = query.extent(0); - size_t D = index.extent(1); - - RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); - RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - - const bool rowMajorIndex = raft::is_row_major(index); - const bool rowMajorQuery = raft::is_row_major(query); - - cuvs::spatial::knn::detail::fusedL2Knn(D, - out_inds.data_handle(), - out_dists.data_handle(), - index.data_handle(), - query.data_handle(), - n_index_rows, - n_query_rows, - k, - rowMajorIndex, - rowMajorQuery, - raft::resource::get_cuda_stream(handle), - metric); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * @tparam T data element type - * - * @param[in] res - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * - * @return the constructed brute force index - */ -template -index build( - raft::resources const& res, - raft::mdspan, raft::row_major, Accessor> dataset, - cuvs::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - T metric_arg = 0.0) -{ - // certain distance metrics can benefit by pre-calculating the norms for the index dataset - // which lets us avoid calculating these at query time - std::optional> norms; - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded) { - norms = raft::make_device_vector(res, dataset.extent(0)); - // cosine needs the l2norm, where as l2 distances needs the squared norm - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(res, - dataset, - norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op{}); - } else { - raft::linalg::norm(res, - dataset, - norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - } - } - - return index(res, dataset, std::move(norms), metric, metric_arg); -} - -/** - * @brief Brute Force search using the constructed index. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] idx brute force index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - cuvs::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); -} -/** @} */ // end group brute_force_knn -} // namespace cuvs::neighbors::brute_force diff --git a/cpp/include/cuvs/neighbors/brute_force.cuh b/cpp/include/cuvs/neighbors/brute_force.cuh deleted file mode 100644 index 91065d35f..000000000 --- a/cpp/include/cuvs/neighbors/brute_force.cuh +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "brute_force-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "brute_force-ext.cuh" -#endif - -#include - -namespace cuvs::neighbors::brute_force { -/** - * @brief Make a brute force query over batches of k - * - * This lets you query for batches of k. For example, you can get - * the first 100 neighbors, then the next 100 neighbors etc. - * - * Example usage: - * @code{.cpp} - * #include - * #include - * #include - - * // create a random dataset - * int n_rows = 10000; - * int n_cols = 10000; - - * raft::device_resources res; - * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); - * auto labels = raft::make_device_vector(res, n_rows); - - * raft::make_blobs(res, dataset.view(), labels.view()); - * - * // create a brute_force knn index from the dataset - * auto index = cuvs::neighbors::brute_force::build(res, - * raft::make_const_mdspan(dataset.view())); - * - * // search the index in batches of 128 nearest neighbors - * auto search = raft::make_const_mdspan(dataset.view()); - * auto query = make_batch_k_query(res, index, search, 128); - * for (auto & batch: *query) { - * // batch.indices() and batch.distances() contain the information on the current batch - * } - * - * // we can also support variable sized batches - loaded up a different number - * // of neighbors at each iteration through the ::advance method - * int64_t batch_size = 128; - * query = make_batch_k_query(res, index, search, batch_size); - * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { - * // batch.indices() and batch.distances() contain the information on the current batch - * - * batch_size += 16; // load up an extra 16 items in the next batch - * } - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * @param[in] res - * @param[in] index The index to query - * @param[in] query A device matrix view to query for [n_queries, index->dim()] - * @param[in] batch_size The size of each batch - */ - -template -std::shared_ptr> make_batch_k_query( - const raft::resources& res, - const cuvs::neighbors::brute_force::index& index, - raft::device_matrix_view query, - int64_t batch_size) -{ - return std::shared_ptr>( - new detail::gpu_batch_k_query(res, index, query, batch_size)); -} -} // namespace cuvs::neighbors::brute_force diff --git a/cpp/include/cuvs/neighbors/brute_force_types.hpp b/cpp/include/cuvs/neighbors/brute_force_types.hpp deleted file mode 100644 index 0d3252d71..000000000 --- a/cpp/include/cuvs/neighbors/brute_force_types.hpp +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs::neighbors::brute_force { -/** - * @addtogroup brute_force_knn - * @{ - */ - -/** - * @brief Brute Force index. - * - * The index stores the dataset and norms for the dataset in device memory. - * - * @tparam T data element type - */ -template -struct index : ann::index { - public: - /** Distance metric used for retrieval */ - [[nodiscard]] constexpr inline cuvs::distance::DistanceType metric() const noexcept - { - return metric_; - } - - /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); } - - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); } - - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept - -> raft::device_matrix_view - { - return dataset_view_; - } - - /** Dataset norms */ - [[nodiscard]] inline auto norms() const - -> raft::device_vector_view - { - return norms_view_.value(); - } - - /** Whether or not this index has dataset norms */ - [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } - - [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } - - // Don't allow copying the index for performance reasons (try avoiding copying data) - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - ~index() = default; - - /** Construct a brute force index from dataset - * - * Constructs a brute force index from a dataset. This lets us precompute norms for - * the dataset, providing a speed benefit over doing this at query time. - - * If the dataset is already in GPU memory, then this class stores a non-owning reference to - * the dataset. If the dataset is in host memory, it will be copied to the device and the - * index will own the device memory. - */ - template - index(raft::resources const& res, - raft::mdspan, raft::row_major, data_accessor> dataset, - std::optional>&& norms, - cuvs::distance::DistanceType metric, - T metric_arg = 0.0) - : ann::index(), - metric_(metric), - dataset_(raft::make_device_matrix(res, 0, 0)), - norms_(std::move(norms)), - metric_arg_(metric_arg) - { - if (norms_) { norms_view_ = raft::make_const_mdspan(norms_.value().view()); } - update_dataset(res, dataset); - raft::resource::sync_stream(res); - } - - /** Construct a brute force index from dataset - * - * This class stores a non-owning reference to the dataset and norms here. - * Having precomputed norms gives us a performance advantage at query time. - */ - index(raft::resources const& res, - raft::device_matrix_view dataset_view, - std::optional> norms_view, - cuvs::distance::DistanceType metric, - T metric_arg = 0.0) - : ann::index(), - metric_(metric), - dataset_(raft::make_device_matrix(res, 0, 0)), - dataset_view_(dataset_view), - norms_view_(norms_view), - metric_arg_(metric_arg) - { - } - - private: - /** - * Replace the dataset with a new dataset. - */ - void update_dataset(raft::resources const& res, - raft::device_matrix_view dataset) - { - dataset_view_ = dataset; - } - - /** - * Replace the dataset with a new dataset. - * - * We create a copy of the dataset on the device. The index manages the lifetime of this copy. - */ - void update_dataset(raft::resources const& res, - raft::host_matrix_view dataset) - { - dataset_ = raft::make_device_matrix(dataset.extents(0), dataset.extents(1)); - raft::copy(dataset_.data_handle(), - dataset.data_handle(), - dataset.size(), - resource::get_cuda_stream(res)); - dataset_view_ = raft::make_const_mdspan(dataset_.view()); - } - - cuvs::distance::DistanceType metric_; - raft::device_matrix dataset_; - std::optional> norms_; - std::optional> norms_view_; - raft::device_matrix_view dataset_view_; - T metric_arg_; -}; - -/** - * @brief Interface for performing queries over values of k - * - * This interface lets you iterate over batches of k from a brute_force::index. - * This lets you do things like retrieve the first 100 neighbors for a query, - * apply post processing to remove any unwanted items and then if needed get the - * next 100 closest neighbors for the query. - * - * This query interface exposes C++ iterators through the ::begin and ::end, and - * is compatible with range based for loops. - * - * Note that this class is an abstract class without any cuda dependencies, meaning - * that it doesn't require a cuda compiler to use - but also means it can't be directly - * instantiated. See the cuvs::neighbors::brute_force::make_batch_k_query - * function for usage examples. - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - */ -template -class batch_k_query { - public: - batch_k_query(const raft::resources& res, - int64_t index_size, - int64_t query_size, - int64_t batch_size) - : res(res), index_size(index_size), query_size(query_size), batch_size(batch_size) - { - } - virtual ~batch_k_query() {} - - using value_type = cuvs::neighbors::batch; - - class iterator { - public: - using value_type = cuvs::neighbors::batch; - using reference = const value_type&; - using pointer = const value_type*; - - iterator(const batch_k_query* query, int64_t offset = 0) - : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) - { - query->load_batch(offset, query->batch_size, &batches); - query->slice_batch(batches, offset, query->batch_size, ¤t); - } - - reference operator*() const { return current; } - - pointer operator->() const { return ¤t; } - - iterator& operator++() - { - advance(query->batch_size); - return *this; - } - - iterator operator++(int) - { - iterator previous(*this); - operator++(); - return previous; - } - - /** - * @brief Advance the iterator, using a custom size for the next batch - * - * Using operator++ means that we will load up the same batch_size for each - * batch. This method allows us to get around this restriction, and load up - * arbitrary batch sizes on each iteration. - * See cuvs::neighbors::brute_force::make_batch_k_query for a usage example. - * - * @param[in] next_batch_size: size of the next batch to load up - */ - void advance(int64_t next_batch_size) - { - offset = std::min(offset + current.batch_size(), query->index_size); - if (offset + next_batch_size > batches.batch_size()) { - query->load_batch(offset, next_batch_size, &batches); - } - query->slice_batch(batches, offset, next_batch_size, ¤t); - } - - friend bool operator==(const iterator& lhs, const iterator& rhs) - { - return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); - }; - friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; - - protected: - // the current batch of data - value_type current; - - // the currently loaded group of data (containing multiple batches of data that we can iterate - // through) - value_type batches; - - const batch_k_query* query; - int64_t offset, current_batch_size; - }; - - iterator begin() const { return iterator(this); } - iterator end() const { return iterator(this, index_size); } - - protected: - // these two methods need cuda code, and are implemented in the subclass - virtual void load_batch(int64_t offset, - int64_t next_batch_size, - batch* output) const = 0; - virtual void slice_batch(const value_type& input, - int64_t offset, - int64_t batch_size, - value_type* output) const = 0; - - const raft::resources& res; - int64_t index_size, query_size, batch_size; -}; -/** @} */ - -} // namespace cuvs::neighbors::brute_force diff --git a/cpp/include/cuvs/neighbors/cagra.cuh b/cpp/include/cuvs/neighbors/cagra.cuh deleted file mode 100644 index a8e42d18a..000000000 --- a/cpp/include/cuvs/neighbors/cagra.cuh +++ /dev/null @@ -1,425 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/cagra/cagra_build.cuh" -#include "detail/cagra/cagra_search.cuh" -#include "detail/cagra/graph_core.cuh" - -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra { - -/** - * @defgroup cagra CUDA ANN Graph-based nearest neighbor search - * @{ - */ - -/** - * @brief Build a kNN graph using IVF-PQ. - * - * The kNN graph is the first building block for CAGRA index. - * - * The output is a dense matrix that stores the neighbor indices for each point in the dataset. - * Each point has the same number of neighbors. - * - * See [cagra::build](#cagra::build) for an alternative method. - * - * The following distance metrics are supported: - * - L2Expanded - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * ivf_pq::index_params build_params; - * ivf_pq::search_params search_params - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // create knn graph - * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); - * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT data element type - * @tparam IdxT type of the dataset vector indices - * - * @param[in] res raft resources - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] refine_rate (optional) refinement rate for ivf-pq search - * @param[in] build_params (optional) ivf_pq index building parameters for knn graph - * @param[in] search_params (optional) ivf_pq search parameters - */ -template -void build_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - std::optional refine_rate = std::nullopt, - std::optional build_params = std::nullopt, - std::optional search_params = std::nullopt) -{ - using internal_IdxT = typename std::make_unsigned::type; - - auto knn_graph_internal = raft::make_host_matrix_view( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - auto dataset_internal = - raft::mdspan, raft::row_major, accessor>( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - cagra::detail::build_knn_graph( - res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); -} - -/** - * @brief Build a kNN graph using NN-descent. - * - * The kNN graph is the first building block for CAGRA index. - * - * The output is a dense matrix that stores the neighbor indices for each point in the dataset. - * Each point has the same number of neighbors. - * - * See [cagra::build](#cagra::build) for an alternative method. - * - * The following distance metrics are supported: - * - L2Expanded - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * using namespace cuvs::neighbors::experimental; - * // use default index parameters - * nn_descent::index_params build_params; - * build_params.graph_degree = 128; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // create knn graph - * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); - * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT data element type - * @tparam IdxT type of the dataset vector indices - * @tparam accessor host or device accessor_type for the dataset - * @param[in] res raft::resources is an object mangaging resources - * @param[in] dataset input raft::host/device_matrix_view that can be located in - * in host or device memory - * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters - * to run the nn-descent algorithm - */ -template , memory_type::device>> -void build_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - experimental::nn_descent::index_params build_params) -{ - detail::build_knn_graph(res, dataset, knn_graph, build_params); -} - -/** - * @brief Sort a KNN graph index. - * Preprocessing step for `cagra::optimize`: If a KNN graph is not built using - * `cagra::build_knn_graph`, then it is necessary to call this function before calling - * `cagra::optimize`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and - * you do not need to call this function. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * cagra::index_params build_params; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // build KNN graph not using `cagra::build_knn_graph` - * // build(knn_graph, dataset, ...); - * // sort graph index - * sort_knn_graph(res, dataset.view(), knn_graph.view()); - * // optimize graph - * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT type of the data in the source dataset - * @tparam IdxT type of the dataset vector indices - * - * @param[in] res raft resources - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows, - * knn_graph_degree] - */ -template , memory_type::device>, - typename g_accessor = - host_device_accessor, memory_type::host>> -void sort_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, d_accessor> dataset, - raft::mdspan, raft::row_major, g_accessor> knn_graph) -{ - using internal_IdxT = typename std::make_unsigned::type; - - using g_accessor_internal = - host_device_accessor, g_accessor::mem_type>; - auto knn_graph_internal = - raft::mdspan, raft::row_major, g_accessor_internal>( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - - auto dataset_internal = - raft::mdspan, raft::row_major, d_accessor>( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - cagra::detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); -} - -/** - * @brief Prune a KNN graph. - * - * Decrease the number of neighbors for each node. - * - * See [cagra::build_knn_graph](#cagra::build_knn_graph) for usage example - * - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res raft resources - * @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows, - * knn_graph_degree] - * @param[out] new_graph a host matrix view of the optimized knn graph [n_rows, graph_degree] - */ -template , memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - detail::optimize(res, knn_graph, new_graph); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * The build consist of two steps: build an intermediate knn-graph, and optimize it to - * create the final graph. The index_params struct controls the node degree of these - * graphs. - * - * It is required that dataset and the optimized graph fit the GPU memory. - * - * To customize the parameters for knn-graph building and pruning, and to reuse the - * intermediate results, you could build the index in two steps using - * [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::optimize](#cagra::optimize). - * - * The following distance metrics are supported: - * - L2 - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * // use default search parameters - * cagra::search_params search_params; - * // search K nearest neighbours - * auto neighbors = raft::make_device_matrix(res, n_queries, k); - * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params parameters for building the index - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * - * @return the constructed cagra index - */ -template , memory_type::host>> -index build( - raft::resources const& res, - const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset) -{ - return detail::build(res, params, dataset); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - cuvs::neighbors::filtering::none_cagra_sample_filter()); -} - -/** - * @brief Search ANN using the constructed index with the given sample filter. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * // use default search parameters - * cagra::search_params search_params; - * // create a bitset to filter the search - * auto removed_indices = raft::make_device_vector(res, n_removed_indices); - * raft::core::bitset removed_indices_bitset( - * res, removed_indices.view(), dataset.extent(0)); - * // search K nearest neighbours according to a bitset - * auto neighbors = raft::make_device_matrix(res, n_queries, k); - * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances, - * filtering::bitset_filter(removed_indices_bitset.view())); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam CagraSampleFilterT Device filter function, with the signature - * `(uint32_t query ix, uint32_t sample_ix) -> bool` - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); -} - -/** @} */ // end group cagra - -} // namespace cuvs::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace cuvs::neighbors::experimental::cagra { -using cuvs::neighbors::cagra::build; -using cuvs::neighbors::cagra::build_knn_graph; -using cuvs::neighbors::cagra::optimize; -using cuvs::neighbors::cagra::search; -using cuvs::neighbors::cagra::sort_knn_graph; -} // namespace cuvs::neighbors::experimental::cagra diff --git a/cpp/include/cuvs/neighbors/cagra_types.hpp b/cpp/include/cuvs/neighbors/cagra.hpp similarity index 56% rename from cpp/include/cuvs/neighbors/cagra_types.hpp rename to cpp/include/cuvs/neighbors/cagra.hpp index 0299b78df..3a0c60b78 100644 --- a/cpp/include/cuvs/neighbors/cagra_types.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,24 +17,14 @@ #pragma once #include "ann_types.hpp" -#include - #include -#include -#include -#include -#include -#include +#include +#include +#include #include -#include - -#include -#include -#include -#include -#include +#include +#include -#include namespace cuvs::neighbors::cagra { /** * @addtogroup cagra @@ -61,6 +51,21 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; + + /** Build a raft CAGRA index params from an existing cuvs CAGRA index params. */ + operator raft::neighbors::cagra::index_params() const + { + return raft::neighbors::cagra::index_params{ + { + .metric = static_cast((int)this->metric), + .metric_arg = this->metric_arg, + .add_data_on_build = this->add_data_on_build, + }, + .intermediate_graph_degree = intermediate_graph_degree, + .graph_degree = graph_degree, + .build_algo = static_cast((int)build_algo), + .nn_descent_niter = nn_descent_niter}; + } }; enum class search_algo { @@ -116,6 +121,27 @@ struct search_params : ann::search_params { uint32_t num_random_samplings = 1; /** Bit mask used for initial random seed node selection. */ uint64_t rand_xor_mask = 0x128394; + + /** Build a raft CAGRA search params from an existing cuvs CAGRA search params. */ + operator raft::neighbors::cagra::search_params() const + { + raft::neighbors::cagra::search_params result = { + {}, + max_queries, + itopk_size, + max_iterations, + static_cast((int)algo), + team_size, + search_width, + min_iterations, + thread_block_size, + static_cast((int)hashmap_mode), + hashmap_min_bitlen, + hashmap_max_fill_rate, + num_random_samplings, + rand_xor_mask}; + return result; + } }; static_assert(std::is_aggregate_v); @@ -132,6 +158,12 @@ static_assert(std::is_aggregate_v); */ template struct index : ann::index { + /** Build a cuvs CAGRA index from an existing RAFT CAGRA index. */ + index(raft::neighbors::cagra::index&& raft_idx) + : ann::index(), + raft_index_{std::make_unique>(std::move(raft_idx))} + { + } static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); @@ -139,38 +171,35 @@ struct index : ann::index { /** Distance metric used for clustering. */ [[nodiscard]] constexpr inline auto metric() const noexcept -> cuvs::distance::DistanceType { - return metric_; + return static_cast((int)raft_index_->metric()); } /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT - { - return dataset_view_.extent(0); - } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return raft_index_->size(); } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { - return dataset_view_.extent(1); + return raft_index_->dim(); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { - return graph_view_.extent(1); + return raft_index_->graph_degree(); } /** Dataset [size, dim] */ [[nodiscard]] inline auto dataset() const noexcept -> raft::device_matrix_view { - return dataset_view_; + return raft_index_->dataset(); } /** neighborhood graph [size, graph-degree] */ [[nodiscard]] inline auto graph() const noexcept -> raft::device_matrix_view { - return graph_view_; + return raft_index_->graph(); } // Don't allow copying the index for performance reasons (try avoiding copying data) @@ -184,12 +213,10 @@ struct index : ann::index { index(raft::resources const& res, cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded) : ann::index(), - metric_(metric), - dataset_(raft::make_device_matrix(res, 0, 0)), - graph_(raft::make_device_matrix(res, 0, 0)) + raft_index_(std::make_unique>( + res, static_cast((int)metric))) { } - /** Construct an index from dataset and knn_graph arrays * * If the dataset and graph is already in GPU memory, then the index is just a thin wrapper around @@ -251,9 +278,8 @@ struct index : ann::index { raft::mdspan, raft::row_major, graph_accessor> knn_graph) : ann::index(), - metric_(metric), - dataset_(raft::make_device_matrix(res, 0, 0)), - graph_(raft::make_device_matrix(res, 0, 0)) + raft_index_(std::make_unique>( + res, static_cast((int)metric), dataset, knn_graph)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); @@ -272,15 +298,8 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) * sizeof(T) % 16 != 0) { - RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); - copy_padded(res, dataset); - } else { - dataset_view_ = raft::make_device_strided_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1)); - } + raft_index_->update_dataset(res, dataset); } - /** * Replace the dataset with a new dataset. * @@ -289,8 +308,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device"); - copy_padded(res, dataset); + raft_index_->update_dataset(res, dataset); } /** @@ -302,7 +320,7 @@ struct index : ann::index { void update_graph(raft::resources const& res, raft::device_matrix_view knn_graph) { - graph_view_ = knn_graph; + raft_index_->update_graph(res, knn_graph); } /** @@ -313,54 +331,83 @@ struct index : ann::index { void update_graph(raft::resources const& res, raft::host_matrix_view knn_graph) { - RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device"); - if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) { - // clear existing memory before allocating to prevent OOM errors on large graphs - if (graph_.size()) { graph_ = raft::make_device_matrix(res, 0, 0); } - graph_ = - raft::make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); - } - raft::copy(graph_.data_handle(), - knn_graph.data_handle(), - knn_graph.size(), - raft::resource::get_cuda_stream(res)); - graph_view_ = graph_.view(); + raft_index_->update_graph(res, knn_graph); } - private: - /** Create a device copy of the dataset, and pad it if necessary. */ - template - void copy_padded( - raft::resources const& res, - raft::mdspan, raft::row_major, data_accessor> dataset) + auto get_raft_index() const -> const raft::neighbors::cagra::index* { - detail::copy_with_padding(res, dataset_, dataset); - - dataset_view_ = raft::make_device_strided_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); - RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", - static_cast(dataset_view_.extent(0)), - static_cast(dataset_view_.extent(1)), - static_cast(dataset_view_.stride(0))); + return raft_index_.get(); } + auto get_raft_index() -> raft::neighbors::cagra::index* { return raft_index_.get(); } - cuvs::distance::DistanceType metric_; - raft::device_matrix dataset_; - raft::device_matrix graph_; - raft::device_matrix_view dataset_view_; - raft::device_matrix_view graph_view_; + private: + std::unique_ptr> raft_index_; }; +// Using device and host_matrix_view avoids needing to typedef multiple mdspans based on accessors +#define CUVS_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->cuvs::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const cuvs::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const cuvs::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + cuvs::neighbors::cagra::index* index); + +CUVS_INST_CAGRA_FUNCS(float, uint32_t); +CUVS_INST_CAGRA_FUNCS(int8_t, uint32_t); +CUVS_INST_CAGRA_FUNCS(uint8_t, uint32_t); + +#undef CUVS_INST_CAGRA_FUNCS + +#define CUVS_INST_CAGRA_OPTIMIZE(IdxT) \ + void optimize_device(raft::resources const& res, \ + raft::device_matrix_view knn_graph, \ + raft::host_matrix_view new_graph); \ + \ + void optimize_host(raft::resources const& res, \ + raft::host_matrix_view knn_graph, \ + raft::host_matrix_view new_graph); + +CUVS_INST_CAGRA_OPTIMIZE(uint32_t); + +#undef CUVS_INST_CAGRA_OPTIMIZE + /** @} */ } // namespace cuvs::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace cuvs::neighbors::experimental::cagra { -using cuvs::neighbors::cagra::graph_build_algo; -using cuvs::neighbors::cagra::hash_mode; -using cuvs::neighbors::cagra::index; -using cuvs::neighbors::cagra::index_params; -using cuvs::neighbors::cagra::search_algo; -using cuvs::neighbors::cagra::search_params; -} // namespace cuvs::neighbors::experimental::cagra diff --git a/cpp/include/cuvs/neighbors/cagra_serialize.cuh b/cpp/include/cuvs/neighbors/cagra_serialize.cuh deleted file mode 100644 index ee492ea8c..000000000 --- a/cpp/include/cuvs/neighbors/cagra_serialize.cuh +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/cagra/cagra_serialize.cuh" - -namespace cuvs::neighbors::cagra { - -/** - * \defgroup cagra_serialize CAGRA Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - std::ostream& os, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, os, index, include_dataset); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, filename, index, include_dataset); -} - -/** - * Write the CAGRA built index as a base layer HNSW index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - std::ostream& os, - const index& index) -{ - detail::serialize_to_hnswlib(handle, os, index); -} - -/** - * Write the CAGRA built index as a base layer HNSW index to file - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - const std::string& filename, - const index& index) -{ - detail::serialize_to_hnswlib(handle, filename, index); -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return cuvs::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, std::istream& is) -{ - return detail::deserialize(handle, is); -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return cuvs::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, const std::string& filename) -{ - return detail::deserialize(handle, filename); -} - -/**@}*/ - -} // namespace cuvs::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace cuvs::neighbors::experimental::cagra { -using cuvs::neighbors::cagra::deserialize; -using cuvs::neighbors::cagra::serialize; - -} // namespace cuvs::neighbors::experimental::cagra diff --git a/cpp/include/cuvs/neighbors/detail/cagra/bitonic.hpp b/cpp/include/cuvs/neighbors/detail/cagra/bitonic.hpp deleted file mode 100644 index d1fa0b41a..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/bitonic.hpp +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace bitonic { - -namespace detail { - -template -_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool asc) -{ - if ((k0 != k1) && ((k0 < k1) != asc)) { - const auto tmp_k = k0; - k0 = k1; - k1 = tmp_k; - const auto tmp_v = v0; - v0 = v1; - v1 = tmp_v; - } -} - -template -_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, const unsigned lane_offset, const bool asc) -{ - auto k1 = __shfl_xor_sync(~0u, k0, lane_offset); - auto v1 = __shfl_xor_sync(~0u, v0, lane_offset); - if ((k0 != k1) && ((k0 < k1) != asc)) { - k0 = k1; - v0 = v1; - } -} - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[N], V v[N], const std::uint32_t range, const bool asc) - { - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - for (std::uint32_t b = 2; b <= N; b <<= 1) { - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - std::uint32_t j = i ^ c; - if (i >= j) continue; - const auto line_id = i + (N * lane_id); - const auto p = static_cast(line_id & b) == static_cast(line_id & c); - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - } - } - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - for (std::uint32_t c = N / 2; c >= 1; c >>= 1) { -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - std::uint32_t j = i ^ c; - if (i >= j) continue; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - } - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[6], V v[6], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 6; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - for (std::uint32_t i = 0; i < N; i += 3) { - const auto p = (i == 0); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - const auto p = ((lane_id & 1) == 0); - for (std::uint32_t i = 0; i < 3; i++) { - std::uint32_t j = i + 3; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - for (std::uint32_t i = 0; i < N; i += 3) { - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - for (std::uint32_t i = 0; i < 3; i++) { - std::uint32_t j = i + 3; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - for (std::uint32_t i = 0; i < N; i += N / 2) { - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[3], V v[3], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 3; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - const auto p = ((lane_id & 1) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - swap_if_needed(k[1], v[1], k[2], v[2], p); - swap_if_needed(k[0], v[0], k[1], v[1], p); - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - swap_if_needed(k[1], v[1], k[2], v[2], p); - swap_if_needed(k[0], v[0], k[1], v[1], p); - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[2], V v[2], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 2; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - const auto p = ((lane_id & 1) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[1], V v[1], const std::uint32_t range, const bool asc) - { - const auto lane_id = threadIdx.x % warp_size; - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); - swap_if_needed(k[0], v[0], c, p); - } - } -}; - -} // namespace detail - -template -__device__ void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true) -{ - detail::warp_merge_core{}(k, v, range, asc); -} - -template -__device__ void warp_sort(K k[N], V v[N], const bool asc = true) -{ - for (std::uint32_t range = 1; range <= warp_size; range <<= 1) { - warp_merge(k, v, range, asc); - } -} - -} // namespace bitonic -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/cuvs/neighbors/detail/cagra/cagra_build.cuh deleted file mode 100644 index 399d0071b..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/cagra_build.cuh +++ /dev/null @@ -1,353 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../../cagra_types.hpp" -#include "graph_core.cuh" -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { - -template -void build_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - std::optional refine_rate = std::nullopt, - std::optional build_params = std::nullopt, - std::optional search_params = std::nullopt) -{ - resource::detail::warn_non_pool_workspace(res, "cuvs::neighbors::cagra::build"); - RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, - "Currently only L2Expanded metric is supported"); - - uint32_t node_degree = knn_graph.extent(1); - raft::common::nvtx::range fun_scope( - "cagra::build_graph(%zu, %zu, %u)", - size_t(dataset.extent(0)), - size_t(dataset.extent(1)), - node_degree); - - if (!build_params) { - build_params = ivf_pq::index_params{}; - build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500); - build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2); - build_params->pq_bits = 8; - build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10; - build_params->kmeans_n_iters = 25; - build_params->add_data_on_build = true; - } - - // Make model name - const std::string model_name = [&]() { - char model_name[1024]; - sprintf(model_name, - "%s-%lux%lu.cluster_%u.pq_%u.%ubit.itr_%u.metric_%u.pqcenter_%u", - "IVF-PQ", - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1)), - build_params->n_lists, - build_params->pq_dim, - build_params->pq_bits, - build_params->kmeans_n_iters, - build_params->metric, - static_cast(build_params->codebook_kind)); - return std::string(model_name); - }(); - - RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str()); - auto index = ivf_pq::build( - res, *build_params, dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - // - // search top (k + 1) neighbors - // - if (!search_params) { - search_params = ivf_pq::search_params{}; - search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); - search_params->lut_dtype = CUDA_R_8U; - search_params->internal_distance_dtype = CUDA_R_32F; - } - const auto top_k = node_degree + 1; - uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); - const auto num_queries = dataset.extent(0); - const auto max_batch_size = 1024; - RAFT_LOG_DEBUG( - "IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u", - node_degree, - top_k, - gpu_top_k, - max_batch_size, - search_params->n_probes); - - auto distances = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto neighbors = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto refined_distances = raft::make_device_matrix(res, max_batch_size, top_k); - auto refined_neighbors = raft::make_device_matrix(res, max_batch_size, top_k); - auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); - auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); - auto refined_neighbors_host = raft::make_host_matrix(max_batch_size, top_k); - auto refined_distances_host = raft::make_host_matrix(max_batch_size, top_k); - - // TODO(tfeher): batched search with multiple GPUs - std::size_t num_self_included = 0; - bool first = true; - const auto start_clock = std::chrono::system_clock::now(); - - rmm::mr::device_memory_resource* device_memory = nullptr; - auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); - if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq using pool memory resource"); } - - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches( - dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - max_batch_size, - resource::get_cuda_stream(res), - device_memory); - - size_t next_report_offset = 0; - size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps. - - for (const auto& batch : vec_batches) { - // Map int64_t to uint32_t because ivf_pq requires the latter. - // TODO(tfeher): remove this mapping once ivf_pq accepts raft::mdspan with int64_t index type - auto queries_view = raft::make_device_matrix_view( - batch.data(), batch.size(), batch.row_width()); - auto neighbors_view = raft::make_device_matrix_view( - neighbors.data_handle(), batch.size(), neighbors.extent(1)); - auto distances_view = raft::make_device_matrix_view( - distances.data_handle(), batch.size(), distances.extent(1)); - - ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view); - if constexpr (is_host_mdspan_v) { - raft::copy(neighbors_host.data_handle(), - neighbors.data_handle(), - neighbors_view.size(), - resource::get_cuda_stream(res)); - raft::copy(queries_host.data_handle(), - batch.data(), - queries_view.size(), - resource::get_cuda_stream(res)); - auto queries_host_view = raft::make_host_matrix_view( - queries_host.data_handle(), batch.size(), batch.row_width()); - auto neighbors_host_view = raft::make_host_matrix_view( - neighbors_host.data_handle(), batch.size(), neighbors.extent(1)); - auto refined_neighbors_host_view = raft::make_host_matrix_view( - refined_neighbors_host.data_handle(), batch.size(), top_k); - auto refined_distances_host_view = raft::make_host_matrix_view( - refined_distances_host.data_handle(), batch.size(), top_k); - resource::sync_stream(res); - - cuvs::neighbors::detail::refine_host( - dataset, - queries_host_view, - neighbors_host_view, - refined_neighbors_host_view, - refined_distances_host_view, - build_params->metric); - } else { - auto neighbor_candidates_view = raft::make_device_matrix_view( - neighbors.data_handle(), batch.size(), gpu_top_k); - auto refined_neighbors_view = raft::make_device_matrix_view( - refined_neighbors.data_handle(), batch.size(), top_k); - auto refined_distances_view = raft::make_device_matrix_view( - refined_distances.data_handle(), batch.size(), top_k); - - auto dataset_view = raft::make_device_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - cuvs::neighbors::detail::refine_device( - res, - dataset_view, - queries_view, - neighbor_candidates_view, - refined_neighbors_view, - refined_distances_view, - build_params->metric); - raft::copy(refined_neighbors_host.data_handle(), - refined_neighbors_view.data_handle(), - refined_neighbors_view.size(), - resource::get_cuda_stream(res)); - resource::sync_stream(res); - } - // omit itself & write out - // TODO(tfeher): do this in parallel with GPU processing of next batch - for (std::size_t i = 0; i < batch.size(); i++) { - size_t vec_idx = i + batch.offset(); - for (std::size_t j = 0, num_added = 0; j < top_k && num_added < node_degree; j++) { - const auto v = refined_neighbors_host(i, j); - if (static_cast(v) == vec_idx) { - num_self_included++; - continue; - } - knn_graph(vec_idx, num_added) = v; - num_added++; - } - } - - size_t num_queries_done = batch.offset() + batch.size(); - const auto end_clock = std::chrono::system_clock::now(); - if (batch.offset() > next_report_offset) { - next_report_offset += d_report_offset; - const auto time = - std::chrono::duration_cast(end_clock - start_clock).count() * - 1e-6; - const auto throughput = num_queries_done / time; - - RAFT_LOG_DEBUG( - "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " - "%3.2f %% \r", - num_queries_done, - dataset.extent(0), - num_queries_done / static_cast(dataset.extent(0)) * 100, - throughput, - (num_queries - num_queries_done) / throughput / 60, - static_cast(num_self_included) / num_queries_done * 100.); - } - first = false; - } - - if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); -} - -template -void build_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - experimental::nn_descent::index_params build_params) -{ - auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); - experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); - - using internal_IdxT = typename std::make_unsigned::type; - using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; - using g_accessor_internal = - host_device_accessor, g_accessor::mem_type>; - - auto knn_graph_internal = - raft::mdspan, raft::row_major, g_accessor_internal>( - reinterpret_cast(nn_descent_idx.graph().data_handle()), - nn_descent_idx.graph().extent(0), - nn_descent_idx.graph().extent(1)); - - graph::sort_knn_graph(res, dataset, knn_graph_internal); -} - -template , memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - using internal_IdxT = typename std::make_unsigned::type; - - auto new_graph_internal = raft::make_host_matrix_view( - reinterpret_cast(new_graph.data_handle()), - new_graph.extent(0), - new_graph.extent(1)); - - using g_accessor_internal = - host_device_accessor, memory_type::host>; - auto knn_graph_internal = - raft::mdspan, raft::row_major, g_accessor_internal>( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - - cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal); -} - -template , memory_type::host>> -index build( - raft::resources const& res, - const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset, - std::optional nn_descent_params = std::nullopt, - std::optional refine_rate = std::nullopt, - std::optional pq_build_params = std::nullopt, - std::optional search_params = std::nullopt) -{ - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - std::optional> knn_graph( - raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - - if (params.build_algo == graph_build_algo::IVF_PQ) { - build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params); - - } else { - // Use nn-descent to build CAGRA knn graph - if (!nn_descent_params) { - nn_descent_params = experimental::nn_descent::index_params(); - nn_descent_params->graph_degree = intermediate_degree; - nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; - nn_descent_params->max_iterations = params.nn_descent_niter; - } - build_knn_graph(res, dataset, knn_graph->view(), *nn_descent_params); - } - - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); - - optimize(res, knn_graph->view(), cagra_graph.view()); - - // free intermediate graph before trying to create the index - knn_graph.reset(); - - // Construct an index from dataset and optimized knn graph. - return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); -} -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/cuvs/neighbors/detail/cagra/cagra_search.cuh deleted file mode 100644 index 87d8876e3..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/cagra_search.cuh +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "factory.cuh" -#include "search_plan.cuh" -#include "search_single_cta.cuh" - -namespace cuvs::neighbors::cagra::detail { - -template -struct CagraSampleFilterWithQueryIdOffset { - const uint32_t offset; - CagraSampleFilterT filter; - - CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) - : offset(offset), filter(filter) - { - } - - _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) - { - return filter(query_id + offset, sample_id); - } -}; - -template -struct CagraSampleFilterT_Selector { - using type = CagraSampleFilterWithQueryIdOffset; -}; -template <> -struct CagraSampleFilterT_Selector { - using type = cuvs::neighbors::filtering::none_cagra_sample_filter; -}; - -// A helper function to set a query id offset -template -inline typename CagraSampleFilterT_Selector::type set_offset( - CagraSampleFilterT filter, const uint32_t offset) -{ - typename CagraSampleFilterT_Selector::type new_filter(offset, filter); - return new_filter; -} -template <> -inline - typename CagraSampleFilterT_Selector::type - set_offset( - cuvs::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) -{ - return filter; -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [build](#build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of database vector indices - * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need - * separate kernels for int/uint. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ - -template -void search_main(raft::resources const& res, - search_params params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) -{ - resource::detail::warn_non_pool_workspace(res, "cuvs::neighbors::cagra::search"); - RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(index.dataset().extent(0)), - static_cast(index.dataset().extent(1))); - RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", - static_cast(queries.extent(0)), - static_cast(queries.extent(1))); - RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match"); - const uint32_t topk = neighbors.extent(1); - - if (params.max_queries == 0) { params.max_queries = queries.extent(0); } - - raft::common::nvtx::range fun_scope( - "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); - - using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; - std::unique_ptr> plan = - factory::create( - res, params, index.dim(), index.graph_degree(), topk); - - plan->check(neighbors.extent(1)); - - RAFT_LOG_DEBUG("Cagra search"); - const uint32_t max_queries = plan->max_queries; - const uint32_t query_dim = queries.extent(1); - - for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { - const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); - internal_IdxT* _topk_indices_ptr = - reinterpret_cast(neighbors.data_handle()) + (topk * qid); - DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid); - // todo(tfeher): one could keep distances optional and pass nullptr - const T* _query_ptr = queries.data_handle() + (query_dim * qid); - const internal_IdxT* _seed_ptr = - plan->num_seeds > 0 - ? reinterpret_cast(plan->dev_seed.data()) + (plan->num_seeds * qid) - : nullptr; - uint32_t* _num_executed_iterations = nullptr; - - auto dataset_internal = - raft::make_device_strided_matrix_view( - index.dataset().data_handle(), - index.dataset().extent(0), - index.dataset().extent(1), - index.dataset().stride(0)); - auto graph_internal = - raft::make_device_matrix_view( - reinterpret_cast(index.graph().data_handle()), - index.graph().extent(0), - index.graph().extent(1)); - - (*plan)(res, - dataset_internal, - graph_internal, - _topk_indices_ptr, - _topk_distances_ptr, - _query_ptr, - n_queries, - _seed_ptr, - _num_executed_iterations, - topk, - set_offset(sample_filter, qid)); - } - - static_assert(std::is_same_v, - "only float distances are supported at the moment"); - float* dist_out = distances.data_handle(); - const DistanceT* dist_in = distances.data_handle(); - // We're converting the data from T to DistanceT during distance computation - // and divide the values by kDivisor. Here we restore the original scale. - constexpr float kScale = spatial::knn::detail::utils::config::kDivisor / - spatial::knn::detail::utils::config::kDivisor; - ivf_pq::detail::postprocess_distances(dist_out, - dist_in, - index.metric(), - distances.extent(0), - distances.extent(1), - kScale, - resource::get_cuda_stream(res)); -} -/** @} */ // end group cagra - -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/cuvs/neighbors/detail/cagra/cagra_serialize.cuh deleted file mode 100644 index 019da84f3..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/cagra_serialize.cuh +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace cuvs::neighbors::cagra::detail { - -constexpr int serialization_version = 3; - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] res the raft resource handle - * @param[in] filename the file name for saving the index - * @param[in] index_ CAGRA index - * - */ -template -void serialize(raft::resources const& res, - std::ostream& os, - const index& index_, - bool include_dataset) -{ - raft::common::nvtx::range fun_scope("cagra::serialize"); - - RAFT_LOG_DEBUG( - "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - - std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); - dtype_string.resize(4); - os << dtype_string; - - serialize_scalar(res, os, serialization_version); - serialize_scalar(res, os, index_.size()); - serialize_scalar(res, os, index_.dim()); - serialize_scalar(res, os, index_.graph_degree()); - serialize_scalar(res, os, index_.metric()); - serialize_mdspan(res, os, index_.graph()); - - serialize_scalar(res, os, include_dataset); - if (include_dataset) { - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); - } -} - -template -void serialize(raft::resources const& res, - const std::string& filename, - const index& index_, - bool include_dataset) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(res, of, index_, include_dataset); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -template -void serialize_to_hnswlib(raft::resources const& res, - std::ostream& os, - const index& index_) -{ - raft::common::nvtx::range fun_scope( - "cagra::serialize_to_hnswlib"); - RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", - static_cast(index_.size()), - index_.dim()); - - // offset_level_0 - std::size_t offset_level_0 = 0; - os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); - // max_element - std::size_t max_element = index_.size(); - os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); - // curr_element_count - std::size_t curr_element_count = index_.size(); - os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); - // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, - // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + - // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = - static_cast(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8); - os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); - // label_offset - std::size_t label_offset = size_data_per_element - 8; - os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); - // offset_data - auto offset_data = static_cast(index_.graph_degree() * 4 + 4); - os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); - // max_level - int max_level = 1; - os.write(reinterpret_cast(&max_level), sizeof(int)); - // entrypoint_node - auto entrypoint_node = static_cast(index_.size() / 2); - os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); - // max_M - auto max_M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); - // max_M0 - std::size_t max_M0 = index_.graph_degree(); - os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); - // M - auto M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&M), sizeof(std::size_t)); - // mult, can be anything - double mult = 0.42424242; - os.write(reinterpret_cast(&mult), sizeof(double)); - // efConstruction, can be anything - std::size_t efConstruction = 500; - os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - - auto graph = index_.graph(); - auto host_graph = - raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), - graph.data_handle(), - graph.size(), - raft::resource::get_cuda_stream(res)); - resource::sync_stream(res); - - // Write one dataset and graph row at a time - for (std::size_t i = 0; i < index_.size(); i++) { - auto graph_degree = static_cast(index_.graph_degree()); - os.write(reinterpret_cast(&graph_degree), sizeof(int)); - - for (std::size_t j = 0; j < index_.graph_degree(); ++j) { - auto graph_elem = host_graph(i, j); - os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); - } - - auto data_row = host_dataset.data_handle() + (index_.dim() * i); - if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = static_cast(host_dataset(i, j)); - os.write(reinterpret_cast(&data_elem), sizeof(int)); - } - } - - os.write(reinterpret_cast(&i), sizeof(std::size_t)); - } - - for (std::size_t i = 0; i < index_.size(); i++) { - // zeroes - auto zero = 0; - os.write(reinterpret_cast(&zero), sizeof(int)); - } - // delete [] host_graph; -} - -template -void serialize_to_hnswlib(raft::resources const& res, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize_to_hnswlib(res, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] res the raft resource handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ CAGRA index - * - */ -template -auto deserialize(raft::resources const& res, std::istream& is) -> index -{ - raft::common::nvtx::range fun_scope("cagra::deserialize"); - - char dtype_string[4]; - is.read(dtype_string, 4); - - auto ver = deserialize_scalar(res, is); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto graph_degree = deserialize_scalar(res, is); - auto metric = deserialize_scalar(res, is); - - auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, graph.view()); - - bool has_dataset = deserialize_scalar(res, is); - if (has_dataset) { - auto dataset = raft::make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, dataset.view()); - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); - } else { - // create a new index with no dataset - the user must supply via update_dataset themselves - // later (this avoids allocating GPU memory in the meantime) - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - return idx; - } -} - -template -auto deserialize(raft::resources const& res, const std::string& filename) -> index -{ - std::ifstream is(filename, std::ios::in | std::ios::binary); - - if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(res, is); - - is.close(); - - return index; -} -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/cuvs/neighbors/detail/cagra/compute_distance.hpp deleted file mode 100644 index d77d10f3c..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/compute_distance.hpp +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "device_common.hpp" -#include "hashmap.hpp" -#include "utils.hpp" -#include - -namespace cuvs::neighbors::cagra::detail { -namespace device { - -// using LOAD_256BIT_T = ulonglong4; -using LOAD_128BIT_T = uint4; -using LOAD_64BIT_T = uint64_t; - -template -_RAFT_DEVICE constexpr unsigned get_vlen() -{ - return utils::size_of() / utils::size_of(); -} - -template -struct data_load_t { - union { - LOAD_T load; - DATA_T data[VLEN]; - }; -}; - -template -_RAFT_DEVICE void compute_distance_to_random_nodes( - INDEX_T* const result_indices_ptr, // [num_pickup] - DISTANCE_T* const result_distances_ptr, // [num_pickup] - const float* const query_buffer, - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* const seed_ptr, // [num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hash_ptr, - const uint32_t hash_bitlen, - const uint32_t block_id = 0, - const uint32_t num_blocks = 1) -{ - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - constexpr unsigned vlen = get_vlen(); - constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - struct data_load_t dl_buff[nelem]; - uint32_t max_i = num_pickup; - if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } - for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) { - const bool valid_i = (i < num_pickup); - - INDEX_T best_index_team_local; - DISTANCE_T best_norm2_team_local = utils::get_max_value(); - for (uint32_t j = 0; j < num_distilation; j++) { - // Select a node randomly and compute the distance to it - INDEX_T seed_index; - DISTANCE_T norm2 = 0.0; - if (valid_i) { - // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); - uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); - if (seed_ptr && (gid < num_seeds)) { - seed_index = seed_ptr[gid]; - } else { - seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; - } -#pragma unroll - for (uint32_t e = 0; e < nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * seed_index)))[0]; - } -#pragma unroll - for (uint32_t e = 0; e < nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; -#pragma unroll - for (uint32_t v = 0; v < vlen; v++) { - const uint32_t kv = k + v; - // if (kv >= dataset_dim) break; - DISTANCE_T diff = query_buffer[device::swizzling(kv)]; - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } - - if (valid_i && (norm2 < best_norm2_team_local)) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { - if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } else { - result_distances_ptr[i] = utils::get_max_value(); - result_indices_ptr[i] = utils::get_max_value(); - } - } - } -} - -template -_RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_indices_ptr, - DISTANCE_T* const result_child_distances_ptr, - // query - const float* const query_buffer, - // [dataset_dim, dataset_size] - const DATA_T* const dataset_ptr, - const std::size_t dataset_dim, - const std::size_t dataset_ld, - // [knn_k, dataset_size] - const INDEX_T* const knn_graph, - const std::uint32_t knn_k, - // hashmap - INDEX_T* const visited_hashmap_ptr, - const std::uint32_t hash_bitlen, - const INDEX_T* const parent_indices, - const INDEX_T* const internal_topk_list, - const std::uint32_t search_width) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - // Read child indices of parents from knn graph and check if the distance - // computaiton is necessary. - for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { - const INDEX_T smem_parent_id = parent_indices[i / knn_k]; - INDEX_T child_id = invalid_index; - if (smem_parent_id != invalid_index) { - const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; - child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; - } - if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { - child_id = invalid_index; - } - } - result_child_indices_ptr[i] = child_id; - } - - constexpr unsigned vlen = get_vlen(); - constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - - // [Notice] - // Loading the query vector here from shared memory into registers reduces - // shared memory trafiic. However, register usage increase. The - // MAX_N_FRAGS below is used as the threshold to enable or disable this, - // but the appropriate value should be discussed. - constexpr unsigned N_FRAGS = (MAX_DATASET_DIM + TEAM_SIZE - 1) / TEAM_SIZE; - float query_frags[N_FRAGS]; - if (N_FRAGS <= MAX_N_FRAGS) { - // Pre-load query vectors into registers when register usage is not too large. -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - // if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - const unsigned kv = k + v; - const unsigned ev = (vlen * e) + v; - query_frags[ev] = query_buffer[device::swizzling(kv)]; - } - } - } - __syncthreads(); - - // Compute the distance to child nodes - std::uint32_t max_i = knn_k * search_width; - if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } - for (std::uint32_t tid = threadIdx.x; tid < max_i * TEAM_SIZE; tid += blockDim.x) { - const auto i = tid / TEAM_SIZE; - const bool valid_i = (i < (knn_k * search_width)); - INDEX_T child_id = invalid_index; - if (valid_i) { child_id = result_child_indices_ptr[i]; } - - DISTANCE_T norm2 = 0.0; - struct data_load_t dl_buff[nelem]; - if (child_id != invalid_index) { -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * child_id)))[0]; - } -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - DISTANCE_T diff; - if (N_FRAGS <= MAX_N_FRAGS) { - const unsigned ev = (vlen * e) + v; - diff = query_frags[ev]; - } else { - const unsigned kv = k + v; - diff = query_buffer[device::swizzling(kv)]; - } - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } - - // Store the distance - if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { - if (child_id != invalid_index) { - result_child_distances_ptr[i] = norm2; - } else { - result_child_distances_ptr[i] = utils::get_max_value(); - } - } - } -} - -} // namespace device -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/device_common.hpp b/cpp/include/cuvs/neighbors/detail/cagra/device_common.hpp deleted file mode 100644 index 82139ef59..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/device_common.hpp +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "utils.hpp" -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace device { - -// warpSize for compile time calculation -constexpr unsigned warp_size = 32; - -/** Xorshift rondem number generator. - * - * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. - */ -_RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) -{ - u ^= u >> 12; - u ^= u << 25; - u ^= u >> 27; - return u * 0x2545F4914F6CDD1DULL; -} - -template -_RAFT_DEVICE inline T swizzling(T x) -{ - // Address swizzling reduces bank conflicts in shared memory, but increases - // the amount of operation instead. - // return x; - return x ^ (x >> 5); // "x" must be less than 1024 -} - -} // namespace device -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/factory.cuh b/cpp/include/cuvs/neighbors/detail/cagra/factory.cuh deleted file mode 100644 index abe8d28a5..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/factory.cuh +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "search_multi_cta.cuh" -#include "search_multi_kernel.cuh" -#include "search_plan.cuh" -#include "search_single_cta.cuh" -#include - -namespace cuvs::neighbors::cagra::detail { - -template -class factory { - public: - /** - * Create a search structure for dataset with dim features. - */ - static std::unique_ptr> create( - raft::resources const& res, - search_params const& params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) - { - search_plan_impl_base plan(params, dim, graph_degree, topk); - switch (plan.max_dim) { - case 128: - switch (plan.team_size) { - case 8: return dispatch_kernel<128, 8>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - case 256: - switch (plan.team_size) { - case 16: return dispatch_kernel<256, 16>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - case 512: - switch (plan.team_size) { - case 32: return dispatch_kernel<512, 32>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - case 1024: - switch (plan.team_size) { - case 32: return dispatch_kernel<1024, 32>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim); - } - return std::unique_ptr>(); - } - - private: - template - static std::unique_ptr> dispatch_kernel( - raft::resources const& res, search_plan_impl_base& plan) - { - if (plan.algo == search_algo::SINGLE_CTA) { - return std::unique_ptr>( - new single_cta_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); - } else if (plan.algo == search_algo::MULTI_CTA) { - return std::unique_ptr>( - new multi_cta_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); - } else { - return std::unique_ptr>( - new multi_kernel_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); - } - } -}; -}; // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/fragment.hpp b/cpp/include/cuvs/neighbors/detail/cagra/fragment.hpp deleted file mode 100644 index 256e46627..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/fragment.hpp +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "device_common.hpp" -#include "utils.hpp" -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace device { - -namespace detail { -template -struct load_unit_t { - using type = uint4; -}; -template <> -struct load_unit_t<8> { - using type = std::uint64_t; -}; -template <> -struct load_unit_t<4> { - using type = std::uint32_t; -}; -template <> -struct load_unit_t<2> { - using type = std::uint16_t; -}; -template <> -struct load_unit_t<1> { - using type = std::uint8_t; -}; -} // namespace detail - -// One dataset or query vector is distributed within a warp and stored as `fragment`. -template -struct fragment_base {}; -template -struct fragment - : fragment_base()) == 0>::type> { - static constexpr unsigned num_elements = DIM / TEAM_SIZE; - using block_t = typename detail::load_unit_t()>::type; - static constexpr unsigned num_load_blocks = - num_elements * utils::size_of() / utils::size_of(); - - union { - T x[num_elements]; - block_t load_block[num_load_blocks]; - }; -}; - -// Load a vector from device/shared memory -template -_RAFT_DEVICE void load_vector_sync(device::fragment& frag, - const INPUT_T* const input_vector_ptr, - const unsigned input_vector_length, - const bool sync = true) -{ - const auto lane_id = threadIdx.x % TEAM_SIZE; - if (DIM == input_vector_length) { - for (unsigned i = 0; i < frag.num_load_blocks; i++) { - const auto vector_index = i * TEAM_SIZE + lane_id; - frag.load_block[i] = - reinterpret_cast::block_t*>( - input_vector_ptr)[vector_index]; - } - } else { - for (unsigned i = 0; i < frag.num_elements; i++) { - const auto vector_index = i * TEAM_SIZE + lane_id; - - INPUT_T v; - if (vector_index < input_vector_length) { - v = static_cast(input_vector_ptr[vector_index]); - } else { - v = static_cast(0); - } - - frag.x[i] = v; - } - } - if (sync) { __syncwarp(); } -} - -// Compute the square of the L2 norm of two vectors -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const device::fragment& b) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - for (unsigned i = 0; i < a.num_elements; i++) { - const auto diff = static_cast(a.x[i]) - static_cast(b.x[i]); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const device::fragment& b, - const float scale) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - for (unsigned i = 0; i < a.num_elements; i++) { - const auto diff = - static_cast((static_cast(a.x[i]) - static_cast(b.x[i])) * scale); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const T* b, // [DIM] - const float scale) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - const unsigned chunk_size = a.num_elements / a.num_load_blocks; - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); - const auto diff = static_cast(a.x[i] * scale) - static_cast(b[j] * scale); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE inline COMPUTE_T norm2x(const device::fragment& a, - const COMPUTE_T* b, // [dim] - const uint32_t dim, - const float scale) -{ - // Compute the thread-local norm2 - COMPUTE_T sum = 0; - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - if (dim == DIM) { - const unsigned chunk_size = a.num_elements / a.num_load_blocks; - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); - const auto diff = static_cast(a.x[i] * scale) - b[j]; - sum += diff * diff; - } - } else { - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = lane_id + (TEAM_SIZE * i); - if (j >= dim) break; - const auto diff = static_cast(a.x[i] * scale) - b[j]; - sum += diff * diff; - } - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE void print_fragment(const device::fragment& a) -{ - for (unsigned i = 0; i < TEAM_SIZE; i++) { - if ((threadIdx.x % TEAM_SIZE) == i) { - for (unsigned j = 0; j < a.num_elements; j++) { - RAFT_LOG_DEBUG("%+e ", static_cast(a.x[j])); - } - } - __syncwarp(); - } -} - -} // namespace device -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/graph_core.cuh b/cpp/include/cuvs/neighbors/detail/cagra/graph_core.cuh deleted file mode 100644 index 9734aa0e2..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/graph_core.cuh +++ /dev/null @@ -1,575 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "utils.hpp" - -namespace cuvs::neighbors::cagra::detail { -namespace graph { - -// unnamed namespace to avoid multiple definition error -namespace { -inline double cur_time(void) -{ - struct timeval tv; - gettimeofday(&tv, NULL); - return ((double)tv.tv_sec + (double)tv.tv_usec * 1e-6); -} - -template -__device__ inline void swap(T& val1, T& val2) -{ - T val0 = val1; - val1 = val2; - val2 = val0; -} - -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) -{ - if (key1 == key2) { return false; } - if ((key1 > key2) == ascending) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -template -RAFT_KERNEL kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] - const IdxT dataset_size, - const uint32_t dataset_dim, - IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree) -{ - const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize; - if (srcNode >= graph_size) { return; } - - const uint32_t lane_id = threadIdx.x % raft::WarpSize; - - float my_keys[numElementsPerThread]; - IdxT my_vals[numElementsPerThread]; - - // Compute distance from a src node to its neighbors - for (int k = 0; k < graph_degree; k++) { - const IdxT dstNode = knn_graph[k + static_cast(graph_degree) * srcNode]; - float dist = 0.0; - for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) { - float diff = spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * srcNode]) - - spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * dstNode]); - dist += diff * diff; - } - dist += __shfl_xor_sync(0xffffffff, dist, 1); - dist += __shfl_xor_sync(0xffffffff, dist, 2); - dist += __shfl_xor_sync(0xffffffff, dist, 4); - dist += __shfl_xor_sync(0xffffffff, dist, 8); - dist += __shfl_xor_sync(0xffffffff, dist, 16); - if (lane_id == (k % raft::WarpSize)) { - my_keys[k / raft::WarpSize] = dist; - my_vals[k / raft::WarpSize] = dstNode; - } - } - for (int k = graph_degree; k < raft::WarpSize * numElementsPerThread; k++) { - if (lane_id == k % raft::WarpSize) { - my_keys[k / raft::WarpSize] = utils::get_max_value(); - my_vals[k / raft::WarpSize] = utils::get_max_value(); - } - } - - // Sort by RAFT bitonic sort - raft::util::bitonic(true).sort(my_keys, my_vals); - - // Update knn_graph - for (int i = 0; i < numElementsPerThread; i++) { - const int k = i * raft::WarpSize + lane_id; - if (k < graph_degree) { - knn_graph[k + (static_cast(graph_degree) * srcNode)] = my_vals[i]; - } - } -} - -template -RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] - uint64_t* const stats) -{ - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - uint64_t* const num_retain = stats; - uint64_t* const num_full = stats + 1; - - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_size) { return; } - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - } - __syncthreads(); - - const uint64_t iA = nid; - if (iA >= graph_size) { return; } - - // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { - // if ( kDB < kAB ) - { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; - if (iB == iB_candidate) { - atomicAdd(smem_num_detour + kAB, 1); - break; - } - } - } - } - __syncthreads(); - } - - uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); - if (smem_num_detour[k] == 0) { num_edges_no_detour++; } - } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - - if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; - atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } - } -} - -template -RAFT_KERNEL kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) -{ - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; - - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; - - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } - } -} - -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) -{ - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } - } - return num; -} - -template -void shift_array(T* array, uint64_t num) -{ - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; - } -} -} // namespace - -template , memory_type::device>, - typename g_accessor = - host_device_accessor, memory_type::host>> -void sort_knn_graph( - raft::resources const& res, - raft::mdspan, raft::row_major, d_accessor> dataset, - raft::mdspan, raft::row_major, g_accessor> knn_graph) -{ - RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), - "dataset size is expected to have the same number of graph index size"); - const uint32_t dataset_size = dataset.extent(0); - const uint32_t dataset_dim = dataset.extent(1); - const DataT* dataset_ptr = dataset.data_handle(); - - const IdxT graph_size = dataset_size; - const uint32_t input_graph_degree = knn_graph.extent(1); - IdxT* const input_graph_ptr = knn_graph.data_handle(); - - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); - - // - // Sorting kNN graph - // - const double time_sort_start = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - - auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); - raft::copy(d_dataset.data_handle(), - dataset_ptr, - dataset_size * dataset_dim, - resource::get_cuda_stream(res)); - - raft::copy(d_input_graph.data_handle(), - input_graph_ptr, - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); - - void (*kernel_sort)( - const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); - if (input_graph_degree <= 32) { - constexpr int numElementsPerThread = 1; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 64) { - constexpr int numElementsPerThread = 2; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 128) { - constexpr int numElementsPerThread = 4; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 256) { - constexpr int numElementsPerThread = 8; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 512) { - constexpr int numElementsPerThread = 16; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 1024) { - constexpr int numElementsPerThread = 32; - kernel_sort = kern_sort; - } else { - RAFT_FAIL( - "The degree of input knn graph is too large (%u). " - "It must be equal to or smaller than %d.", - input_graph_degree, - 1024); - } - const auto block_size = 256; - const auto num_warps_per_block = block_size / raft::WarpSize; - const auto grid_size = (graph_size + num_warps_per_block - 1) / num_warps_per_block; - - RAFT_LOG_DEBUG("."); - kernel_sort<<>>( - d_dataset.data_handle(), - dataset_size, - dataset_dim, - d_input_graph.data_handle(), - graph_size, - input_graph_degree); - resource::sync_stream(res); - RAFT_LOG_DEBUG("."); - raft::copy(input_graph_ptr, - d_input_graph.data_handle(), - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); - RAFT_LOG_DEBUG("\n"); - - const double time_sort_end = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); -} - -template , memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - RAFT_LOG_DEBUG( - "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); - - RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), - "Each input array is expected to have the same number of rows"); - RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), - "output graph cannot have more columns than input graph"); - const uint32_t input_graph_degree = knn_graph.extent(1); - const uint32_t output_graph_degree = new_graph.extent(1); - auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); - const IdxT graph_size = new_graph.extent(0); - - { - // - // Prune kNN graph - // - auto d_detour_count = - raft::make_device_matrix(res, graph_size, input_graph_degree); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - graph_size * input_graph_degree * sizeof(uint8_t), - resource::get_cuda_stream(res))); - - auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - resource::get_cuda_stream(res))); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - const double time_prune_start = cur_time(); - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - // Copy input_graph_ptr over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view(input_graph_ptr, graph_size, input_graph_degree)); - - constexpr int MAX_DEGREE = 1024; - if (input_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%u). " - "It must be equal to or smaller than %d.", - input_graph_degree, - 1024); - } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, resource::get_cuda_stream(res))); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - input_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - host_matrix_view_from_device detour_count(res, d_detour_count.view()); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, resource::get_cuda_stream(res)); - const auto num_keep = host_stats.data_handle()[0]; - const auto num_full = host_stats.data_handle()[1]; - - // Create pruned kNN graph - uint32_t max_detour = 0; -#pragma omp parallel for reduction(max : max_detour) - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - for (uint32_t num_detour = 0; num_detour < output_graph_degree; num_detour++) { - if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } - for (uint64_t k = 0; k < input_graph_degree; k++) { - if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } - output_graph_ptr[pk + (output_graph_degree * i)] = - input_graph_ptr[k + (input_graph_degree * i)]; - pk += 1; - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - } - assert(pk == output_graph_degree); - } - // RAFT_LOG_DEBUG("# max_detour: %u\n", max_detour); - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Pruning time: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%\n", - time_prune_end - time_prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } - - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - - { - // - // Make reverse graph - // - const double time_make_start = cur_time(); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - resource::get_cuda_stream(res))); - - auto d_rev_graph_count = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - resource::get_cuda_stream(res))); - - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_vector(res, graph_size); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - } - resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - resource::get_cuda_stream(res)); - - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - resource::get_cuda_stream(res)); - } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - resource::get_cuda_stream(res)); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); - } - - { - // - // Replace some edges with reverse edges - // - const double time_replace_start = cur_time(); - - const uint64_t num_protected_edges = output_graph_degree / 2; - RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); - - constexpr int _omp_chunk = 1024; -#pragma omp parallel for schedule(dynamic, _omp_chunk) - for (uint64_t j = 0; j < graph_size; j++) { - uint64_t k = std::min(rev_graph_count.data_handle()[j], output_graph_degree); - while (k) { - k--; - uint64_t i = rev_graph.data_handle()[k + (output_graph_degree * j)]; - - uint64_t pos = - pos_in_array(i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos == output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), - num_shift); - output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; - } - if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { - RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); - } - } - RAFT_LOG_DEBUG("\n"); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); - - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); - } -} - -} // namespace graph -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/hashmap.hpp b/cpp/include/cuvs/neighbors/detail/cagra/hashmap.hpp deleted file mode 100644 index 2ac7438a9..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/hashmap.hpp +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "utils.hpp" -#include -#include -#include - -// #pragma GCC diagnostic push -// #pragma GCC diagnostic ignored -// #pragma GCC diagnostic pop -namespace cuvs::neighbors::cagra::detail { -namespace hashmap { - -_RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } - -template -_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen, unsigned FIRST_TID = 0) -{ - if (threadIdx.x < FIRST_TID) return; - for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); - } -} - -template -_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) -{ - // Open addressing is used for collision resolution - const uint32_t size = get_size(bitlen); - const uint32_t bit_mask = size - 1; -#if 1 - // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; - constexpr uint32_t stride = 1; -#else - // Double hashing - uint32_t index = key & bit_mask; - const uint32_t stride = (key >> bitlen) * 2 + 1; -#endif - for (unsigned i = 0; i < size; i++) { - const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); - if (old == ~static_cast(0)) { - return 1; - } else if (old == key) { - return 0; - } - index = (index + stride) & bit_mask; - } - return 0; -} - -template -_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) -{ - IdxT ret = 0; - if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } - for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { - ret |= __shfl_xor_sync(0xffffffff, ret, offset); - } - return ret; -} - -} // namespace hashmap -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta.cuh deleted file mode 100644 index 2cb11e343..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta.cuh +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_multi_cta_kernel.cuh" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible -#include "utils.hpp" -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -namespace cuvs::neighbors::cagra::detail { -namespace multi_cta_search { - -template - -struct search : public search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - uint32_t num_cta_per_query; - rmm::device_uvector intermediate_indices; - rmm::device_uvector intermediate_distances; - size_t topk_workspace_size; - rmm::device_uvector topk_workspace; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) - : search_plan_impl( - res, params, dim, graph_degree, topk), - intermediate_indices(0, resource::get_cuda_stream(res)), - intermediate_distances(0, resource::get_cuda_stream(res)), - topk_workspace(0, resource::get_cuda_stream(res)) - - { - set_params(res, params); - } - - void set_params(raft::resources const& res, const search_params& params) - { - constexpr unsigned muti_cta_itopk_size = 32; - this->itopk_size = muti_cta_itopk_size; - search_width = 1; - num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size); - result_buffer_size = itopk_size + search_width * graph_degree; - typedef raft::Pow2<32> AlignBytes; - unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); - // constexpr unsigned max_result_buffer_size = 256; - RAFT_EXPECTS(result_buffer_size_32 <= 256, "Result buffer size cannot exceed 256"); - - smem_size = sizeof(float) * max_dim + - (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(uint32_t) * search_width + sizeof(uint32_t); - RAFT_LOG_DEBUG("# smem_size: %u", smem_size); - - // - // Determine the thread block size - // - constexpr unsigned min_block_size = 64; - constexpr unsigned max_block_size = 1024; - uint32_t block_size = thread_block_size; - if (block_size == 0) { - block_size = min_block_size; - - // Increase block size according to shared memory requirements. - // If block size is 32, upper limit of shared memory size per - // thread block is set to 4096. This is GPU generation dependent. - constexpr unsigned ulimit_smem_size_cta32 = 4096; - while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { - block_size *= 2; - } - - // Increase block size to improve GPU occupancy when total number of - // CTAs (= num_cta_per_query * max_queries) is small. - cudaDeviceProp deviceProp = resource::get_device_properties(res); - RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); - while ((block_size < max_block_size) && - (graph_degree * search_width * team_size >= block_size * 2) && - (num_cta_per_query * max_queries <= - (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { - block_size *= 2; - } - } - RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); - RAFT_EXPECTS(block_size >= min_block_size, - "block_size cannot be smaller than min_block size, %u", - min_block_size); - RAFT_EXPECTS(block_size <= max_block_size, - "block_size cannot be larger than max_block size %u", - max_block_size); - thread_block_size = block_size; - - // - // Allocate memory for intermediate buffer and workspace. - // - uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - intermediate_indices.resize(num_intermediate_results * max_queries, - resource::get_cuda_stream(res)); - intermediate_distances.resize(num_intermediate_results * max_queries, - resource::get_cuda_stream(res)); - - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - - topk_workspace_size = _cuann_find_topk_bufferSize( - topk, max_queries, num_intermediate_results, utils::get_cuda_data_type()); - RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); - } - - void check(const uint32_t topk) override - { - RAFT_EXPECTS(num_cta_per_query * 32 >= topk, - "`num_cta_per_query` (%u) * 32 must be equal to or greater than " - "`topk` (%u) when 'search_mode' is \"multi-cta\". " - "(`num_cta_per_query`=max(`search_width`, `itopk_size`/32))", - num_cta_per_query, - topk); - } - - ~search() {} - - void operator()(raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - cudaStream_t stream = resource::get_cuda_stream(res); - - select_and_run( - dataset, - graph, - intermediate_indices.data(), - intermediate_distances.data(), - queries_ptr, - num_queries, - dev_seed_ptr, - num_executed_iterations, - topk, - thread_block_size, - result_buffer_size, - smem_size, - hash_bitlen, - hashmap.data(), - num_cta_per_query, - num_random_samplings, - rand_xor_mask, - num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, - sample_filter, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - // Select the top-k results from the intermediate results - const uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - _cuann_find_topk(topk, - num_queries, - num_intermediate_results, - intermediate_distances.data(), - num_intermediate_results, - intermediate_indices.data(), - num_intermediate_results, - topk_distances_ptr, - topk, - topk_indices_ptr, - topk, - topk_workspace.data(), - true, - NULL, - stream); - } -}; - -} // namespace multi_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh deleted file mode 100644 index 27e07ae5a..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // none_cagra_sample_filter -#include // RAFT_EXPLICIT - -namespace cuvs::neighbors::cagra::detail { -namespace multi_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run(raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, - DISTANCE_T* const topk_distances_ptr, - const DATA_T* const queries_ptr, - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, - uint32_t* const num_executed_iterations, - uint32_t topk, - uint32_t block_size, - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - cudaStream_t stream) RAFT_EXPLICIT; -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); - -instantiate_kernel_selection( - 32, 1024, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection -} // namespace multi_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh deleted file mode 100644 index 60dc34d47..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ /dev/null @@ -1,530 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible -#include "utils.hpp" -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -namespace cuvs::neighbors::cagra::detail { -namespace multi_cta_search { - -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [search_width] - const uint32_t search_width, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = j; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] - INDEX_T* indices, // [num_elements] - const uint32_t num_elements, - const uint32_t num_itopk // num_itopk <= num_elements -) -{ - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; - float key[N]; - INDEX_T val[N]; - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_elements) { - key[i] = distances[j]; - val[i] = indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store itopk sorted results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - distances[j] = key[i]; - indices[j] = val[i]; - } - } -} - -// -// multiple CTAs per single query -// -template -__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( - INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] - DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const size_t dataset_dim, - const size_t dataset_size, - const size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, - const uint32_t itopk_size, - const uint32_t search_width, - const uint32_t min_iteration, - const uint32_t max_iteration, - uint32_t* const num_executed_iterations, /* stats */ - SAMPLE_FILTER_T sample_filter) -{ - assert(dataset_dim <= MAX_DATASET_DIM); - - const auto num_queries = gridDim.y; - const auto query_id = blockIdx.y; - const auto num_cta_per_query = gridDim.x; - const auto cta_id = blockIdx.x; // local CTA ID - -#ifdef _CLK_BREAKDOWN - uint64_t clk_init = 0; - uint64_t clk_compute_1st_distance = 0; - uint64_t clk_topk = 0; - uint64_t clk_pickup_parents = 0; - uint64_t clk_compute_distance = 0; - uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ uint32_t smem[]; - - // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | - // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - uint32_t result_buffer_size = itopk_size + (search_width * graph_degree); - uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - assert(result_buffer_size_32 <= MAX_ELEMENTS); - - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto parent_indices_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto terminate_flag = reinterpret_cast(parent_indices_buffer + search_width); - -#if 0 - /* debug */ - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) { - result_indices_buffer[i] = utils::get_max_value(); - result_distances_buffer[i] = utils::get_max_value(); - } -#endif - const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - uint32_t block_id = cta_id + (num_cta_per_query * query_id); - uint32_t num_blocks = num_cta_per_query * num_queries; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - block_id, - num_blocks); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - uint32_t iter = 0; - while (1) { - // topk with bitonic sort - _CLK_START(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - _CLK_REC(clk_topk); - - if (iter + 1 == max_iteration) { - __syncthreads(); - break; - } - - // pick up next parents - _CLK_START(); - pickup_next_parents( - parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag); - _CLK_REC(clk_pickup_parents); - - __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - // constexpr unsigned max_n_frags = 16; - constexpr unsigned max_n_frags = 0; - device::compute_distance_to_child_nodes( - result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - search_width); - _CLK_REC(clk_compute_distance); - __syncthreads(); - - // Filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { - if (parent_indices_buffer[p] != invalid_index) { - const auto parent_id = - result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, parent_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_indices_buffer[p]] = invalid_index; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - __syncthreads(); - } - - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit - } - - if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } - -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) && - ((query_id * 3) % gridDim.y < 3)) { - RAFT_LOG_DEBUG( - "query, %d, thread, %d" - ", init, %d" - ", 1st_distance, %lu" - ", topk, %lu" - ", pickup_parents, %lu" - ", distance, %lu" - "\n", - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_pickup_parents, - clk_compute_distance); - } -#endif -} - -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - -template -struct search_kernel_config { - // Search kernel function type. Note that the actual values for the template value - // parameters do not matter, because they are not part of the function signature. The - // second to fourth value parameters will be selected by the choose_* functions below. - using kernel_t = decltype(&search_kernel); - - static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t - { - if (result_buffer_size <= 64) { - return search_kernel; - } else if (result_buffer_size <= 128) { - return search_kernel; - } else if (result_buffer_size <= 256) { - return search_kernel; - } - THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); - } -}; - -template -void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - // multi_cta_search (params struct) - uint32_t block_size, // - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - cudaStream_t stream) -{ - auto kernel = - search_kernel_config:: - choose_buffer_size(result_buffer_size, block_size); - - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // Initialize hash table - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap_ptr, hash_size, utils::get_max_value(), hash_size, num_queries, stream); - - dim3 block_dims(block_size, 1, 1); - dim3 grid_dims(num_cta_per_query, num_queries, 1); - RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %u smem", - block_size, - num_cta_per_query, - num_queries, - smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), - queries_ptr, - graph.data_handle(), - graph.extent(1), - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - hash_bitlen, - itopk_size, - search_width, - min_iterations, - max_iterations, - num_executed_iterations, - sample_filter); -} - -} // namespace multi_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel.cuh deleted file mode 100644 index e00390729..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "search_multi_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_multi_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_multi_kernel.cuh deleted file mode 100644 index 622a6a825..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_multi_kernel.cuh +++ /dev/null @@ -1,862 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "fragment.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel -#include "utils.hpp" -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -namespace cuvs::neighbors::cagra::detail { -namespace multi_kernel_search { - -template -RAFT_KERNEL set_value_kernel(T* const dev_ptr, const T val) -{ - *dev_ptr = val; -} - -template -RAFT_KERNEL set_value_kernel(T* const dev_ptr, const T val, const std::size_t count) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count) { return; } - dev_ptr[tid] = val; -} - -template -void set_value(T* const dev_ptr, const T val, cudaStream_t cuda_stream) -{ - set_value_kernel<<<1, 1, 0, cuda_stream>>>(dev_ptr, val); -} - -template -void set_value(T* const dev_ptr, const T val, const std::size_t count, cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count + block_size - 1) / block_size; - set_value_kernel<<>>(dev_ptr, val, count); -} - -template -RAFT_KERNEL get_value_kernel(T* const host_ptr, const T* const dev_ptr) -{ - *host_ptr = *dev_ptr; -} - -template -void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stream) -{ - get_value_kernel<<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr); -} - -// MAX_DATASET_DIM : must equal to or greater than dataset_dim -template -RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen) -{ - const auto ldb = hashmap::get_size(hash_bitlen); - const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) / TEAM_SIZE; - const uint32_t query_id = blockIdx.y; - if (global_team_index >= num_pickup) { return; } - // Load a query - device::fragment query_frag; - device::load_vector_sync(query_frag, queries_ptr + query_id * dataset_dim, dataset_dim); - - INDEX_T best_index_team_local; - DISTANCE_T best_norm2_team_local = utils::get_max_value(); - for (unsigned i = 0; i < num_distilation; i++) { - INDEX_T seed_index; - if (seed_ptr && (global_team_index < num_seeds)) { - seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; - } else { - // Chose a seed node randomly - seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; - } - device::fragment random_data_frag; - device::load_vector_sync( - random_data_frag, dataset_ptr + (dataset_ld * seed_index), dataset_dim); - - // Compute the norm of two data - const auto norm2 = device::norm2( - query_frag, - random_data_frag, - static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor) - /*, scale*/ - ); - - if (norm2 < best_norm2_team_local) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - const auto store_gmem_index = global_team_index + (ldr * query_id); - if (threadIdx.x % TEAM_SIZE == 0) { - if (hashmap::insert( - visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { - result_distances_ptr[store_gmem_index] = best_norm2_team_local; - result_indices_ptr[store_gmem_index] = best_index_team_local; - } else { - result_distances_ptr[store_gmem_index] = utils::get_max_value(); - result_indices_ptr[store_gmem_index] = utils::get_max_value(); - } - } -} - -// MAX_DATASET_DIM : must be equal to or greater than dataset_dim -template -void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_queries, - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::size_t ldr, // (*) ldr >= num_pickup - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen, - cudaStream_t const cuda_stream = 0) -{ - const auto block_size = 256u; - const auto num_teams_per_threadblock = block_size / TEAM_SIZE; - const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, - num_queries); - - random_pickup_kernel - <<>>(dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen); -} - -template -RAFT_KERNEL pickup_next_parents_kernel( - INDEX_T* const parent_candidates_ptr, // [num_queries, raft::lds] - const std::size_t raft::lds, // (*) raft::lds >= parent_candidates_size - const std::uint32_t parent_candidates_size, // - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - const std::size_t ldb = hashmap::get_size(hash_bitlen); - const uint32_t query_id = blockIdx.x; - if (threadIdx.x < 32) { - // pickup next parents with single warp - for (std::uint32_t i = threadIdx.x; i < parent_list_size; i += 32) { - parent_list_ptr[i + (ldd * query_id)] = utils::get_max_value(); - } - std::uint32_t parent_candidates_size_max = parent_candidates_size; - if (parent_candidates_size % 32) { - parent_candidates_size_max += 32 - (parent_candidates_size % 32); - } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < parent_candidates_size_max; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < parent_candidates_size) { - index = parent_candidates_ptr[j + (lds * query_id)]; - if ((index & index_msb_1_mask) == 0) { // check most significant bit - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < parent_list_size) { - parent_list_ptr[i + (ldd * query_id)] = j; - parent_candidates_ptr[j + (lds * query_id)] |= - index_msb_1_mask; // set most significant bit as used node - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= parent_list_size) { break; } - } - if ((num_new_parents > 0) && (threadIdx.x == 0)) { *terminate_flag = 0; } - } else if (small_hash_bitlen) { - // reset small-hash - hashmap::init(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, 32); - } - - if (small_hash_bitlen) { - __syncthreads(); - // insert internal-topk indices into small-hash - for (unsigned i = threadIdx.x; i < parent_candidates_size; i += blockDim.x) { - auto key = parent_candidates_ptr[i + (lds * query_id)] & - ~index_msb_1_mask; // clear most significant bit - hashmap::insert(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, key); - } - } -} - -template -void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, raft::lds] - const std::size_t raft::lds, // (*) raft::lds >= parent_candidates_size - const std::size_t parent_candidates_size, // - const std::size_t num_queries, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::size_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag, - cudaStream_t cuda_stream = 0) -{ - std::uint32_t block_size = 32; - if (small_hash_bitlen) { - block_size = 128; - while (parent_candidates_size > block_size) { - block_size *= 2; - } - block_size = min(block_size, (uint32_t)512); - } - pickup_next_parents_kernel - <<>>(parent_candidates_ptr, - raft::lds, - parent_candidates_size, - visited_hashmap_ptr, - hash_bitlen, - small_hash_bitlen, - parent_list_ptr, - ldd, - parent_list_size, - terminate_flag); -} - -template -RAFT_KERNEL compute_distance_to_child_nodes_kernel( - const INDEX_T* const parent_node_list, // [num_queries, search_width] - INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] - DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] - const std::size_t raft::lds, - const std::uint32_t search_width, - const DATA_T* const dataset_ptr, // [dataset_size, data_dim] - const std::uint32_t data_dim, - const std::uint32_t dataset_size, - const std::uint32_t dataset_ld, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter) -{ - const uint32_t ldb = hashmap::get_size(hash_bitlen); - const auto tid = threadIdx.x + blockDim.x * blockIdx.x; - const auto global_team_id = tid / TEAM_SIZE; - const auto query_id = blockIdx.y; - - if (global_team_id >= search_width * graph_degree) { return; } - - const std::size_t parent_list_index = - parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; - - if (parent_list_index == utils::get_max_value()) { return; } - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto parent_index = - parent_candidates_ptr[parent_list_index + (lds * query_id)] & ~index_msb_1_mask; - - if (parent_index == utils::get_max_value()) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - return; - } - const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); - - const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - - if (hashmap::insert( - visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { - device::fragment frag_target; - device::load_vector_sync(frag_target, dataset_ptr + (dataset_ld * child_id), data_dim); - - device::fragment frag_query; - device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim); - - const auto norm2 = device::norm2( - frag_target, - frag_query, - static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor)); - - if (threadIdx.x % TEAM_SIZE == 0) { - result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; - result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; - } - } else { - if (threadIdx.x % TEAM_SIZE == 0) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - } - } - - if constexpr (!std::is_same::value) { - if (!sample_filter(query_id, parent_index)) { - parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); - parent_distance_ptr[parent_list_index + (lds * query_id)] = - utils::get_max_value(); - } - } -} - -template -void compute_distance_to_child_nodes( - const INDEX_T* const parent_node_list, // [num_queries, search_width] - INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] - DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] - const std::size_t raft::lds, - const uint32_t search_width, - const DATA_T* const dataset_ptr, // [dataset_size, data_dim] - const std::uint32_t data_dim, - const std::uint32_t dataset_size, - const std::uint32_t dataset_ld, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] - const std::uint32_t num_queries, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter, - cudaStream_t cuda_stream = 0) -{ - const auto block_size = 128; - const dim3 grid_size( - (search_width * graph_degree + (block_size / TEAM_SIZE) - 1) / (block_size / TEAM_SIZE), - num_queries); - compute_distance_to_child_nodes_kernel - <<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - raft::lds, - search_width, - dataset_ptr, - data_dim, - dataset_size, - dataset_ld, - neighbor_graph_ptr, - graph_degree, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter); -} - -template -RAFT_KERNEL remove_parent_bit_kernel(const std::uint32_t num_queries, - const std::uint32_t num_topk, - INDEX_T* const topk_indices_ptr, // [ld, num_queries] - const std::uint32_t ld) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - uint32_t i_query = blockIdx.x; - if (i_query >= num_queries) return; - - for (unsigned i = threadIdx.x; i < num_topk; i += blockDim.x) { - topk_indices_ptr[i + (ld * i_query)] &= ~index_msb_1_mask; // clear most significant bit - } -} - -template -void remove_parent_bit(const std::uint32_t num_queries, - const std::uint32_t num_topk, - INDEX_T* const topk_indices_ptr, // [ld, num_queries] - const std::uint32_t ld, - cudaStream_t cuda_stream = 0) -{ - const std::size_t grid_size = num_queries; - const std::size_t block_size = 256; - remove_parent_bit_kernel<<>>( - num_queries, num_topk, topk_indices_ptr, ld); -} - -// This function called after the `remove_parent_bit` function -template -RAFT_KERNEL apply_filter_kernel(INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t raft::lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= result_buffer_size * num_queries) { return; } - const auto i = tid % result_buffer_size; - const auto j = tid / result_buffer_size; - const auto index = i + j * raft::lds; - - if (result_indices_ptr[index] != ~index_msb_1_mask && - !sample_filter(query_id_offset + j, result_indices_ptr[index])) { - result_indices_ptr[index] = utils::get_max_value(); - result_distances_ptr[index] = utils::get_max_value(); - } -} - -template -void apply_filter(INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t raft::lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter, - cudaStream_t cuda_stream) -{ - const std::uint32_t block_size = 256; - const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); - - apply_filter_kernel<<>>(result_indices_ptr, - result_distances_ptr, - raft::lds, - result_buffer_size, - num_queries, - query_id_offset, - sample_filter); -} - -template -RAFT_KERNEL batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] - const uint64_t ld_dst, - const T* const src, // [batch_size, ld_src] - const uint64_t ld_src, - const uint64_t count, - const uint64_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto i = tid % count; - const auto j = tid / count; - dst[i + (ld_dst * j)] = src[i + (ld_src * j)]; -} - -template -void batched_memcpy(T* const dst, // [batch_size, ld_dst] - const uint64_t ld_dst, - const T* const src, // [batch_size, ld_src] - const uint64_t ld_src, - const uint64_t count, - const uint64_t batch_size, - cudaStream_t cuda_stream) -{ - assert(ld_dst >= count); - assert(ld_src >= count); - constexpr uint32_t block_size = 256; - const auto grid_size = (batch_size * count + block_size - 1) / block_size; - batched_memcpy_kernel - <<>>(dst, ld_dst, src, ld_src, count, batch_size); -} - -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - -// result_buffer (work buffer) for "multi-kernel" -// +--------------------+------------------------------+-------------------+ -// | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | -// | | | | -// +--------------------+------------------------------+-------------------+ -// |<--- result_buffer_allocation_size --->| -// |<--- result_buffer_size --->| // Double buffer (A) -// |<--- result_buffer_size --->| // Double buffer (B) -template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; - rmm::device_uvector topk_hint; - rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; - rmm::device_uvector topk_workspace; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) - : search_plan_impl( - res, params, dim, graph_degree, topk), - result_indices(0, resource::get_cuda_stream(res)), - result_distances(0, resource::get_cuda_stream(res)), - parent_node_list(0, resource::get_cuda_stream(res)), - topk_hint(0, resource::get_cuda_stream(res)), - topk_workspace(0, resource::get_cuda_stream(res)), - terminate_flag(raft::resource::get_cuda_stream(res)) - { - set_params(res); - } - - void set_params(raft::resources const& res) - { - // - // Allocate memory for intermediate buffer and workspace. - // - result_buffer_size = itopk_size + (search_width * graph_degree); - result_buffer_allocation_size = result_buffer_size + itopk_size; - result_indices.resize(result_buffer_allocation_size * max_queries, - resource::get_cuda_stream(res)); - result_distances.resize(result_buffer_allocation_size * max_queries, - resource::get_cuda_stream(res)); - - parent_node_list.resize(max_queries * search_width, resource::get_cuda_stream(res)); - topk_hint.resize(max_queries, resource::get_cuda_stream(res)); - - size_t topk_workspace_size = _cuann_find_topk_bufferSize( - itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); - RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); - - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - } - - ~search() {} - - void operator()(raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - // Init hashmap - cudaStream_t stream = resource::get_cuda_stream(res); - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); - // Init topk_hint - if (topk_hint.size() > 0) { set_value(topk_hint.data(), 0xffffffffu, num_queries, stream); } - - // Choose initial entry point candidates at random - random_pickup( - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), - queries_ptr, - num_queries, - result_buffer_size, - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - result_indices.data(), - result_distances.data(), - result_buffer_allocation_size, - hashmap.data(), - hash_bitlen, - stream); - - unsigned iter = 0; - while (1) { - // Make an index list of internal top-k nodes - _cuann_find_topk(itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - topk_workspace.data(), - true, - topk_hint.data(), - stream); - - // termination (1) - if ((iter + 1 == max_iterations)) { - iter++; - break; - } - - if (iter + 1 >= min_iterations) { set_value(terminate_flag.data(), 1, stream); } - - // pickup parent nodes - uint32_t _small_hash_bitlen = 0; - if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } - pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - itopk_size, - num_queries, - hashmap.data(), - hash_bitlen, - _small_hash_bitlen, - parent_node_list.data(), - search_width, - search_width, - terminate_flag.data(), - stream); - - // termination (2) - if (iter + 1 >= min_iterations && terminate_flag.value(stream)) { - iter++; - break; - } - - // Compute distance to child nodes that are adjacent to the parent node - compute_distance_to_child_nodes( - parent_node_list.data(), - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - search_width, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), - graph.data_handle(), - graph.extent(1), - queries_ptr, - num_queries, - hashmap.data(), - hash_bitlen, - result_indices.data() + itopk_size, - result_distances.data() + itopk_size, - result_buffer_allocation_size, - sample_filter, - stream); - - iter++; - } // while ( 1 ) - auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size; - auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; - - if constexpr (!std::is_same::value) { - // Remove parent bit in search results - remove_parent_bit(num_queries, - result_buffer_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - stream); - - apply_filter( - result_indices.data() + (iter & 0x1) * itopk_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_buffer_size, - num_queries, - 0, - sample_filter, - stream); - - result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; - result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; - _cuann_find_topk(itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances_ptr, - result_buffer_allocation_size, - result_indices_ptr, - result_buffer_allocation_size, - topk_workspace.data(), - true, - topk_hint.data(), - stream); - } else { - // Remove parent bit in search results - remove_parent_bit( - num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); - } - - // Copy results from working buffer to final buffer - batched_memcpy(topk_indices_ptr, - topk, - result_indices_ptr, - result_buffer_allocation_size, - topk, - num_queries, - stream); - if (topk_distances_ptr) { - batched_memcpy(topk_distances_ptr, - topk, - result_distances_ptr, - result_buffer_allocation_size, - topk, - num_queries, - stream); - } - - if (num_executed_iterations) { - for (std::uint32_t i = 0; i < num_queries; i++) { - num_executed_iterations[i] = iter; - } - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } -}; - -} // namespace multi_kernel_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_plan.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_plan.cuh deleted file mode 100644 index f83418b5c..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_plan.cuh +++ /dev/null @@ -1,331 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "hashmap.hpp" -#include -// #include "search_single_cta.cuh" -// #include "topk_for_cagra/topk_core.cuh" - -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { - -struct search_plan_impl_base : public search_params { - int64_t max_dim; - int64_t dim; - int64_t graph_degree; - uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) - { - set_max_dim_team(dim); - if (algo == search_algo::AUTO) { - const size_t num_sm = raft::getMultiProcessorCount(); - if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { - algo = search_algo::SINGLE_CTA; - RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); - } else { - algo = search_algo::MULTI_CTA; - RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); - } - } - } - - void set_max_dim_team(int64_t dim) - { - max_dim = 128; - while (max_dim < dim && max_dim <= 1024) - max_dim *= 2; - // To keep binary size in check we limit only one team size specialization for each max_dim. - // TODO(tfeher): revise this decision. - switch (max_dim) { - case 128: team_size = 8; break; - case 256: team_size = 16; break; - case 512: team_size = 32; break; - case 1024: team_size = 32; break; - default: RAFT_LOG_DEBUG("Dataset dimension is too large (%lu)\n", dim); - } - } -}; - -template -struct search_plan_impl : public search_plan_impl_base { - int64_t hash_bitlen; - - size_t small_hash_bitlen; - size_t small_hash_reset_interval; - size_t hashmap_size; - uint32_t dataset_size; - uint32_t result_buffer_size; - - uint32_t smem_size; - uint32_t topk; - uint32_t num_seeds; - - rmm::device_uvector hashmap; - rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; - - search_plan_impl(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) - : search_plan_impl_base(params, dim, graph_degree, topk), - hashmap(0, raft::resource::get_cuda_stream(res)), - num_executed_iterations(0, raft::resource::get_cuda_stream(res)), - dev_seed(0, raft::resource::get_cuda_stream(res)), - num_seeds(0) - { - adjust_search_params(); - check_params(); - calc_hashmap_params(res); - set_max_dim_team(dim); - num_executed_iterations.resize(max_queries, raft::resource::get_cuda_stream(res)); - RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); - } - - virtual ~search_plan_impl() {} - - virtual void operator()( - raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const result_indices_ptr, // [num_queries, topk] - DISTANCE_T* const result_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk, - SAMPLE_FILTER_T sample_filter){}; - - void adjust_search_params() - { - uint32_t _max_iterations = max_iterations; - if (max_iterations == 0) { - if (algo == search_algo::MULTI_CTA) { - _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) - } else { - _max_iterations = - 1 + std::min((itopk_size / search_width) * 1.1, (itopk_size / search_width) + 10.0); - } - } - if (max_iterations < min_iterations) { _max_iterations = min_iterations; } - if (max_iterations < _max_iterations) { - RAFT_LOG_DEBUG( - "# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations); - max_iterations = _max_iterations; - } - if (itopk_size % 32) { - uint32_t itopk32 = itopk_size; - itopk32 += 32 - (itopk_size % 32); - RAFT_LOG_DEBUG("# internal_topk is increased from %lu to %u, as it must be multiple of 32.", - itopk_size, - itopk32); - itopk_size = itopk32; - } - } - - // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size - inline void calc_hashmap_params(raft::resources const& res) - { - // for multipel CTA search - uint32_t mc_num_cta_per_query = 0; - uint32_t mc_search_width = 0; - uint32_t mc_itopk_size = 0; - if (algo == search_algo::MULTI_CTA) { - mc_itopk_size = 32; - mc_search_width = 1; - mc_num_cta_per_query = max(search_width, itopk_size / 32); - RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); - RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); - RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); - } - - // Determine hash size (bit length) - hashmap_size = 0; - hash_bitlen = 0; - small_hash_bitlen = 0; - small_hash_reset_interval = 1024 * 1024; - float max_fill_rate = hashmap_max_fill_rate; - while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { - // - // The small-hash reduces hash table size by initializing the hash table - // for each iteraton and re-registering only the nodes that should not be - // re-visited in that iteration. Therefore, the size of small-hash should - // be determined based on the internal topk size and the number of nodes - // visited per iteration. - // - const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); - unsigned min_bitlen = 8; // 256 - unsigned max_bitlen = 13; // 8K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - if (hash_bitlen > max_bitlen) { - // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. - if (hashmap_mode == hash_mode::AUTO) { - hash_bitlen = 0; - break; - } else { - RAFT_FAIL( - "small-hash cannot be used because the required hash size exceeds the limit (%u)", - hashmap::get_size(max_bitlen)); - } - } - small_hash_bitlen = hash_bitlen; - // - // Sincc the hash table size is limited to a power of 2, the requirement, - // the maximum fill rate, may be satisfied even if the frequency of hash - // table reset is reduced to once every 2 or more iterations without - // changing the hash table size. In that case, reduce the reset frequency. - // - small_hash_reset_interval = 1; - while (1) { - const auto max_visited_nodes = - itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); - if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } - small_hash_reset_interval += 1; - } - break; - } - if (hash_bitlen == 0) { - // - // The size of hash table is determined based on the maximum number of - // nodes that may be visited before the search is completed and the - // maximum fill rate of the hash table. - // - uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); - if (algo == search_algo::MULTI_CTA) { - max_visited_nodes = mc_itopk_size + (mc_search_width * graph_degree * max_iterations); - max_visited_nodes *= mc_num_cta_per_query; - } - unsigned min_bitlen = 11; // 2K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); - } - - RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); - RAFT_LOG_DEBUG("# parent size = %lu", search_width); - RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); - RAFT_LOG_DEBUG("# max_iterations = %lu", max_iterations); - RAFT_LOG_DEBUG("# max_queries = %lu", max_queries); - RAFT_LOG_DEBUG("# hashmap mode = %s%s-%u", - (small_hash_bitlen > 0 ? "small-" : ""), - "hash", - hashmap::get_size(hash_bitlen)); - if (small_hash_bitlen > 0) { - RAFT_LOG_DEBUG("# small_hash_reset_interval = %lu", small_hash_reset_interval); - } - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); - RAFT_LOG_DEBUG("# hashmap size: %lu", hashmap_size); - if (hashmap_size >= 1024 * 1024 * 1024) { - RAFT_LOG_DEBUG(" (%.2f GiB)", (double)hashmap_size / (1024 * 1024 * 1024)); - } else if (hashmap_size >= 1024 * 1024) { - RAFT_LOG_DEBUG(" (%.2f MiB)", (double)hashmap_size / (1024 * 1024)); - } else if (hashmap_size >= 1024) { - RAFT_LOG_DEBUG(" (%.2f KiB)", (double)hashmap_size / (1024)); - } - } - - virtual void check(const uint32_t topk) - { - // For single-CTA and multi kernel - RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size); - } - - inline void check_params() - { - std::string error_message = ""; - - if (itopk_size > 1024) { - if (algo == search_algo::MULTI_CTA) { - } else { - error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) + - ") must be smaller or equal to 1024"); - } - } - if (algo != search_algo::SINGLE_CTA && algo != search_algo::MULTI_CTA && - algo != search_algo::MULTI_KERNEL) { - error_message += "An invalid kernel mode has been given: " + std::to_string((int)algo) + ""; - } - if (team_size != 0 && team_size != 4 && team_size != 8 && team_size != 16 && team_size != 32) { - error_message += - "`team_size` must be 0, 4, 8, 16 or 32. " + std::to_string(team_size) + " has been given."; - } - if (thread_block_size != 0 && thread_block_size != 64 && thread_block_size != 128 && - thread_block_size != 256 && thread_block_size != 512 && thread_block_size != 1024) { - error_message += "`thread_block_size` must be 0, 64, 128, 256 or 512. " + - std::to_string(thread_block_size) + " has been given."; - } - if (hashmap_min_bitlen > 20) { - error_message += "`hashmap_min_bitlen` must be equal to or smaller than 20. " + - std::to_string(hashmap_min_bitlen) + " has been given."; - } - if (hashmap_max_fill_rate < 0.1 || hashmap_max_fill_rate >= 0.9) { - error_message += - "`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " + - std::to_string(hashmap_max_fill_rate) + " has been given."; - } - if constexpr (!std::is_same::value) { - if (hashmap_mode == hash_mode::SMALL) { - error_message += "`SMALL` hash is not available when filtering"; - } else { - hashmap_mode = hash_mode::HASH; - } - } - if (algo == search_algo::MULTI_CTA) { - if (hashmap_mode == hash_mode::SMALL) { - error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\""; - } else { - hashmap_mode = hash_mode::HASH; - } - } - - if (error_message.length() != 0) { THROW("[CAGRA Error] %s", error_message.c_str()); } - } -}; - -// template -// struct search_plan { -// search_plan(raft::resources const& res, -// search_params param, -// int64_t dim, -// int64_t graph_degree) -// : plan(res, param, dim, graph_degree) -// { -// } -// void check(uint32_t topk) { plan.check(topk); } - -// // private: -// detail::search_plan_impl plan; -// }; -/** @} */ // end group cagra - -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta.cuh deleted file mode 100644 index 7a2a9392c..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta.cuh +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "search_single_cta_kernel.cuh" -#include "topk_by_radix.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk -#include "utils.hpp" -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -namespace cuvs::neighbors::cagra::detail { -namespace single_cta_search { - -template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - uint32_t num_itopk_candidates; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) - : search_plan_impl( - res, params, dim, graph_degree, topk) - { - set_params(res); - } - - ~search() {} - - inline void set_params(raft::resources const& res) - { - num_itopk_candidates = search_width * graph_degree; - result_buffer_size = itopk_size + num_itopk_candidates; - - typedef raft::Pow2<32> AlignBytes; - unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); - - constexpr unsigned max_itopk = 512; - RAFT_EXPECTS(itopk_size <= max_itopk, "itopk_size cannot be larger than %u", max_itopk); - - RAFT_LOG_DEBUG("# num_itopk_candidates: %u", num_itopk_candidates); - RAFT_LOG_DEBUG("# num_itopk: %lu", itopk_size); - // - // Determine the thread block size - // - constexpr unsigned min_block_size = 64; // 32 or 64 - constexpr unsigned min_block_size_radix = 256; - constexpr unsigned max_block_size = 1024; - // - const std::uint32_t topk_ws_size = 3; - const std::uint32_t base_smem_size = - sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width + - sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); - smem_size = base_smem_size; - if (num_itopk_candidates > 256) { - // Tentatively calculate the required share memory size when radix - // sort based topk is used, assuming the block size is the maximum. - if (itopk_size <= 256) { - smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); - } else { - smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); - } - } - - uint32_t block_size = thread_block_size; - if (block_size == 0) { - block_size = min_block_size; - - if (num_itopk_candidates > 256) { - // radix-based topk is used. - block_size = min_block_size_radix; - - // Internal topk values per thread must be equlal to or less than 4 - // when radix-sort block_topk is used. - while ((block_size < max_block_size) && (max_itopk / block_size > 4)) { - block_size *= 2; - } - } - - // Increase block size according to shared memory requirements. - // If block size is 32, upper limit of shared memory size per - // thread block is set to 4096. This is GPU generation dependent. - constexpr unsigned ulimit_smem_size_cta32 = 4096; - while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { - block_size *= 2; - } - - // Increase block size to improve GPU occupancy when batch size - // is small, that is, number of queries is low. - cudaDeviceProp deviceProp = resource::get_device_properties(res); - RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); - while ((block_size < max_block_size) && - (graph_degree * search_width * team_size >= block_size * 2) && - (max_queries <= (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { - block_size *= 2; - } - } - RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); - RAFT_EXPECTS(block_size >= min_block_size, - "block_size cannot be smaller than min_block size, %u", - min_block_size); - RAFT_EXPECTS(block_size <= max_block_size, - "block_size cannot be larger than max_block size %u", - max_block_size); - thread_block_size = block_size; - - if (num_itopk_candidates <= 256) { - RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used"); - } else { - RAFT_LOG_DEBUG("# radix-sort based topk routine is used"); - smem_size = base_smem_size; - if (itopk_size <= 256) { - constexpr unsigned MAX_ITOPK = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else { - constexpr unsigned MAX_ITOPK = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } - } - RAFT_LOG_DEBUG("# smem_size: %u", smem_size); - hashmap_size = 0; - if (small_hash_bitlen == 0) { - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - } - RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); - } - - void operator()(raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const result_indices_ptr, // [num_queries, topk] - DISTANCE_T* const result_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - cudaStream_t stream = resource::get_cuda_stream(res); - select_and_run( - dataset, - graph, - result_indices_ptr, - result_distances_ptr, - queries_ptr, - num_queries, - dev_seed_ptr, - num_executed_iterations, - topk, - num_itopk_candidates, - static_cast(thread_block_size), - smem_size, - hash_bitlen, - hashmap.data(), - small_hash_bitlen, - small_hash_reset_interval, - num_random_samplings, - rand_xor_mask, - num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, - sample_filter, - stream); - } -}; - -} // namespace single_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh deleted file mode 100644 index 615007a9e..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // RAFT_EXPLICIT - -namespace cuvs::neighbors::cagra::detail { -namespace single_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, - uint32_t smem_size, - int64_t hash_bitlen, - INDEX_T* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - cudaStream_t stream) RAFT_EXPLICIT; - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run( - 32, 1024, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, float, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, int8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, uint8_t, uint32_t, float, cuvs::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_select_and_run - -} // namespace single_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh deleted file mode 100644 index 8aec44dfa..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ /dev/null @@ -1,956 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_by_radix.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk -#include "utils.hpp" -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -namespace cuvs::neighbors::cagra::detail { -namespace single_cta_search { - -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(std::uint32_t* const terminate_flag, - INDEX_T* const next_parent_indices, - INDEX_T* const internal_topk_indices, - const std::size_t internal_topk_size, - const std::size_t dataset_size, - const std::uint32_t search_width) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - // if (threadIdx.x >= 32) return; - - for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - std::uint32_t itopk_max = internal_topk_size; - if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { - std::uint32_t jj = j; - if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } - INDEX_T index; - int new_parent = 0; - if (j < internal_topk_size) { - index = internal_topk_indices[jj]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = jj; - // set most significant bit as used node - internal_topk_indices[jj] |= index_msb_1_mask; - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk, - unsigned MULTI_WARPS = 0) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; - float key[N]; - IdxT val[N]; - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) - constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; - constexpr unsigned N = (max_candidates_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (warp_id < 2) { - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = lane_id + (32 * i); - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && jl < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - __syncthreads(); - - unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; - if (warp_id < num_warps_used) { - /* Temp_candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned kl = max_candidates_per_warp - 1 - jl; - unsigned j = jl + (max_candidates_per_warp * warp_id); - unsigned k = MAX_CANDIDATES - 1 - j; - if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; - float temp_key = candidate_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - if (warp_id < num_warps_used) { - /* Merge */ - bitonic::warp_merge(key, val, 32); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - } -} - -template -__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first, - unsigned MULTI_WARPS = 0) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_ITOPK + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - } else { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - key[i] = itopk_distances[device::swizzling(j)]; - val[i] = itopk_indices[device::swizzling(j)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - } - /* Merge candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk || k >= num_candidates) continue; - float candidate_key = candidate_distances[device::swizzling(k)]; - if (key[i] > candidate_key) { - key[i] = candidate_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) or more - constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; - constexpr unsigned N = (max_itopk_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itop results (not sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - __syncthreads(); - if (warp_id < 2) { - /* Load intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk) continue; - float temp_key = itopk_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = itopk_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - } - __syncthreads(); - /* Store itopk results (sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } - const uint32_t num_itopk_div2 = num_itopk / 2; - if (threadIdx.x < 3) { - // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. - work_buf[threadIdx.x] = num_itopk_div2; - } - __syncthreads(); - - // Merge candidates (using whole threads) - for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { - const unsigned j = num_itopk - 1 - k; - const float itopk_key = itopk_distances[device::swizzling(j)]; - const float candidate_key = candidate_distances[device::swizzling(k)]; - if (itopk_key > candidate_key) { - itopk_distances[device::swizzling(j)] = candidate_key; - itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; - if (j < num_itopk_div2) { - atomicMin(work_buf + 2, j); - } else { - atomicMin(work_buf + 1, j - num_itopk_div2); - } - } - } - __syncthreads(); - - // Merge 1st and 2nd half of itopk (using whole threads) - for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { - const unsigned k = j + num_itopk_div2; - float key_0 = itopk_distances[device::swizzling(j)]; - float key_1 = itopk_distances[device::swizzling(k)]; - if (key_0 > key_1) { - itopk_distances[device::swizzling(j)] = key_1; - itopk_distances[device::swizzling(k)] = key_0; - IdxT val_0 = itopk_indices[device::swizzling(j)]; - IdxT val_1 = itopk_indices[device::swizzling(k)]; - itopk_indices[device::swizzling(j)] = val_1; - itopk_indices[device::swizzling(k)] = val_0; - atomicMin(work_buf + 0, j); - } - } - if (threadIdx.x == blockDim.x - 1) { - if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } - } - __syncthreads(); - // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { - // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); - // } - - // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. - if (warp_id < 2) { - // Load intermedidate itopk results - const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 - for (unsigned i = 0; i < N; i++) { - unsigned k = num_itopk; - unsigned j = (N * lane_id) + i; - if (j < turning_point) { - k = j + (num_itopk_div2 * warp_id); - } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { - j -= (MAX_ITOPK / 2 - num_itopk_div2); - if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } - } - if (k < num_itopk) { - key[i] = itopk_distances[device::swizzling(k)]; - val[i] = itopk_indices[device::swizzling(k)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - const unsigned j = (N * lane_id) + i; - if (j < num_itopk_div2) { - unsigned k = j + (num_itopk_div2 * warp_id); - itopk_distances[device::swizzling(k)] = key[i]; - itopk_indices[device::swizzling(k)] = val[i]; - } - } - } - } -} - -template -__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first, - const unsigned MULTI_WARPS_1, - const unsigned MULTI_WARPS_2) -{ - // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( - candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1); - - // The results sorted above are merged with the internal intermediate top-k - // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first, - MULTI_WARPS_2); -} - -template -__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, - const size_t hashmap_bitlen, - const INDEX_T* itopk_indices, - const uint32_t itopk_size, - const uint32_t first_tid = 0) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - if (threadIdx.x < first_tid) return; - for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { - auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit - hashmap::insert(hashmap_ptr, hashmap_bitlen, key); - } -} - -template -__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count) -{ - for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) { - ptr[i] = fill; - } -} - -// One query one thread block -template -__launch_bounds__(1024, 1) RAFT_KERNEL - search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] - DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, // stride of dataset - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter) -{ - using LOAD_T = device::LOAD_128BIT_T; - const auto query_id = blockIdx.y; - -#ifdef _CLK_BREAKDOWN - std::uint64_t clk_init = 0; - std::uint64_t clk_compute_1st_distance = 0; - std::uint64_t clk_topk = 0; - std::uint64_t clk_reset_hash = 0; - std::uint64_t clk_pickup_parents = 0; - std::uint64_t clk_restore_hash = 0; - std::uint64_t clk_compute_distance = 0; - std::uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ std::uint32_t smem[]; - - // Layout of result_buffer - // +----------------------+------------------------------+---------+ - // | internal_top_k | neighbors of internal_top_k | padding | - // | | | upto 32 | - // +----------------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - std::uint32_t result_buffer_size = internal_topk + (search_width * graph_degree); - std::uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - const auto small_hash_size = hashmap::get_size(small_hash_bitlen); - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); - auto topk_ws = reinterpret_cast(parent_list_buffer + search_width); - auto terminate_flag = reinterpret_cast(topk_ws + 3); - auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); - - // A flag for filtering. - auto filter_flag = terminate_flag; - - const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } - if (threadIdx.x == 0) { - terminate_flag[0] = 0; - topk_ws[0] = ~0u; - } - - // Init hashmap - INDEX_T* local_visited_hashmap_ptr; - if (small_hash_bitlen) { - local_visited_hashmap_ptr = visited_hash_buffer; - } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - std::uint32_t iter = 0; - while (1) { - // sort - if constexpr (TOPK_BY_BITONIC_SORT) { - // [Notice] - // It is good to use multiple warps in topk_by_bitonic_sort() when - // batch size is small (short-latency), but it might not be always good - // when batch size is large (high-throughput). - // topk_by_bitonic_sort() consists of two operations: - // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; - // if MAX_ITOPK is greater than 256, the second operation used two warps. - const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; - - // reset small-hash table. - if ((iter + 1) % small_hash_reset_interval == 0) { - // Depending on the block size and the number of warps used in - // topk_by_bitonic_sort(), determine which warps are used to reset - // the small hash and whether they are performed in overlap with - // topk_by_bitonic_sort(). - _CLK_START(); - unsigned hash_start_tid; - if (blockDim.x == 32) { - hash_start_tid = 0; - } else if (blockDim.x == 64) { - if (multi_warps_1 || multi_warps_2) { - hash_start_tid = 0; - } else { - hash_start_tid = 32; - } - } else { - if (multi_warps_1 || multi_warps_2) { - hash_start_tid = 64; - } else { - hash_start_tid = 32; - } - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); - _CLK_REC(clk_reset_hash); - } - - // topk with bitonic sort - _CLK_START(); - if (std::is_same::value || - *filter_flag == 0) { - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); - __syncthreads(); - } else { - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - internal_topk, - false); - if (threadIdx.x == 0) { *terminate_flag = 0; } - } - _CLK_REC(clk_topk); - } else { - _CLK_START(); - // topk with radix block sort - topk_by_radix_sort{}( - internal_topk, - gridDim.x, - result_buffer_size, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - nullptr, - topk_ws, - true, - reinterpret_cast(smem_working_ptr)); - _CLK_REC(clk_topk); - - // reset small-hash table - if ((iter + 1) % small_hash_reset_interval == 0) { - _CLK_START(); - hashmap::init(local_visited_hashmap_ptr, hash_bitlen); - _CLK_REC(clk_reset_hash); - } - } - __syncthreads(); - - if (iter + 1 == max_iteration) { break; } - - // pick up next parents - if (threadIdx.x < 32) { - _CLK_START(); - pickup_next_parents(terminate_flag, - parent_list_buffer, - result_indices_buffer, - internal_topk, - dataset_size, - search_width); - _CLK_REC(clk_pickup_parents); - } - - // restore small-hash table by putting internal-topk indices in it - if ((iter + 1) % small_hash_reset_interval == 0) { - const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); - _CLK_START(); - hashmap_restore( - local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); - _CLK_REC(clk_restore_hash); - } - __syncthreads(); - - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - constexpr unsigned max_n_frags = 16; - device::compute_distance_to_child_nodes( - result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_list_buffer, - result_indices_buffer, - search_width); - __syncthreads(); - _CLK_REC(clk_compute_distance); - - // Filtering - if constexpr (!std::is_same::value) { - if (threadIdx.x == 0) { *filter_flag = 0; } - __syncthreads(); - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { - if (parent_list_buffer[p] != invalid_index) { - const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, parent_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_list_buffer[p]] = invalid_index; - *filter_flag = 1; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; - i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - top_k, - false); - __syncthreads(); - } - - for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { - unsigned j = i + (top_k * query_id); - unsigned ii = i; - if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit - } - if (threadIdx.x == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) { - RAFT_LOG_DEBUG( - "query, %d, thread, %d" - ", init, %d" - ", 1st_distance, %lu" - ", topk, %lu" - ", reset_hash, %lu" - ", pickup_parents, %lu" - ", restore_hash, %lu" - ", distance, %lu" - "\n", - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_reset_hash, - clk_pickup_parents, - clk_restore_hash, - clk_compute_distance); - } -#endif -} - -template -struct search_kernel_config { - using kernel_t = - decltype(&search_kernel); - - template - static auto choose_search_kernel(unsigned itopk_size) -> kernel_t - { - if (itopk_size <= 64) { - return search_kernel; - } else if (itopk_size <= 128) { - return search_kernel; - } else if (itopk_size <= 256) { - return search_kernel; - } else if (itopk_size <= 512) { - return search_kernel; - } - THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); - } - - static auto choose_itopk_and_mx_candidates(unsigned itopk_size, - unsigned num_itopk_candidates, - unsigned block_size) -> kernel_t - { - if (num_itopk_candidates <= 64) { - // use bitonic sort based topk - return choose_search_kernel<64, 1>(itopk_size); - } else if (num_itopk_candidates <= 128) { - return choose_search_kernel<128, 1>(itopk_size); - } else if (num_itopk_candidates <= 256) { - return choose_search_kernel<256, 1>(itopk_size); - } else { - // Radix-based topk is used - constexpr unsigned max_candidates = 32; // to avoid build failure - if (itopk_size <= 256) { - return search_kernel; - } else if (itopk_size <= 512) { - return search_kernel; - } - } - THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", - itopk_size, - num_itopk_candidates); - } -}; - -template -void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - INDEX_T* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - cudaStream_t stream) -{ - auto kernel = - search_kernel_config:: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), - queries_ptr, - graph.data_handle(), - graph.extent(1), - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - itopk_size, - search_width, - min_iterations, - max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} -} // namespace single_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel.cuh deleted file mode 100644 index 1d8fd8e30..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/search_single_cta_kernel.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "search_single_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_single_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/include/cuvs/neighbors/detail/cagra/topk_by_radix.cuh deleted file mode 100644 index 67173026b..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/topk_by_radix.cuh +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "topk_for_cagra/topk_core.cuh" - -namespace cuvs::neighbors::cagra::detail { -namespace single_cta_search { - -template -struct topk_by_radix_sort_base { - static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; - static constexpr std::uint32_t state_bit_lenght = 0; - static constexpr std::uint32_t vecLen = 2; // TODO -}; -template -struct topk_by_radix_sort : topk_by_radix_sort_base {}; - -template -struct topk_by_radix_sort> - : topk_by_radix_sort_base { - __device__ void operator()(uint32_t topk, - uint32_t batch_size, - uint32_t len_x, - const uint32_t* _x, - const IdxT* _in_vals, - uint32_t* _y, - IdxT* _out_vals, - uint32_t* work, - uint32_t* _hints, - bool sort, - uint32_t* _smem) - { - std::uint8_t* const state = reinterpret_cast(work); - topk_cta_11_core::state_bit_lenght, - topk_by_radix_sort_base::vecLen, - 64, - 32, - IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); - } -}; - -#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ - struct topk_by_radix_sort< \ - MAX_INTERNAL_TOPK, \ - IdxT, \ - std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ - : topk_by_radix_sort_base { \ - __device__ void operator()(uint32_t topk, \ - uint32_t batch_size, \ - uint32_t len_x, \ - const uint32_t* _x, \ - const IdxT* _in_vals, \ - uint32_t* _y, \ - IdxT* _out_vals, \ - uint32_t* work, \ - uint32_t* _hints, \ - bool sort, \ - uint32_t* _smem) \ - { \ - assert(blockDim.x >= V / 4); \ - std::uint8_t* state = (std::uint8_t*)work; \ - topk_cta_11_core::state_bit_lenght, \ - topk_by_radix_sort_base::vecLen, \ - V, \ - V / 4, \ - IdxT>( \ - topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ - } \ - }; -TOP_FUNC_PARTIAL_SPECIALIZATION(128); -TOP_FUNC_PARTIAL_SPECIALIZATION(256); -TOP_FUNC_PARTIAL_SPECIALIZATION(512); -TOP_FUNC_PARTIAL_SPECIALIZATION(1024); - -} // namespace single_cta_search -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk.h deleted file mode 100644 index 41141ac27..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace cuvs::neighbors::cagra::detail { - -// -size_t _cuann_find_topk_bufferSize(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - cudaDataType_t sampleDtype = CUDA_R_32F); - -// -template -void _cuann_find_topk(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const ValT* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - ValT* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort = false, - uint32_t* hint = NULL, - cudaStream_t stream = 0); - -#ifdef __CUDA_ARCH__ -#define CUDA_DEVICE_HOST_FUNC __device__ -#else -#define CUDA_DEVICE_HOST_FUNC -#endif -// -CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 128) -{ - if (size % unit) { size += unit - (size % unit); } - return size; -} -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh deleted file mode 100644 index a57fda93b..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ /dev/null @@ -1,1038 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include "topk.h" -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { -// -__device__ inline uint32_t convert(uint32_t x) -{ - if (x & 0x80000000) { - return x ^ 0xffffffff; - } else { - return x ^ 0x80000000; - } -} - -// -__device__ inline uint16_t convert(uint16_t x) -{ - if (x & 0x8000) { - return x ^ 0xffff; - } else { - return x ^ 0x8000; - } -} - -// -struct u32_vector { - uint1 x1; - uint2 x2; - uint4 x4; - ulonglong4 x8; -}; - -// -struct u16_vector { - ushort1 x1; - ushort2 x2; - ushort4 x4; - uint4 x8; -}; - -// -template -__device__ inline void load_u32_vector(struct u32_vector& vec, const uint32_t* x, int i) -{ - if (vecLen == 1) { - vec.x1 = ((uint1*)(x + i))[0]; - } else if (vecLen == 2) { - vec.x2 = ((uint2*)(x + i))[0]; - } else if (vecLen == 4) { - vec.x4 = ((uint4*)(x + i))[0]; - } else if (vecLen == 8) { - vec.x8 = ((ulonglong4*)(x + i))[0]; - } -} - -// -template -__device__ inline void load_u16_vector(struct u16_vector& vec, const uint16_t* x, int i) -{ - if (vecLen == 1) { - vec.x1 = ((ushort1*)(x + i))[0]; - } else if (vecLen == 2) { - vec.x2 = ((ushort2*)(x + i))[0]; - } else if (vecLen == 4) { - vec.x4 = ((ushort4*)(x + i))[0]; - } else if (vecLen == 8) { - vec.x8 = ((uint4*)(x + i))[0]; - } -} - -// -template -__device__ inline uint32_t get_element_from_u32_vector(struct u32_vector& vec, int i) -{ - uint32_t xi; - if (vecLen == 1) { - xi = convert(vec.x1.x); - } else if (vecLen == 2) { - if (i == 0) - xi = convert(vec.x2.x); - else - xi = convert(vec.x2.y); - } else if (vecLen == 4) { - if (i == 0) - xi = convert(vec.x4.x); - else if (i == 1) - xi = convert(vec.x4.y); - else if (i == 2) - xi = convert(vec.x4.z); - else - xi = convert(vec.x4.w); - } else if (vecLen == 8) { - if (i == 0) - xi = convert((uint32_t)(vec.x8.x & 0xffffffff)); - else if (i == 1) - xi = convert((uint32_t)(vec.x8.x >> 32)); - else if (i == 2) - xi = convert((uint32_t)(vec.x8.y & 0xffffffff)); - else if (i == 3) - xi = convert((uint32_t)(vec.x8.y >> 32)); - else if (i == 4) - xi = convert((uint32_t)(vec.x8.z & 0xffffffff)); - else if (i == 5) - xi = convert((uint32_t)(vec.x8.z >> 32)); - else if (i == 6) - xi = convert((uint32_t)(vec.x8.w & 0xffffffff)); - else - xi = convert((uint32_t)(vec.x8.w >> 32)); - } - return xi; -} - -// -template -__device__ inline uint16_t get_element_from_u16_vector(struct u16_vector& vec, int i) -{ - uint16_t xi; - if (vecLen == 1) { - xi = convert(vec.x1.x); - } else if (vecLen == 2) { - if (i == 0) - xi = convert(vec.x2.x); - else - xi = convert(vec.x2.y); - } else if (vecLen == 4) { - if (i == 0) - xi = convert(vec.x4.x); - else if (i == 1) - xi = convert(vec.x4.y); - else if (i == 2) - xi = convert(vec.x4.z); - else - xi = convert(vec.x4.w); - } else if (vecLen == 8) { - if (i == 0) - xi = convert((uint16_t)(vec.x8.x & 0xffff)); - else if (i == 1) - xi = convert((uint16_t)(vec.x8.x >> 16)); - else if (i == 2) - xi = convert((uint16_t)(vec.x8.y & 0xffff)); - else if (i == 3) - xi = convert((uint16_t)(vec.x8.y >> 16)); - else if (i == 4) - xi = convert((uint16_t)(vec.x8.z & 0xffff)); - else if (i == 5) - xi = convert((uint16_t)(vec.x8.z >> 16)); - else if (i == 6) - xi = convert((uint16_t)(vec.x8.w & 0xffff)); - else - xi = convert((uint16_t)(vec.x8.w >> 16)); - } - return xi; -} - -template -__device__ inline void block_scan(const T input, T& output) -{ - switch (blockDim.x) { - case 32: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 64: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 128: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 256: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 512: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 1024: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - default: break; - } -} - -// -template -__device__ inline void update_histogram(int itr, - uint32_t thread_id, - uint32_t num_threads, - uint32_t hint, - uint32_t threshold, - uint32_t& num_bins, - uint32_t& shift, - const T* x, // [nx,] - uint32_t nx, - uint32_t* hist, // [num_bins] - uint8_t* state, - uint32_t* output, // [topk] - uint32_t* output_count) -{ - if (sizeof(T) == 4) { - // 32-bit (uint32_t) - // itr:0, calculate histogram with 11 bits from bit-21 to bit-31 - // itr:1, calculate histogram with 11 bits from bit-10 to bit-20 - // itr:2, calculate histogram with 10 bits from bit-0 to bit-9 - if (itr == 0) { - shift = 21; - num_bins = 2048; - } else if (itr == 1) { - shift = 10; - num_bins = 2048; - } else { - shift = 0; - num_bins = 1024; - } - } else if (sizeof(T) == 2) { - // 16-bit (uint16_t) - // itr:0, calculate histogram with 8 bits from bit-8 to bit-15 - // itr:1, calculate histogram with 8 bits from bit-0 to bit-7 - if (itr == 0) { - shift = 8; - num_bins = 256; - } else { - shift = 0; - num_bins = 256; - } - } else { - return; - } - if (itr > 0) { - for (int i = threadIdx.x; i < num_bins; i += blockDim.x) { - hist[i] = 0; - } - __syncthreads(); - } - - // (*) Note that 'thread_id' may be different from 'threadIdx.x', - // and 'num_threads' may be different from 'blockDim.x' - int ii = 0; - for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { - uint8_t iState = 0; - if ((stateBitLen == 8) && (itr > 0)) { - iState = state[thread_id + (num_threads * ii)]; - if (iState == (uint8_t)0xff) continue; - } -#pragma unroll - for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - const int iv = i + (num_threads * v); - if (iv >= nx) break; - - struct u32_vector x_u32_vec; - struct u16_vector x_u16_vec; - if (sizeof(T) == 4) { - load_u32_vector(x_u32_vec, (const uint32_t*)x, iv); - } else { - load_u16_vector(x_u16_vec, (const uint16_t*)x, iv); - } -#pragma unroll - for (int u = 0; u < vecLen; u++) { - const int ivu = iv + u; - if (ivu >= nx) break; - - uint8_t mask = (uint8_t)0x1 << (v + u); - if ((stateBitLen == 8) && (iState & mask)) continue; - - uint32_t xi; - if (sizeof(T) == 4) { - xi = get_element_from_u32_vector(x_u32_vec, u); - } else { - xi = get_element_from_u16_vector(x_u16_vec, u); - } - if ((xi > hint) && (itr == 0)) { - if (stateBitLen == 8) { iState |= mask; } - } else if (xi < threshold) { - if (stateBitLen == 8) { - // If the condition is already met, record the index. - output[atomicAdd(output_count, 1)] = ivu; - iState |= mask; - } - } else { - const uint32_t k = (xi - threshold) >> shift; // 0 <= k - if (k >= num_bins) { - if (stateBitLen == 8) { iState |= mask; } - } else if (k + 1 < num_bins) { - // Update histogram - atomicAdd(&(hist[k + 1]), 1); - } - } - } - } - if (stateBitLen == 8) { state[thread_id + (num_threads * ii)] = iState; } - } - __syncthreads(); -} - -template -__device__ inline void select_best_index_for_next_threshold_core(uint32_t& my_index, - uint32_t& my_csum, - const unsigned num_bins, - const uint32_t* const hist, - const uint32_t nx_below_threshold, - const uint32_t max_threshold, - const uint32_t threshold, - const uint32_t shift, - const uint32_t topk) -{ - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - if (num_bins == 2048) { - constexpr int n_data = 2048 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; - break; - } - } else if (num_bins == 1024) { - constexpr int n_data = 1024 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; - break; - } - } -} - -// -__device__ inline void select_best_index_for_next_threshold( - const uint32_t topk, - const uint32_t threshold, - const uint32_t max_threshold, - const uint32_t nx_below_threshold, - const uint32_t num_bins, - const uint32_t shift, - const uint32_t* const hist, // [num_bins] - uint32_t* const best_index, - uint32_t* const best_csum) -{ - // Scan the histogram ('hist') and compute csum. Then, find the largest - // index under the condition that the sum of the number of elements found - // so far ('nx_below_threshold') and the csum value does not exceed the - // topk value. - uint32_t my_index = 0xffffffff; - uint32_t my_csum = 0; - if (num_bins <= blockDim.x) { - uint32_t csum = 0; - if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } - detail::block_scan(csum, csum); - if (threadIdx.x < num_bins) { - const uint32_t index = threadIdx.x; - if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { - my_index = index; - my_csum = csum; - } - } - } else { - switch (blockDim.x) { - case 64: - select_best_index_for_next_threshold_core<64>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 128: - select_best_index_for_next_threshold_core<128>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 256: - select_best_index_for_next_threshold_core<256>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 512: - select_best_index_for_next_threshold_core<512>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 1024: - select_best_index_for_next_threshold_core<1024>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - } - } - if (threadIdx.x < num_bins) { - const int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); - if ((threadIdx.x & 0x1f) == laneid) { - const uint32_t old_index = atomicMax(best_index, my_index); - if (old_index < my_index) { atomicMax(best_csum, my_csum); } - } - } - __syncthreads(); -} - -// -template -__device__ inline void output_index_below_threshold(const uint32_t topk, - const uint32_t thread_id, - const uint32_t num_threads, - const uint32_t threshold, - const uint32_t nx_below_threshold, - const T* const x, // [nx,] - const uint32_t nx, - const uint8_t* state, - uint32_t* const output, // [topk] - uint32_t* const output_count, - uint32_t* const output_count_eq) -{ - int ii = 0; - for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { - uint8_t iState = 0; - if (stateBitLen == 8) { - iState = state[thread_id + (num_threads * ii)]; - if (iState == (uint8_t)0xff) continue; - } -#pragma unroll - for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - const int iv = i + (num_threads * v); - if (iv >= nx) break; - - struct u32_vector u32_vec; - struct u16_vector u16_vec; - if (sizeof(T) == 4) { - load_u32_vector(u32_vec, (const uint32_t*)x, iv); - } else { - load_u16_vector(u16_vec, (const uint16_t*)x, iv); - } -#pragma unroll - for (int u = 0; u < vecLen; u++) { - const int ivu = iv + u; - if (ivu >= nx) break; - - const uint8_t mask = (uint8_t)0x1 << (v + u); - if ((stateBitLen == 8) && (iState & mask)) continue; - - uint32_t xi; - if (sizeof(T) == 4) { - xi = get_element_from_u32_vector(u32_vec, u); - } else { - xi = get_element_from_u16_vector(u16_vec, u); - } - if (xi < threshold) { - output[atomicAdd(output_count, 1)] = ivu; - } else if (xi == threshold) { - // (*) If the value is equal to the threshold, the index - // processed first is recorded. Cause of non-determinism. - if (nx_below_threshold + atomicAdd(output_count_eq, 1) < topk) { - output[atomicAdd(output_count, 1)] = ivu; - } - } - } - } - } -} - -// -template -__device__ inline void swap(T& val1, T& val2) -{ - const T val0 = val1; - val1 = val2; - val2 = val0; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2) -{ - if (key1 > key2) { - swap(key1, key2); - return true; - } - return false; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2) -{ - if (key1 > key2) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) -{ - if (key1 == key2) { return false; } - if ((key1 > key2) == ascending) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -// -template -__device__ inline T max_value_of(); -template <> -__device__ inline float max_value_of() -{ - return FLT_MAX; -} -template <> -__device__ inline uint32_t max_value_of() -{ - return ~0u; -} - -template -__device__ __host__ inline uint32_t get_state_size(uint32_t len_x) -{ -#ifdef __CUDA_ARCH__ - const uint32_t num_threads = blockDim.x; -#else - const uint32_t num_threads = BLOCK_SIZE; -#endif - if (stateBitLen == 8) { - uint32_t numElements_perThread = (len_x + num_threads - 1) / num_threads; - uint32_t numState_perThread = (numElements_perThread + stateBitLen - 1) / stateBitLen; - return numState_perThread * num_threads; - } - return 0; -} - -// -template -__device__ inline void topk_cta_11_core(uint32_t topk, - uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - const ValT* _in_vals, // [size_batch, ld_iv,] - uint32_t* _y, // [size_batch, ld_y,] - ValT* _out_vals, // [size_batch, ld_ov,] - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hint, - bool sort, - uint32_t* _smem) -{ - uint32_t* const smem_out_vals = _smem; - uint32_t* const hist = &(_smem[2 * maxTopk]); - uint32_t* const best_index = &(_smem[2 * maxTopk + 2048]); - uint32_t* const best_csum = &(_smem[2 * maxTopk + 2048 + 3]); - - const uint32_t num_threads = blockDim.x; - const uint32_t thread_id = threadIdx.x; - uint32_t nx = len_x; - const uint32_t* const x = _x; - const ValT* in_vals = NULL; - if (_in_vals) { in_vals = _in_vals; } - uint32_t* y = NULL; - if (_y) { y = _y; } - ValT* out_vals = NULL; - if (_out_vals) { out_vals = _out_vals; } - uint8_t* state = _state; - const uint32_t hint = (_hint == NULL ? ~0u : *_hint); - - // Initialize shared memory - for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { - _smem[i] = 0; - } - uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); - uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); - uint32_t threshold = 0; - uint32_t nx_below_threshold = 0; - __syncthreads(); - - // - // Search for the maximum threshold that satisfies "(x < threshold).sum() <= topk". - // -#pragma unroll - for (int j = 0; j < 3; j += 1) { - uint32_t num_bins; - uint32_t shift; - - update_histogram(j, - thread_id, - num_threads, - hint, - threshold, - num_bins, - shift, - x, - nx, - hist, - state, - smem_out_vals, - output_count); - select_best_index_for_next_threshold(topk, - threshold, - hint, - nx_below_threshold, - num_bins, - shift, - hist, - best_index + j, - best_csum + j); - - threshold += (best_index[j] << shift); - nx_below_threshold += best_csum[j]; - if (nx_below_threshold == topk) break; - } - - if ((_hint != NULL) && (thread_id == 0)) { *_hint = min(threshold, hint); } - - // - // Output index that satisfies "x[i] < threshold". - // - output_index_below_threshold(topk, - thread_id, - num_threads, - threshold, - nx_below_threshold, - x, - nx, - state, - smem_out_vals, - output_count, - output_count_eq); - __syncthreads(); - -#ifdef CUANN_DEBUG - if (thread_id == 0 && output_count[0] < topk) { - RAFT_LOG_DEBUG( - "# i_batch:%d, topk:%d, output_count:%d, nx_below_threshold:%d, threshold:%08x\n", - i_batch, - topk, - output_count[0], - nx_below_threshold, - threshold); - } -#endif - - if (!sort) { - for (int k = thread_id; k < topk; k += blockDim.x) { - const uint32_t i = smem_out_vals[k]; - if (y) { y[k] = x[i]; } - if (out_vals) { - if (in_vals) { - out_vals[k] = in_vals[i]; - } else { - out_vals[k] = i; - } - } - } - return; - } - - constexpr int numTopkPerThread = maxTopk / numSortThreads; - float my_keys[numTopkPerThread]; - ValT my_vals[numTopkPerThread]; - - // Read keys and values to registers - if (thread_id < numSortThreads) { - for (int i = 0; i < numTopkPerThread; i++) { - const int k = thread_id + (numSortThreads * i); - if (k < topk) { - const int j = smem_out_vals[k]; - my_keys[i] = ((float*)x)[j]; - if (in_vals) { - my_vals[i] = in_vals[j]; - } else { - my_vals[i] = j; - } - } else { - my_keys[i] = FLT_MAX; - my_vals[i] = ~static_cast(0); - } - } - } - - uint32_t mask = 1; - - // Sorting by thread - if (thread_id < numSortThreads) { - const bool ascending = ((thread_id & mask) == 0); - if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); - } else { - for (int j = 0; j < numTopkPerThread / 2; j += 1) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); - } -#pragma unroll - for (int i = 1; i < numTopkPerThread - 1; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); - } - } - } - } - - // Bitonic Sorting - while (mask < numSortThreads) { - uint32_t next_mask = mask << 1; - - for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - const bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); - if (curr_mask >= 32) { - // inter warp - ValT* const smem_vals = reinterpret_cast(_smem); // [maxTopk] - float* const smem_keys = - reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] - __syncthreads(); - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - smem_keys[thread_id + (numSortThreads * i)] = my_keys[i]; - smem_vals[thread_id + (numSortThreads * i)] = my_vals[i]; - } - } - __syncthreads(); - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; - ValT opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } - } else { - // intra warp - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - ValT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } - } - } - - if (thread_id < numSortThreads) { - const bool ascending = ((thread_id & next_mask) == 0); - if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); - } else { -#pragma unroll - for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - const int j = i ^ curr_mask; - if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); - } - } - } - } - mask = next_mask; - } - - // Write sorted keys and values - if (thread_id < numSortThreads) { - for (int i = 0; i < numTopkPerThread; i++) { - const int k = i + (numTopkPerThread * thread_id); - if (k < topk) { - if (y) { y[k] = reinterpret_cast(my_keys)[i]; } - if (out_vals) { out_vals[k] = my_vals[i]; } - } - } - } -} - -namespace { - -// -constexpr std::uint32_t NUM_THREADS = 1024; // DO NOT CHANGE -constexpr std::uint32_t STATE_BIT_LENGTH = 8; // 0: state not used, 8: state used -constexpr std::uint32_t MAX_VEC_LENGTH = 4; // 1, 2, 4 or 8 - -// -// -int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) -{ - int vecLen = min(maxVecLen, (int)MAX_VEC_LENGTH); - while ((maxSamples % vecLen) != 0) { - vecLen /= 2; - } - return vecLen; -} -} // unnamed namespace - -template -__launch_bounds__(1024, 1) RAFT_KERNEL - kern_topk_cta_11(uint32_t topk, - uint32_t size_batch, - uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - uint32_t ld_x, - const ValT* _in_vals, // [size_batch, ld_iv,] - uint32_t ld_iv, - uint32_t* _y, // [size_batch, ld_y,] - uint32_t ld_y, - ValT* _out_vals, // [size_batch, ld_ov,] - uint32_t ld_ov, - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hints, // [size_batch,] - bool sort) -{ - const uint32_t i_batch = blockIdx.x; - if (i_batch >= size_batch) return; - - constexpr uint32_t smem_len = 2 * maxTopk + 2048 + 8; - static_assert(maxTopk * (1 + utils::size_of() / utils::size_of()) <= smem_len, - "maxTopk * sizeof(ValT) must be smaller or equal to 8192 byte"); - __shared__ uint32_t _smem[smem_len]; - - topk_cta_11_core( - topk, - len_x, - (_x == NULL ? NULL : _x + i_batch * ld_x), - (_in_vals == NULL ? NULL : _in_vals + i_batch * ld_iv), - (_y == NULL ? NULL : _y + i_batch * ld_y), - (_out_vals == NULL ? NULL : _out_vals + i_batch * ld_ov), - (_state == NULL ? NULL : _state + i_batch * get_state_size(len_x)), - (_hints == NULL ? NULL : _hints + i_batch), - sort, - _smem); -} - -// -size_t inline _cuann_find_topk_bufferSize(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - cudaDataType_t sampleDtype) -{ - constexpr int numThreads = NUM_THREADS; - constexpr int stateBitLen = STATE_BIT_LENGTH; - assert(stateBitLen == 0 || stateBitLen == 8); - - size_t workspaceSize = 1; - // state - if (stateBitLen == 8) { - workspaceSize = _cuann_aligned( - sizeof(uint8_t) * get_state_size(numElements) * sizeBatch); - } - - return workspaceSize; -} - -template -inline void _cuann_find_topk(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const ValT* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - ValT* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort, - uint32_t* hints, - cudaStream_t stream) -{ - assert(ldIK >= numElements); - assert(ldIV >= numElements); - assert(ldOK >= topK); - assert(ldOV >= topK); - - constexpr int numThreads = NUM_THREADS; - constexpr int stateBitLen = STATE_BIT_LENGTH; - assert(stateBitLen == 0 || stateBitLen == 8); - - uint8_t* state = NULL; - if (stateBitLen == 8) { state = (uint8_t*)workspace; } - - dim3 threads(numThreads, 1, 1); - dim3 blocks(sizeBatch, 1, 1); - - void (*cta_kernel)(uint32_t, - uint32_t, - uint32_t, - const uint32_t*, - uint32_t, - const ValT*, - uint32_t, - uint32_t*, - uint32_t, - ValT*, - uint32_t, - uint8_t*, - uint32_t*, - bool) = nullptr; - - // V:vecLen, K:maxTopk, T:numSortThreads -#define SET_KERNEL_VKT(V, K, T, ValT) \ - do { \ - assert(numThreads >= T); \ - assert((K % T) == 0); \ - assert((K / T) <= 4); \ - cta_kernel = kern_topk_cta_11; \ - } while (0) - - // V: vecLen -#define SET_KERNEL_V(V, ValT) \ - do { \ - if (topK <= 32) { \ - SET_KERNEL_VKT(V, 32, 32, ValT); \ - } else if (topK <= 64) { \ - SET_KERNEL_VKT(V, 64, 32, ValT); \ - } else if (topK <= 96) { \ - SET_KERNEL_VKT(V, 96, 32, ValT); \ - } else if (topK <= 128) { \ - SET_KERNEL_VKT(V, 128, 32, ValT); \ - } else if (topK <= 192) { \ - SET_KERNEL_VKT(V, 192, 64, ValT); \ - } else if (topK <= 256) { \ - SET_KERNEL_VKT(V, 256, 64, ValT); \ - } else if (topK <= 384) { \ - SET_KERNEL_VKT(V, 384, 128, ValT); \ - } else if (topK <= 512) { \ - SET_KERNEL_VKT(V, 512, 128, ValT); \ - } else if (topK <= 768) { \ - SET_KERNEL_VKT(V, 768, 256, ValT); \ - } else if (topK <= 1024) { \ - SET_KERNEL_VKT(V, 1024, 256, ValT); \ - } \ - /* else if (topK <= 1536) { SET_KERNEL_VKT(V, 1536, 512); } */ \ - /* else if (topK <= 2048) { SET_KERNEL_VKT(V, 2048, 512); } */ \ - /* else if (topK <= 3072) { SET_KERNEL_VKT(V, 3072, 1024); } */ \ - /* else if (topK <= 4096) { SET_KERNEL_VKT(V, 4096, 1024); } */ \ - else { \ - RAFT_FAIL("topk must be lower than or equal to 1024"); \ - } \ - } while (0) - - int _vecLen = _get_vecLen(ldIK, 2); - if (_vecLen == 2) { - SET_KERNEL_V(2, ValT); - } else if (_vecLen == 1) { - SET_KERNEL_V(1, ValT); - } - - cta_kernel<<>>(topK, - sizeBatch, - numElements, - (const uint32_t*)inputKeys, - ldIK, - inputVals, - ldIV, - (uint32_t*)outputKeys, - ldOK, - outputVals, - ldOV, - state, - hints, - sort); - - return; -} -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/cagra/utils.hpp b/cpp/include/cuvs/neighbors/detail/cagra/utils.hpp deleted file mode 100644 index e1cbcc878..000000000 --- a/cpp/include/cuvs/neighbors/detail/cagra/utils.hpp +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::cagra::detail { -namespace utils { -template -inline cudaDataType_t get_cuda_data_type(); -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_32F; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_16F; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_8I; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_8U; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_32U; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_64U; -} - -template -constexpr unsigned size_of(); -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 1; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 1; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 2; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 4; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 8; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 16; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 32; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 4; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 2; -} - -// max values for data types -template -union fp_conv { - BS_T bs; - FP_T fp; -}; -template -_RAFT_HOST_DEVICE inline T get_max_value(); -template <> -_RAFT_HOST_DEVICE inline float get_max_value() -{ - return FLT_MAX; -}; -template <> -_RAFT_HOST_DEVICE inline half get_max_value() -{ - return fp_conv{.bs = 0x7aff}.fp; -}; -template <> -_RAFT_HOST_DEVICE inline std::uint32_t get_max_value() -{ - return 0xffffffffu; -}; -template <> -_RAFT_HOST_DEVICE inline std::uint64_t get_max_value() -{ - return 0xfffffffffffffffflu; -}; - -template -struct constexpr_max { - static const int value = A; -}; - -template -struct constexpr_max A), bool>> { - static const int value = B; -}; - -template -struct gen_index_msb_1_mask { - static constexpr IdxT value = static_cast(1) << (utils::size_of() * 8 - 1); -}; -} // namespace utils - -/** - * Utility to sync memory from a host_matrix_view to a raft::device_matrix_view - * - * In certain situations (UVM/HMM/ATS) host memory might be directly accessible on the - * device, and no extra allocations need to be performed. This class checks - * if the host_matrix_view is already accessible on the device, and only creates device - * memory and copies over if necessary. In memory limited situations this is preferable - * to having both a host and device copy - * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer - * have been merged, we should remove this class and switch over to using mdbuffer for this - */ -template -class device_matrix_view_from_host { - public: - device_matrix_view_from_host(raft::resources const& res, - raft::host_matrix_view host_view) - : host_view_(host_view) - { - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - device_ptr = reinterpret_cast(attr.devicePointer); - if (device_ptr == NULL) { - // allocate memory and copy over - device_mem_.emplace( - raft::make_device_matrix(res, host_view.extent(0), host_view.extent(1))); - raft::copy(device_mem_->data_handle(), - host_view.data_handle(), - host_view.extent(0) * host_view.extent(1), - raft::resource::get_cuda_stream(res)); - device_ptr = device_mem_->data_handle(); - } - } - - raft::device_matrix_view view() - { - return raft::make_device_matrix_view( - device_ptr, host_view_.extent(0), host_view_.extent(1)); - } - - T* data_handle() { return device_ptr; } - - bool allocated_memory() const { return device_mem_.has_value(); } - - private: - std::optional> device_mem_; - raft::host_matrix_view host_view_; - T* device_ptr; -}; - -/** - * Utility to sync memory from a raft::device_matrix_view to a host_matrix_view - * - * In certain situations (UVM/HMM/ATS) device memory might be directly accessible on the - * host, and no extra allocations need to be performed. This class checks - * if the raft::device_matrix_view is already accessible on the host, and only creates host - * memory and copies over if necessary. In memory limited situations this is preferable - * to having both a host and device copy - * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer - * have been merged, we should remove this class and switch over to using mdbuffer for this - */ -template -class host_matrix_view_from_device { - public: - host_matrix_view_from_device(raft::resources const& res, - raft::device_matrix_view device_view) - : device_view_(device_view) - { - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, device_view.data_handle())); - host_ptr = reinterpret_cast(attr.hostPointer); - if (host_ptr == NULL) { - // allocate memory and copy over - host_mem_.emplace( - raft::make_host_matrix(device_view.extent(0), device_view.extent(1))); - raft::copy(host_mem_->data_handle(), - device_view.data_handle(), - device_view.extent(0) * device_view.extent(1), - raft::resource::get_cuda_stream(res)); - host_ptr = host_mem_->data_handle(); - } - } - - raft::host_matrix_view view() - { - return raft::make_host_matrix_view( - host_ptr, device_view_.extent(0), device_view_.extent(1)); - } - - T* data_handle() { return host_ptr; } - - bool allocated_memory() const { return host_mem_.has_value(); } - - private: - std::optional> host_mem_; - raft::device_matrix_view device_view_; - T* host_ptr; -}; - -// Copy matrix src to dst. pad rows with 0 if necessary to make them 16 byte aligned. -template -void copy_with_padding( - raft::resources const& res, - raft::device_matrix& dst, - raft::mdspan, raft::row_major, data_accessor> src, - rmm::mr::device_memory_resource* mr = nullptr) -{ - if (!mr) { mr = rmm::mr::get_current_device_resource(); } - size_t padded_dim = raft::round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); - - if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { - // clear existing memory before allocating to prevent OOM errors on large datasets - if (dst.size()) { dst = raft::make_device_matrix(res, 0, 0); } - dst = - raft::make_device_mdarray(res, mr, raft::make_extents(src.extent(0), padded_dim)); - } - if (dst.extent(1) == src.extent(1)) { - raft::copy( - dst.data_handle(), src.data_handle(), src.size(), raft::resource::get_cuda_stream(res)); - } else { - // copy with padding - RAFT_CUDA_TRY(cudaMemsetAsync( - dst.data_handle(), 0, dst.size() * sizeof(T), raft::resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), - sizeof(T) * dst.extent(1), - src.data_handle(), - sizeof(T) * src.extent(1), - sizeof(T) * src.extent(1), - src.extent(0), - cudaMemcpyDefault, - raft::resource::get_cuda_stream(res))); - } -} -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/include/cuvs/neighbors/detail/div_utils.hpp b/cpp/include/cuvs/neighbors/detail/div_utils.hpp deleted file mode 100644 index 805bb1304..000000000 --- a/cpp/include/cuvs/neighbors/detail/div_utils.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _RAFT_HAS_CUDA -#include -#else -#include -#endif - -/** - * @brief A simple wrapper for raft::Pow2 which uses raft::Pow2 utils only when available and - * regular integer division otherwise. This is done to allow a common interface for division - * arithmetic for non CUDA headers. - * - * @tparam Value_ a compile-time value representable as a power-of-two. - */ -namespace cuvs::neighbors::detail { -template -struct div_utils { - typedef decltype(Value_) Type; - static constexpr Type Value = Value_; - - template - static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x) - { -#if defined(_RAFT_HAS_CUDA) - return raft::Pow2::roundDown(x); -#else - return raft::round_down_safe(x, Value_); -#endif - } - - template - static constexpr _RAFT_HOST_DEVICE inline auto mod(T x) - { -#if defined(_RAFT_HAS_CUDA) - return raft::Pow2::mod(x); -#else - return x % Value_; -#endif - } - - template - static constexpr _RAFT_HOST_DEVICE inline auto div(T x) - { -#if defined(_RAFT_HAS_CUDA) - return raft::Pow2::div(x); -#else - return x / Value_; -#endif - } -}; -} // namespace cuvs::neighbors::detail \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/Comparators.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/Comparators.cuh deleted file mode 100644 index 9ced61e13..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/Comparators.cuh +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include -#include - -namespace cuvs::neighbors::detail::faiss_select { - -template -struct Comparator { - __device__ static inline bool lt(T a, T b) { return a < b; } - - __device__ static inline bool gt(T a, T b) { return a > b; } -}; - -template <> -struct Comparator { - __device__ static inline bool lt(half a, half b) { return __hlt(a, b); } - - __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } -}; - -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/DistanceUtils.h b/cpp/include/cuvs/neighbors/detail/faiss_select/DistanceUtils.h deleted file mode 100644 index e8a41c1aa..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/DistanceUtils.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -namespace cuvs::neighbors::detail::faiss_select { -// If the inner size (dim) of the vectors is small, we want a larger query tile -// size, like 1024 -inline void chooseTileSize(size_t numQueries, - size_t numCentroids, - size_t dim, - size_t elementSize, - size_t totalMem, - size_t& tileRows, - size_t& tileCols) -{ - // The matrix multiplication should be large enough to be efficient, but if - // it is too large, we seem to lose efficiency as opposed to - // double-streaming. Each tile size here defines 1/2 of the memory use due - // to double streaming. We ignore available temporary memory, as that is - // adjusted independently by the user and can thus meet these requirements - // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, - // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. - size_t targetUsage = 0; - - if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { - targetUsage = 512 * 1024 * 1024; - } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { - targetUsage = 768 * 1024 * 1024; - } else { - targetUsage = 1024 * 1024 * 1024; - } - - targetUsage /= 2 * elementSize; - - // 512 seems to be a batch size sweetspot for float32. - // If we are on float16, increase to 512. - // If the k size (vec dim) of the matrix multiplication is small (<= 32), - // increase to 1024. - size_t preferredTileRows = 512; - if (dim <= 32) { preferredTileRows = 1024; } - - tileRows = std::min(preferredTileRows, numQueries); - - // tileCols is the remainder size - tileCols = std::min(targetUsage / preferredTileRows, numCentroids); -} -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkBlock.cuh deleted file mode 100644 index 14a56cfe1..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkBlock.cuh +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include -#include -#include - -namespace cuvs::neighbors::detail::faiss_select { - -// Merge pairs of lists smaller than blockDim.x (NumThreads) -template -inline __device__ void blockMergeSmall(K* listK, V* listV) -{ - static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); - static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); - static_assert(L <= NumThreads, "merge list size must be <= NumThreads"); - - // Which pair of lists we are merging - int mergeId = threadIdx.x / L; - - // Which thread we are within the merge - int tid = threadIdx.x % L; - - // listK points to a region of size N * 2 * L - listK += 2 * L * mergeId; - listV += 2 * L * mergeId; - - // It's not a bitonic merge, both lists are in the same direction, - // so handle the first swap assuming the second list is reversed - int pos = L - 1 - tid; - int stride = 2 * tid + 1; - - if (AllThreads || (threadIdx.x < N * L)) { - K ka = listK[pos]; - K kb = listK[pos + stride]; - - bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - listK[pos] = swap ? kb : ka; - listK[pos + stride] = swap ? ka : kb; - - V va = listV[pos]; - V vb = listV[pos + stride]; - listV[pos] = swap ? vb : va; - listV[pos + stride] = swap ? va : vb; - - // FIXME: is this a CUDA 9 compiler bug? - // K& ka = listK[pos]; - // K& kb = listK[pos + stride]; - - // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - // swap(s, ka, kb); - - // V& va = listV[pos]; - // V& vb = listV[pos + stride]; - // swap(s, va, vb); - } - - __syncthreads(); - -#pragma unroll - for (int stride = L / 2; stride > 0; stride /= 2) { - int pos = 2 * tid - (tid & (stride - 1)); - - if (AllThreads || (threadIdx.x < N * L)) { - K ka = listK[pos]; - K kb = listK[pos + stride]; - - bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - listK[pos] = swap ? kb : ka; - listK[pos + stride] = swap ? ka : kb; - - V va = listV[pos]; - V vb = listV[pos + stride]; - listV[pos] = swap ? vb : va; - listV[pos + stride] = swap ? va : vb; - - // FIXME: is this a CUDA 9 compiler bug? - // K& ka = listK[pos]; - // K& kb = listK[pos + stride]; - - // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - // swap(s, ka, kb); - - // V& va = listV[pos]; - // V& vb = listV[pos + stride]; - // swap(s, va, vb); - } - - __syncthreads(); - } -} - -// Merge pairs of sorted lists larger than blockDim.x (NumThreads) -template -inline __device__ void blockMergeLarge(K* listK, V* listV) -{ - static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); - static_assert(L >= raft::WarpSize, "merge list size must be >= 32"); - static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); - static_assert(L >= NumThreads, "merge list size must be >= NumThreads"); - - // For L > NumThreads, each thread has to perform more work - // per each stride. - constexpr int kLoopPerThread = L / NumThreads; - - // It's not a bitonic merge, both lists are in the same direction, - // so handle the first swap assuming the second list is reversed -#pragma unroll - for (int loop = 0; loop < kLoopPerThread; ++loop) { - int tid = loop * NumThreads + threadIdx.x; - int pos = L - 1 - tid; - int stride = 2 * tid + 1; - - K ka = listK[pos]; - K kb = listK[pos + stride]; - - bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - listK[pos] = swap ? kb : ka; - listK[pos + stride] = swap ? ka : kb; - - V va = listV[pos]; - V vb = listV[pos + stride]; - listV[pos] = swap ? vb : va; - listV[pos + stride] = swap ? va : vb; - - // FIXME: is this a CUDA 9 compiler bug? - // K& ka = listK[pos]; - // K& kb = listK[pos + stride]; - - // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - // swap(s, ka, kb); - - // V& va = listV[pos]; - // V& vb = listV[pos + stride]; - // swap(s, va, vb); - } - - __syncthreads(); - - constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2; - -#pragma unroll - for (int stride = L / 2; stride > 0; stride /= 2) { -#pragma unroll - for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { - int tid = loop * NumThreads + threadIdx.x; - int pos = 2 * tid - (tid & (stride - 1)); - - K ka = listK[pos]; - K kb = listK[pos + stride]; - - bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - listK[pos] = swap ? kb : ka; - listK[pos + stride] = swap ? ka : kb; - - V va = listV[pos]; - V vb = listV[pos + stride]; - listV[pos] = swap ? vb : va; - listV[pos + stride] = swap ? va : vb; - - // FIXME: is this a CUDA 9 compiler bug? - // K& ka = listK[pos]; - // K& kb = listK[pos + stride]; - - // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - // swap(s, ka, kb); - - // V& va = listV[pos]; - // V& vb = listV[pos + stride]; - // swap(s, va, vb); - } - - __syncthreads(); - } -} - -/// Class template to prevent static_assert from firing for -/// mixing smaller/larger than block cases -template -struct BlockMerge {}; - -/// Merging lists smaller than a block -template -struct BlockMerge { - static inline __device__ void merge(K* listK, V* listV) - { - constexpr int kNumParallelMerges = NumThreads / L; - constexpr int kNumIterations = N / kNumParallelMerges; - - static_assert(L <= NumThreads, "list must be <= NumThreads"); - static_assert((N < kNumParallelMerges) || (kNumIterations * kNumParallelMerges == N), - "improper selection of N and L"); - - if (N < kNumParallelMerges) { - // We only need L threads per each list to perform the merge - blockMergeSmall(listK, listV); - } else { - // All threads participate -#pragma unroll - for (int i = 0; i < kNumIterations; ++i) { - int start = i * kNumParallelMerges * 2 * L; - - blockMergeSmall(listK + start, - listV + start); - } - } - } -}; - -/// Merging lists larger than a block -template -struct BlockMerge { - static inline __device__ void merge(K* listK, V* listV) - { - // Each pair of lists is merged sequentially -#pragma unroll - for (int i = 0; i < N; ++i) { - int start = i * 2 * L; - - blockMergeLarge(listK + start, listV + start); - } - } -}; - -template -inline __device__ void blockMerge(K* listK, V* listV) -{ - constexpr bool kSmallerThanBlock = (L <= NumThreads); - - BlockMerge::merge(listK, listV); -} - -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkUtils.cuh deleted file mode 100644 index 7f7796fad..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkUtils.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -namespace cuvs::neighbors::detail::faiss_select { - -template -inline __device__ void swap(bool swap, T& x, T& y) -{ - T tmp = x; - x = swap ? y : x; - y = swap ? tmp : y; -} - -template -inline __device__ void assign(bool assign, T& x, T y) -{ - x = assign ? y : x; -} -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkWarp.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkWarp.cuh deleted file mode 100644 index cf97d99ca..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/MergeNetworkWarp.cuh +++ /dev/null @@ -1,520 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include -#include - -#include - -namespace cuvs::neighbors::detail::faiss_select { - -// -// This file contains functions to: -// -// -perform bitonic merges on pairs of sorted lists, held in -// registers. Each list contains N *raft::WarpSize (multiple of 32) -// elements for some N. -// The bitonic merge is implemented for arbitrary sizes; -// sorted list A of size N1 *raft::WarpSize registers -// sorted list B of size N2 *raft::WarpSize registers => -// sorted list C if size (N1 + N2) *raft::WarpSize registers. N1 and N2 -// are >= 1 and don't have to be powers of 2. -// -// -perform bitonic sorts on a set of N *raft::WarpSize key/value pairs -// held in registers, by using the above bitonic merge as a -// primitive. -// N can be an arbitrary N >= 1; i.e., the bitonic sort here supports -// odd sizes and doesn't require the input to be a power of 2. -// -// The sort or merge network is completely statically instantiated via -// template specialization / expansion and constexpr, and it uses warp -// shuffles to exchange values between warp lanes. -// -// A note about comparisons: -// -// For a sorting network of keys only, we only need one -// comparison (a < b). However, what we really need to know is -// if one lane chooses to exchange a value, then the -// corresponding lane should also do the exchange. -// Thus, if one just uses the negation !(x < y) in the higher -// lane, this will also include the case where (x == y). Thus, one -// lane in fact performs an exchange and the other doesn't, but -// because the only value being exchanged is equivalent, nothing has -// changed. -// So, you can get away with just one comparison and its negation. -// -// If we're sorting keys and values, where equivalent keys can -// exist, then this is a problem, since we want to treat (x, v1) -// as not equivalent to (x, v2). -// -// To remedy this, you can either compare with a lexicographic -// ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since -// we're predicating all of the choices results in 3 comparisons -// being executed, or we can invert the selection so that there is no -// middle choice of equality; the other lane will likewise -// check that (b.k > a.k) (the higher lane has the values -// swapped). Then, the first lane swaps if and only if the -// second lane swaps; if both lanes have equivalent keys, no -// swap will be performed. This results in only two comparisons -// being executed. -// -// If you don't consider values as well, then this does not produce a -// consistent ordering among (k, v) pairs with equivalent keys but -// different values; for us, we don't really care about ordering or -// stability here. -// -// I have tried both re-arranging the order in the higher lane to get -// away with one comparison or adding the value to the check; both -// result in greater register consumption or lower speed than just -// performing both < and > comparisons with the variables, so I just -// stick with this. - -// This function mergesraft::WarpSize / 2L lists in parallel using warp -// shuffles. -// It works on at most size-16 lists, as we need 32 threads for this -// shuffle merge. -// -// If IsBitonic is false, the first stage is reversed, so we don't -// need to sort directionally. It's still technically a bitonic sort. -template -inline __device__ void warpBitonicMergeLE16(K& k, V& v) -{ - static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); - static_assert(L <= raft::WarpSize / 2, "merge list size must be <= 16"); - - int laneId = raft::laneId(); - - if (!IsBitonic) { - // Reverse the first comparison stage. - // For example, merging a list of size 8 has the exchanges: - // 0 <-> 15, 1 <-> 14, ... - K otherK = raft::shfl_xor(k, 2 * L - 1); - V otherV = raft::shfl_xor(v, 2 * L - 1); - - // Whether we are the lesser thread in the exchange - bool small = !(laneId & L); - - if (Dir) { - // See the comment above how performing both of these - // comparisons in the warp seems to win out over the - // alternatives in practice - bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); - assign(s, k, otherK); - assign(s, v, otherV); - - } else { - bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); - assign(s, k, otherK); - assign(s, v, otherV); - } - } - -#pragma unroll - for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { - K otherK = raft::shfl_xor(k, stride); - V otherV = raft::shfl_xor(v, stride); - - // Whether we are the lesser thread in the exchange - bool small = !(laneId & stride); - - if (Dir) { - bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); - assign(s, k, otherK); - assign(s, v, otherV); - - } else { - bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); - assign(s, k, otherK); - assign(s, v, otherV); - } - } -} - -// Template for performing a bitonic merge of an arbitrary set of -// registers -template -struct BitonicMergeStep {}; - -// -// Power-of-2 merge specialization -// - -// All merges eventually call this -template -struct BitonicMergeStep { - static inline __device__ void merge(K k[1], V v[1]) - { - // Use warp shuffles - warpBitonicMergeLE16(k[0], v[0]); - } -}; - -template -struct BitonicMergeStep { - static inline __device__ void merge(K k[N], V v[N]) - { - static_assert(utils::isPowerOf2(N), "must be power of 2"); - static_assert(N > 1, "must be N > 1"); - -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - K& ka = k[i]; - V& va = v[i]; - - K& kb = k[i + N / 2]; - V& vb = v[i + N / 2]; - - bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - swap(s, ka, kb); - swap(s, va, vb); - } - - { - K newK[N / 2]; - V newV[N / 2]; - -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i]; - newV[i] = v[i]; - } - - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - k[i] = newK[i]; - v[i] = newV[i]; - } - } - - { - K newK[N / 2]; - V newV[N / 2]; - -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i + N / 2]; - newV[i] = v[i + N / 2]; - } - - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - k[i + N / 2] = newK[i]; - v[i + N / 2] = newV[i]; - } - } - } -}; - -// -// Non-power-of-2 merge specialization -// - -// Low recursion -template -struct BitonicMergeStep { - static inline __device__ void merge(K k[N], V v[N]) - { - static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); - static_assert(N >= 3, "must be N >= 3"); - - constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N); - -#pragma unroll - for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - V& va = v[i]; - - K& kb = k[i + kNextHighestPowerOf2 / 2]; - V& vb = v[i + kNextHighestPowerOf2 / 2]; - - bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - swap(s, ka, kb); - swap(s, va, vb); - } - - constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; - constexpr int kHighSize = kNextHighestPowerOf2 / 2; - { - K newK[kLowSize]; - V newV[kLowSize]; - -#pragma unroll - for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i] = v[i]; - } - - constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); - // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i] = newV[i]; - } - } - - { - K newK[kHighSize]; - V newV[kHighSize]; - -#pragma unroll - for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i] = v[i + kLowSize]; - } - - constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); - // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = - // utils::isPowerOf2(kHighSize); - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize] = newV[i]; - } - } - } -}; - -// High recursion -template -struct BitonicMergeStep { - static inline __device__ void merge(K k[N], V v[N]) - { - static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); - static_assert(N >= 3, "must be N >= 3"); - - constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N); - -#pragma unroll - for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - V& va = v[i]; - - K& kb = k[i + kNextHighestPowerOf2 / 2]; - V& vb = v[i + kNextHighestPowerOf2 / 2]; - - bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); - swap(s, ka, kb); - swap(s, va, vb); - } - - constexpr int kLowSize = kNextHighestPowerOf2 / 2; - constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; - { - K newK[kLowSize]; - V newV[kLowSize]; - -#pragma unroll - for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i] = v[i]; - } - - constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); - // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i] = newV[i]; - } - } - - { - K newK[kHighSize]; - V newV[kHighSize]; - -#pragma unroll - for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i] = v[i + kLowSize]; - } - - constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); - // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = - // utils::isPowerOf2(kHighSize); - BitonicMergeStep::merge(newK, newV); - -#pragma unroll - for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize] = newV[i]; - } - } - } -}; - -/// Merges two sets of registers across the warp of any size; -/// i.e., merges a sorted k/v list of sizeraft::WarpSize * N1 with a -/// sorted k/v list of sizeraft::WarpSize * N2, where N1 and N2 are any -/// value >= 1 -template -inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1], K k2[N2], V v2[N2]) -{ - constexpr int kSmallestN = N1 < N2 ? N1 : N2; - -#pragma unroll - for (int i = 0; i < kSmallestN; ++i) { - K& ka = k1[N1 - 1 - i]; - V& va = v1[N1 - 1 - i]; - - K& kb = k2[i]; - V& vb = v2[i]; - - K otherKa; - V otherVa; - - if (FullMerge) { - // We need the other values - otherKa = raft::shfl_xor(ka, raft::WarpSize - 1); - otherVa = raft::shfl_xor(va, raft::WarpSize - 1); - } - - K otherKb = raft::shfl_xor(kb, raft::WarpSize - 1); - V otherVb = raft::shfl_xor(vb, raft::WarpSize - 1); - - // ka is always first in the list, so we needn't use our lane - // in this comparison - bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb); - assign(swapa, ka, otherKb); - assign(swapa, va, otherVb); - - // kb is always second in the list, so we needn't use our lane - // in this comparison - if (FullMerge) { - bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa); - assign(swapb, kb, otherKa); - assign(swapb, vb, otherVa); - - } else { - // We don't care about updating elements in the second list - } - } - - BitonicMergeStep::merge(k1, v1); - if (FullMerge) { - // Only if we care about N2 do we need to bother merging it fully - BitonicMergeStep::merge(k2, v2); - } -} - -// Recursive template that uses the above bitonic merge to perform a -// bitonic sort -template -struct BitonicSortStep { - static inline __device__ void sort(K k[N], V v[N]) - { - static_assert(N > 1, "did not hit specialized case"); - - // Sort recursively - constexpr int kSizeA = N / 2; - constexpr int kSizeB = N - kSizeA; - - K aK[kSizeA]; - V aV[kSizeA]; - -#pragma unroll - for (int i = 0; i < kSizeA; ++i) { - aK[i] = k[i]; - aV[i] = v[i]; - } - - BitonicSortStep::sort(aK, aV); - - K bK[kSizeB]; - V bV[kSizeB]; - -#pragma unroll - for (int i = 0; i < kSizeB; ++i) { - bK[i] = k[i + kSizeA]; - bV[i] = v[i + kSizeA]; - } - - BitonicSortStep::sort(bK, bV); - - // Merge halves - warpMergeAnyRegisters(aK, aV, bK, bV); - -#pragma unroll - for (int i = 0; i < kSizeA; ++i) { - k[i] = aK[i]; - v[i] = aV[i]; - } - -#pragma unroll - for (int i = 0; i < kSizeB; ++i) { - k[i + kSizeA] = bK[i]; - v[i + kSizeA] = bV[i]; - } - } -}; - -// Single warp (N == 1) sorting specialization -template -struct BitonicSortStep { - static inline __device__ void sort(K k[1], V v[1]) - { - // Update this code if this changes - // should go from 1 ->raft::WarpSize in multiples of 2 - static_assert(raft::WarpSize == 32, "unexpected warp size"); - - warpBitonicMergeLE16(k[0], v[0]); - warpBitonicMergeLE16(k[0], v[0]); - warpBitonicMergeLE16(k[0], v[0]); - warpBitonicMergeLE16(k[0], v[0]); - warpBitonicMergeLE16(k[0], v[0]); - } -}; - -/// Sort a list ofraft::WarpSize * N elements in registers, where N is an -/// arbitrary >= 1 -template -inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) -{ - BitonicSortStep::sort(k, v); -} - -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/Select.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/Select.cuh deleted file mode 100644 index 796a841a4..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/Select.cuh +++ /dev/null @@ -1,570 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace cuvs::neighbors::detail::faiss_select { - -// Specialization for block-wide monotonic merges producing a merge sort -// since what we really want is a constexpr loop expansion -template -struct FinalBlockMerge {}; - -template -struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> { - static inline __device__ void merge(K* sharedK, V* sharedV) - { - // no merge required; single warp - } -}; - -template -struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> { - static inline __device__ void merge(K* sharedK, V* sharedV) - { - // Final merge doesn't need to fully merge the second list - blockMerge( - sharedK, sharedV); - } -}; - -template -struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> { - static inline __device__ void merge(K* sharedK, V* sharedV) - { - blockMerge(sharedK, - sharedV); - // Final merge doesn't need to fully merge the second list - blockMerge(sharedK, sharedV); - } -}; - -template -struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> { - static inline __device__ void merge(K* sharedK, V* sharedV) - { - blockMerge(sharedK, - sharedV); - blockMerge( - sharedK, sharedV); - // Final merge doesn't need to fully merge the second list - blockMerge(sharedK, sharedV); - } -}; - -// `Dir` true, produce largest values. -// `Dir` false, produce smallest values. -template -struct BlockSelect { - static constexpr int kNumWarps = ThreadsPerBlock / raft::WarpSize; - static constexpr int kTotalWarpSortSize = NumWarpQ; - - __device__ inline BlockSelect(K initKVal, V initVVal, K* smemK, V* smemV, int k) - : initK(initKVal), - initV(initVVal), - numVals(0), - warpKTop(initKVal), - sharedK(smemK), - sharedV(smemV), - kMinus1(k - 1) - { - static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); - static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); - - // Fill the per-thread queue keys with the default value -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i] = initV; - } - - int laneId = raft::laneId(); - int warpId = threadIdx.x / raft::WarpSize; - warpK = sharedK + warpId * kTotalWarpSortSize; - warpV = sharedV + warpId * kTotalWarpSortSize; - - // Fill warp queue (only the actual queue space is fine, not where - // we write the per-thread queues for merging) - for (int i = laneId; i < NumWarpQ; i += raft::WarpSize) { - warpK[i] = initK; - warpV[i] = initV; - } - - raft::warpFence(); - } - - __device__ inline void addThreadQ(K k, V v) - { - if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { - // Rotate right -#pragma unroll - for (int i = NumThreadQ - 1; i > 0; --i) { - threadK[i] = threadK[i - 1]; - threadV[i] = threadV[i - 1]; - } - - threadK[0] = k; - threadV[0] = v; - ++numVals; - } - } - - __device__ inline void checkThreadQ() - { - bool needSort = (numVals == NumThreadQ); - -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else - needSort = __any(needSort); -#endif - - if (!needSort) { - // no lanes have triggered a sort - return; - } - - // This has a trailing raft::warpFence - mergeWarpQ(); - - // Any top-k elements have been merged into the warp queue; we're - // free to reset the thread queues - numVals = 0; - -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i] = initV; - } - - // We have to beat at least this element - warpKTop = warpK[kMinus1]; - - raft::warpFence(); - } - - /// This function handles sorting and merging together the - /// per-thread queues with the warp-wide queue, creating a sorted - /// list across both - __device__ inline void mergeWarpQ() - { - int laneId = raft::laneId(); - - // Sort all of the per-thread queues - warpSortAnyRegisters(threadK, threadV); - - constexpr int kNumWarpQRegisters = NumWarpQ / raft::WarpSize; - K warpKRegisters[kNumWarpQRegisters]; - V warpVRegisters[kNumWarpQRegisters]; - -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpKRegisters[i] = warpK[i * raft::WarpSize + laneId]; - warpVRegisters[i] = warpV[i * raft::WarpSize + laneId]; - } - - raft::warpFence(); - - // The warp queue is already sorted, and now that we've sorted the - // per-thread queue, merge both sorted lists together, producing - // one sorted list - warpMergeAnyRegisters( - warpKRegisters, warpVRegisters, threadK, threadV); - - // Write back out the warp queue -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i * raft::WarpSize + laneId] = warpKRegisters[i]; - warpV[i * raft::WarpSize + laneId] = warpVRegisters[i]; - } - - raft::warpFence(); - } - - /// WARNING: all threads in a warp must participate in this. - /// Otherwise, you must call the constituent parts separately. - __device__ inline void add(K k, V v) - { - addThreadQ(k, v); - checkThreadQ(); - } - - __device__ inline void reduce() - { - // Have all warps dump and merge their queues; this will produce - // the final per-warp results - mergeWarpQ(); - - // block-wide dep; thus far, all warps have been completely - // independent - __syncthreads(); - - // All warp queues are contiguous in smem. - // Now, we have kNumWarps lists of NumWarpQ elements. - // This is a power of 2. - FinalBlockMerge::merge(sharedK, sharedV); - - // The block-wide merge has a trailing syncthreads - } - - // Default element key - const K initK; - - // Default element value - const V initV; - - // Number of valid elements in our thread queue - int numVals; - - // The k-th highest (Dir) or lowest (!Dir) element - K warpKTop; - - // Thread queue values - K threadK[NumThreadQ]; - V threadV[NumThreadQ]; - - // Queues for all warps - K* sharedK; - V* sharedV; - - // Our warp's queue (points into sharedK/sharedV) - // warpK[0] is highest (Dir) or lowest (!Dir) - K* warpK; - V* warpV; - - // This is a cached k-1 value - int kMinus1; -}; - -/// Specialization for k == 1 (NumWarpQ == 1) -template -struct BlockSelect { - static constexpr int kNumWarps = ThreadsPerBlock / raft::WarpSize; - - __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) - : threadK(initK), threadV(initV), sharedK(smemK), sharedV(smemV) - { - } - - __device__ inline void addThreadQ(K k, V v) - { - bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); - threadK = swap ? k : threadK; - threadV = swap ? v : threadV; - } - - __device__ inline void checkThreadQ() - { - // We don't need to do anything here, since the warp doesn't - // cooperate until the end - } - - __device__ inline void add(K k, V v) { addThreadQ(k, v); } - - __device__ inline void reduce() - { - // Reduce within the warp - raft::KeyValuePair pair(threadK, threadV); - - if (Dir) { - pair = warpReduce(pair, raft::max_op{}); - } else { - pair = warpReduce(pair, raft::min_op{}); - } - - // Each warp writes out a single value - int laneId = raft::laneId(); - int warpId = threadIdx.x / raft::WarpSize; - - if (laneId == 0) { - sharedK[warpId] = pair.key; - sharedV[warpId] = pair.value; - } - - __syncthreads(); - - // We typically use this for small blocks (<= 128), just having the - // first thread in the block perform the reduction across warps is - // faster - if (threadIdx.x == 0) { - threadK = sharedK[0]; - threadV = sharedV[0]; - -#pragma unroll - for (int i = 1; i < kNumWarps; ++i) { - K k = sharedK[i]; - V v = sharedV[i]; - - bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); - threadK = swap ? k : threadK; - threadV = swap ? v : threadV; - } - - // Hopefully a thread's smem reads/writes are ordered wrt - // itself, so no barrier needed :) - sharedK[0] = threadK; - sharedV[0] = threadV; - } - - // In case other threads wish to read this value - __syncthreads(); - } - - // threadK is lowest (Dir) or highest (!Dir) - K threadK; - V threadV; - - // Where we reduce in smem - K* sharedK; - V* sharedV; -}; - -// -// per-warp WarpSelect -// - -// `Dir` true, produce largest values. -// `Dir` false, produce smallest values. -template -struct WarpSelect { - static constexpr int kNumWarpQRegisters = NumWarpQ / raft::WarpSize; - - __device__ inline WarpSelect(K initKVal, V initVVal, int k) - : initK(initKVal), - initV(initVVal), - numVals(0), - warpKTop(initKVal), - kLane((k - 1) % raft::WarpSize) - { - static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); - static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); - - // Fill the per-thread queue keys with the default value -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i] = initV; - } - - // Fill the warp queue with the default value -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i] = initK; - warpV[i] = initV; - } - } - - __device__ inline void addThreadQ(K k, V v) - { - if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { - // Rotate right -#pragma unroll - for (int i = NumThreadQ - 1; i > 0; --i) { - threadK[i] = threadK[i - 1]; - threadV[i] = threadV[i - 1]; - } - - threadK[0] = k; - threadV[0] = v; - ++numVals; - } - } - - __device__ inline void checkThreadQ() - { - bool needSort = (numVals == NumThreadQ); - -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else - needSort = __any(needSort); -#endif - - if (!needSort) { - // no lanes have triggered a sort - return; - } - - mergeWarpQ(); - - // Any top-k elements have been merged into the warp queue; we're - // free to reset the thread queues - numVals = 0; - -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i] = initV; - } - - // We have to beat at least this element - warpKTop = raft::shfl(warpK[kNumWarpQRegisters - 1], kLane); - } - - /// This function handles sorting and merging together the - /// per-thread queues with the warp-wide queue, creating a sorted - /// list across both - __device__ inline void mergeWarpQ() - { - // Sort all of the per-thread queues - warpSortAnyRegisters(threadK, threadV); - - // The warp queue is already sorted, and now that we've sorted the - // per-thread queue, merge both sorted lists together, producing - // one sorted list - warpMergeAnyRegisters( - warpK, warpV, threadK, threadV); - } - - /// WARNING: all threads in a warp must participate in this. - /// Otherwise, you must call the constituent parts separately. - __device__ inline void add(K k, V v) - { - addThreadQ(k, v); - checkThreadQ(); - } - - __device__ inline void reduce() - { - // Have all warps dump and merge their queues; this will produce - // the final per-warp results - mergeWarpQ(); - } - - /// Dump final k selected values for this warp out - __device__ inline void writeOut(K* outK, V* outV, int k) - { - int laneId = raft::laneId(); - -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - int idx = i * raft::WarpSize + laneId; - - if (idx < k) { - outK[idx] = warpK[i]; - outV[idx] = warpV[i]; - } - } - } - - // Default element key - const K initK; - - // Default element value - const V initV; - - // Number of valid elements in our thread queue - int numVals; - - // The k-th highest (Dir) or lowest (!Dir) element - K warpKTop; - - // Thread queue values - K threadK[NumThreadQ]; - V threadV[NumThreadQ]; - - // warpK[0] is highest (Dir) or lowest (!Dir) - K warpK[kNumWarpQRegisters]; - V warpV[kNumWarpQRegisters]; - - // This is what lane we should load an approximation (>=k) to the - // kth element from the last register in the warp queue (i.e., - // warpK[kNumWarpQRegisters - 1]). - int kLane; -}; - -/// Specialization for k == 1 (NumWarpQ == 1) -template -struct WarpSelect { - static constexpr int kNumWarps = ThreadsPerBlock / raft::WarpSize; - - __device__ inline WarpSelect(K initK, V initV, int k) : threadK(initK), threadV(initV) {} - - __device__ inline void addThreadQ(K k, V v) - { - bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); - threadK = swap ? k : threadK; - threadV = swap ? v : threadV; - } - - __device__ inline void checkThreadQ() - { - // We don't need to do anything here, since the warp doesn't - // cooperate until the end - } - - __device__ inline void add(K k, V v) { addThreadQ(k, v); } - - __device__ inline void reduce() - { - // Reduce within the warp - raft::KeyValuePair pair(threadK, threadV); - - if (Dir) { - pair = warpReduce(pair, raft::max_op{}); - } else { - pair = warpReduce(pair, raft::min_op{}); - } - - threadK = pair.key; - threadV = pair.value; - } - - /// Dump final k selected values for this warp out - __device__ inline void writeOut(K* outK, V* outV, int k) - { - if (raft::laneId() == 0) { - *outK = threadK; - *outV = threadV; - } - } - - // threadK is lowest (Dir) or highest (!Dir) - K threadK; - V threadV; -}; - -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/StaticUtils.h b/cpp/include/cuvs/neighbors/detail/faiss_select/StaticUtils.h deleted file mode 100644 index 6f53cf7f8..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/StaticUtils.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include - -// allow usage for non-CUDA files -#ifndef __host__ -#define __host__ -#define __device__ -#endif - -namespace cuvs::neighbors::detail::faiss_select::utils { - -template -constexpr __host__ __device__ bool isPowerOf2(T v) -{ - return (v && !(v & (v - 1))); -} - -static_assert(isPowerOf2(2048), "isPowerOf2"); -static_assert(!isPowerOf2(3333), "isPowerOf2"); - -template -constexpr __host__ __device__ T nextHighestPowerOf2(T v) -{ - return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); -} - -static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); - -static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); - -static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPowerOf2"); -static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, - "nextHighestPowerOf2"); - -} // namespace cuvs::neighbors::detail::faiss_select::utils diff --git a/cpp/include/cuvs/neighbors/detail/faiss_select/key_value_block_select.cuh b/cpp/include/cuvs/neighbors/detail/faiss_select/key_value_block_select.cuh deleted file mode 100644 index 14484435b..000000000 --- a/cpp/include/cuvs/neighbors/detail/faiss_select/key_value_block_select.cuh +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file thirdparty/LICENSES/LICENSE.faiss - */ - -#pragma once - -#include -#include - -// TODO: Need to think further about the impact (and new boundaries created) on the registers -// because this will change the max k that can be processed. One solution might be to break -// up k into multiple batches for larger k. - -namespace cuvs::neighbors::detail::faiss_select { - -// `Dir` true, produce largest values. -// `Dir` false, produce smallest values. -template -struct KeyValueBlockSelect { - static constexpr int kNumWarps = ThreadsPerBlock / raft::WarpSize; - static constexpr int kTotalWarpSortSize = NumWarpQ; - - __device__ inline KeyValueBlockSelect( - K initKVal, K initVKey, V initVVal, K* smemK, KeyValuePair* smemV, int k) - : initK(initKVal), - initVk(initVKey), - initVv(initVVal), - numVals(0), - warpKTop(initKVal), - warpKTopRDist(initKVal), - sharedK(smemK), - sharedV(smemV), - kMinus1(k - 1) - { - static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); - static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); - - // Fill the per-thread queue keys with the default value -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i].key = initVk; - threadV[i].value = initVv; - } - - int laneId = raft::laneId(); - int warpId = threadIdx.x / raft::WarpSize; - warpK = sharedK + warpId * kTotalWarpSortSize; - warpV = sharedV + warpId * kTotalWarpSortSize; - - // Fill warp queue (only the actual queue space is fine, not where - // we write the per-thread queues for merging) - for (int i = laneId; i < NumWarpQ; i += raft::WarpSize) { - warpK[i] = initK; - warpV[i].key = initVk; - warpV[i].value = initVv; - } - - raft::warpFence(); - } - - __device__ inline void addThreadQ(K k, K vk, V vv) - { - if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { - // Rotate right -#pragma unroll - for (int i = NumThreadQ - 1; i > 0; --i) { - threadK[i] = threadK[i - 1]; - threadV[i].key = threadV[i - 1].key; - threadV[i].value = threadV[i - 1].value; - } - - threadK[0] = k; - threadV[0].key = vk; - threadV[0].value = vv; - ++numVals; - } - } - - __device__ inline void checkThreadQ() - { - bool needSort = (numVals == NumThreadQ); - -#if CUDA_VERSION >= 9000 - needSort = __any_sync(0xffffffff, needSort); -#else - needSort = __any(needSort); -#endif - - if (!needSort) { - // no lanes have triggered a sort - return; - } - - // This has a trailing raft::warpFence - mergeWarpQ(); - - // Any top-k elements have been merged into the warp queue; we're - // free to reset the thread queues - numVals = 0; - -#pragma unroll - for (int i = 0; i < NumThreadQ; ++i) { - threadK[i] = initK; - threadV[i].key = initVk; - threadV[i].value = initVv; - } - - // We have to beat at least this element - warpKTop = warpK[kMinus1]; - warpKTopRDist = warpV[kMinus1].key; - - raft::warpFence(); - } - - /// This function handles sorting and merging together the - /// per-thread queues with the warp-wide queue, creating a sorted - /// list across both - __device__ inline void mergeWarpQ() - { - int laneId = raft::laneId(); - - // Sort all of the per-thread queues - warpSortAnyRegisters, NumThreadQ, !Dir, Comp>(threadK, threadV); - - constexpr int kNumWarpQRegisters = NumWarpQ / raft::WarpSize; - K raft::warpKRegisters[kNumWarpQRegisters]; - KeyValuePair warpVRegisters[kNumWarpQRegisters]; - -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - raft::warpKRegisters[i] = warpK[i * raft::WarpSize + laneId]; - warpVRegisters[i].key = warpV[i * raft::WarpSize + laneId].key; - warpVRegisters[i].value = warpV[i * raft::WarpSize + laneId].value; - } - - raft::warpFence(); - - // The warp queue is already sorted, and now that we've sorted the - // per-thread queue, merge both sorted lists together, producing - // one sorted list - warpMergeAnyRegisters, kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>( - raft::warpKRegisters, warpVRegisters, threadK, threadV); - - // Write back out the warp queue -#pragma unroll - for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i * raft::WarpSize + laneId] = raft::warpKRegisters[i]; - warpV[i * raft::WarpSize + laneId].key = warpVRegisters[i].key; - warpV[i * raft::WarpSize + laneId].value = warpVRegisters[i].value; - } - - raft::warpFence(); - } - - /// WARNING: all threads in a warp must participate in this. - /// Otherwise, you must call the constituent parts separately. - __device__ inline void add(K k, K vk, V vv) - { - addThreadQ(k, vk, vv); - checkThreadQ(); - } - - __device__ inline void reduce() - { - // Have all warps dump and merge their queues; this will produce - // the final per-warp results - mergeWarpQ(); - - // block-wide dep; thus far, all warps have been completely - // independent - __syncthreads(); - - // All warp queues are contiguous in smem. - // Now, we have kNumWarps lists of NumWarpQ elements. - // This is a power of 2. - FinalBlockMerge, NumWarpQ, Dir, Comp>::merge( - sharedK, sharedV); - - // The block-wide merge has a trailing syncthreads - } - - // Default element key - const K initK; - - // Default element value - const K initVk; - const V initVv; - - // Number of valid elements in our thread queue - int numVals; - - // The k-th highest (Dir) or lowest (!Dir) element - K warpKTop; - - K warpKTopRDist; - - // Thread queue values - K threadK[NumThreadQ]; - KeyValuePair threadV[NumThreadQ]; - - // Queues for all warps - K* sharedK; - KeyValuePair* sharedV; - - // Our warp's queue (points into sharedK/sharedV) - // warpK[0] is highest (Dir) or lowest (!Dir) - K* warpK; - KeyValuePair* warpV; - - // This is a cached k-1 value - int kMinus1; -}; - -} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_build.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_build.cuh deleted file mode 100644 index 022e5eac5..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_build.cuh +++ /dev/null @@ -1,495 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace cuvs::neighbors::ivf_flat::detail { - -using namespace cuvs::spatial::knn::detail; // NOLINT - -template -auto clone(const raft::resources& res, const index& source) -> index -{ - auto stream = resource::get_cuda_stream(res); - - // Allocate the new index - index target(res, - source.metric(), - source.n_lists(), - source.adaptive_centers(), - source.conservative_memory_allocation(), - source.dim()); - - // Copy the independent parts - copy(target.list_sizes().data_handle(), - source.list_sizes().data_handle(), - source.list_sizes().size(), - stream); - copy(target.centers().data_handle(), - source.centers().data_handle(), - source.centers().size(), - stream); - if (source.center_norms().has_value()) { - target.allocate_center_norms(res); - copy(target.center_norms()->data_handle(), - source.center_norms()->data_handle(), - source.center_norms()->size(), - stream); - } - // Copy shared pointers - target.lists() = source.lists(); - - // Make sure the device pointers point to the new lists - target.recompute_internal_state(res); - - return target; -} - -/** - * @brief Record the dataset into the index, one source row at a time. - * - * The index consists of the dataset rows, grouped by their labels (into clusters/lists). - * Within each cluster (list), the data is grouped into blocks of `WarpSize` interleaved - * vectors. Note, the total index length is slightly larger than the dataset length, because - * each cluster is padded by `WarpSize` elements - * - * CUDA launch grid: - * X dimension must cover the dataset (n_rows), YZ are not used; - * there are no dependencies between threads, hence no constraints on the block size. - * - * @tparam T element type. - * @tparam IdxT type of the indices in the source source_vecs - * @tparam LabelT label type - * @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise - * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. - * - * @param[in] labels device pointer to the cluster ids for each row [n_rows] - * @param[in] source_vecs device pointer to the input data [n_rows, dim] - * @param[in] source_ixs device pointer to the input indices [n_rows] - * @param[out] list_data_ptrs device pointer to the index data of size [n_lists][index_size, dim] - * @param[out] list_index_ptrs device pointer to the source ids corr. to the output [n_lists] - * [index_size] - * @param[out] list_sizes_ptr device pointer to the cluster sizes [n_lists]; - * it's used as an atomic counter, and must be initialized with zeros. - * @param n_rows source length - * @param dim the dimensionality of the data - * @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`. - * - */ -template -RAFT_KERNEL build_index_kernel(const LabelT* labels, - const T* source_vecs, - const IdxT* source_ixs, - T** list_data_ptrs, - IdxT** list_index_ptrs, - uint32_t* list_sizes_ptr, - IdxT n_rows, - uint32_t dim, - uint32_t veclen) -{ - const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; - if (i >= n_rows) { return; } - - auto list_id = labels[i]; - auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); - auto* list_index = list_index_ptrs[list_id]; - auto* list_data = list_data_ptrs[list_id]; - - // Record the source vector id in the index - list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; - - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = raft::Pow2; - auto group_offset = interleaved_group::roundDown(inlist_id); - auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; - - // Point to the location of the interleaved group of vectors - list_data += group_offset * dim; - - // Point to the source vector - if constexpr (gather_src) { - source_vecs += source_ixs[i] * dim; - } else { - source_vecs += i * dim; - } - // Interleave dimensions of the source vector while recording it. - // NB: such `veclen` is selected, that `dim % veclen == 0` - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; - } - } -} - -/** See cuvs::neighbors::ivf_flat::extend docs */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - using LabelT = uint32_t; - RAFT_EXPECTS(index != nullptr, "index cannot be empty."); - - auto stream = resource::get_cuda_stream(handle); - auto n_lists = index->n_lists(); - auto dim = index->dim(); - list_spec list_device_spec{index->dim(), - index->conservative_memory_allocation()}; - raft::common::nvtx::range fun_scope( - "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - - RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, - "You must pass data indices when the index is non-empty."); - - auto new_labels = raft::make_device_vector(handle, n_rows); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); - auto orig_centroids_view = - raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); - cuvs::cluster::kmeans_balanced::predict(handle, - kmeans_params, - new_vectors_view, - orig_centroids_view, - new_labels.view(), - utils::mapping{}); - - auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); - copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); - - // Calculate the centers and sizes on the new data, starting from the original values - if (index->adaptive_centers()) { - auto centroids_view = raft::make_device_matrix_view( - index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); - auto list_sizes_view = - raft::make_device_vector_view, IdxT>( - list_sizes_ptr, n_lists); - auto const_labels_view = make_const_mdspan(new_labels.view()); - cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, - new_vectors_view, - const_labels_view, - centroids_view, - list_sizes_view, - false, - utils::mapping{}); - } else { - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(list_sizes_ptr), - IdxT(n_lists), - new_labels.data_handle(), - n_rows, - 1, - stream); - raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - } - - // Calculate and allocate new list data - std::vector new_list_sizes(n_lists); - std::vector old_list_sizes(n_lists); - { - copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); - copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); - resource::sync_stream(handle); - auto& lists = index->lists(); - for (uint32_t label = 0; label < n_lists; label++) { - ivf::resize_list(handle, - lists[label], - list_device_spec, - new_list_sizes[label], - raft::Pow2::roundUp(old_list_sizes[label])); - } - } - // Update the pointers and the sizes - index->recompute_internal_state(handle); - // Copy the old sizes, so we can start from the current state of the index; - // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - - // Kernel to insert the new vectors - const dim3 block_dim(256); - const dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x)); - build_index_kernel<<>>(new_labels.data_handle(), - new_vectors, - new_indices, - index->data_ptrs().data_handle(), - index->inds_ptrs().data_handle(), - list_sizes_ptr, - n_rows, - dim, - index->veclen()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - // Precompute the centers vector norms for L2Expanded distance - if (!index->center_norms().has_value()) { - index->allocate_center_norms(handle); - if (index->center_norms().has_value()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); - } - } else if (index->center_norms().has_value() && index->adaptive_centers()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); - } -} - -/** See cuvs::neighbors::ivf_flat::extend docs */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - auto ext_index = clone(handle, orig_index); - detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); - return ext_index; -} - -/** See cuvs::neighbors::ivf_flat::build docs */ -template -inline auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - auto stream = resource::get_cuda_stream(handle); - raft::common::nvtx::range fun_scope( - "ivf_flat::build(%zu, %u)", size_t(n_rows), dim); - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "unsupported data type"); - RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - - index index(handle, params, dim); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); - - // Train the kmeans clustering - { - auto trainset_ratio = std::max( - 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); - auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); - // TODO: a proper sampling - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = raft::make_device_matrix_view( - index.centers().data_handle(), index.n_lists(), index.dim()); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = index.metric(); - cuvs::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); - } - - // add the data if necessary - if (params.add_data_on_build) { - detail::extend(handle, &index, dataset, nullptr, n_rows); - } - return index; -} - -/** - * Build an index that can be used in refinement operation. - * - * See cuvs::neighbors::refine for details on the refinement operation. - * - * The returned index cannot be used for a regular ivf_flat::search. The index misses information - * about coarse clusters. Instead, the neighbor candidates are assumed to form clusters, one for - * each query. The candidate vectors are gathered into the index dataset, that can be later used - * in ivfflat_interleaved_scan. - * - * @param[in] handle the raft handle - * @param[inout] refinement_index - * @param[in] dataset device pointer to dataset vectors, size [n_rows, dim]. Note that n_rows is - * not known to this function, but each candidate_idx has to be smaller than n_rows. - * @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates] - * @param[in] n_candidates of neighbor_candidates - */ -template -inline void fill_refinement_index(raft::resources const& handle, - index* refinement_index, - const T* dataset, - const IdxT* candidate_idx, - IdxT n_queries, - uint32_t n_candidates) -{ - using LabelT = uint32_t; - - auto stream = resource::get_cuda_stream(handle); - uint32_t n_lists = n_queries; - raft::common::nvtx::range fun_scope( - "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); - - rmm::device_uvector new_labels(n_queries * n_candidates, stream); - auto new_labels_view = - raft::make_device_vector_view(new_labels.data(), n_queries * n_candidates); - linalg::map_offset( - handle, - new_labels_view, - raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); - - auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - // We do not fill centers and center norms, since we will not run coarse search. - - // Allocate new memory - auto& lists = refinement_index->lists(); - list_spec list_device_spec{refinement_index->dim(), false}; - for (uint32_t label = 0; label < n_lists; label++) { - ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0)); - } - // Update the pointers and the sizes - refinement_index->recompute_internal_state(handle); - - RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); - - const dim3 block_dim(256); - const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); - build_index_kernel - <<>>(new_labels.data(), - dataset, - candidate_idx, - refinement_index->data_ptrs().data_handle(), - refinement_index->inds_ptrs().data_handle(), - list_sizes_ptr, - n_queries * n_candidates, - refinement_index->dim(), - refinement_index->veclen()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -RAFT_KERNEL pack_interleaved_list_kernel(const T* codes, - T* list_data, - uint32_t n_rows, - uint32_t dim, - uint32_t veclen, - std::variant offset_or_indices) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const uint32_t dst_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + tid - : std::get(offset_or_indices)[tid]; - if (tid < n_rows) { codepacker::pack_1(codes + tid * dim, list_data, dim, veclen, dst_ix); } -} - -template -RAFT_KERNEL unpack_interleaved_list_kernel( - const T* list_data, - T* codes, - uint32_t n_rows, - uint32_t dim, - uint32_t veclen, - std::variant offset_or_indices) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const uint32_t src_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + tid - : std::get(offset_or_indices)[tid]; - if (tid < n_rows) { codepacker::unpack_1(list_data, codes + tid * dim, dim, veclen, src_ix); } -} - -template -void pack_list_data( - raft::resources const& res, - raft::device_matrix_view codes, - uint32_t veclen, - std::variant offset_or_indices, - raft::device_mdspan::list_extents, raft::row_major> - list_data) -{ - uint32_t n_rows = codes.extent(0); - uint32_t dim = codes.extent(1); - if (n_rows == 0 || dim == 0) return; - static constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto stream = resource::get_cuda_stream(res); - pack_interleaved_list_kernel<<>>( - codes.data_handle(), list_data.data_handle(), n_rows, dim, veclen, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void unpack_list_data( - raft::resources const& res, - raft::device_mdspan::list_extents, raft::row_major> - list_data, - uint32_t veclen, - std::variant offset_or_indices, - raft::device_matrix_view codes) -{ - uint32_t n_rows = codes.extent(0); - uint32_t dim = codes.extent(1); - if (n_rows == 0 || dim == 0) return; - static constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto stream = resource::get_cuda_stream(res); - unpack_interleaved_list_kernel<<>>( - list_data.data_handle(), codes.data_handle(), n_rows, dim, veclen, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh deleted file mode 100644 index cc32ff22a..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // uintX_t -#include // cuvs::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::ivf_flat::detail { - -template -void ivfflat_interleaved_scan(const cuvs::neighbors::ivf_flat::index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const uint32_t queries_offset, - const cuvs::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IvfSampleFilterT sample_filter, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, IvfSampleFilterT) \ - extern template void \ - cuvs::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const cuvs::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const cuvs::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream) - -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh deleted file mode 100644 index 221da924c..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ /dev/null @@ -1,1129 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include // RAFT_LOG_TRACE -#include -#include -#include // RAFT_CUDA_TRY -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::ivf_flat::detail { - -using namespace cuvs::spatial::knn::detail; // NOLINT - -constexpr int kThreadsPerBlock = 128; - -/** - * @brief Copy `n` elements per block from one place to another. - * - * @param[out] out target pointer (unique per block) - * @param[in] in source pointer - * @param n number of elements to copy - */ -template -__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) -{ - constexpr int VecElems = VecBytes / sizeof(T); // NOLINT - using align_bytes = raft::Pow2<(size_t)VecBytes>; - if constexpr (VecElems > 1) { - using align_elems = raft::Pow2; - if (!align_bytes::areSameAlignOffsets(out, in)) { - return copy_vectorized<(VecBytes >> 1), T>(out, in, n); - } - { // process unaligned head - uint32_t head = align_bytes::roundUp(in) - in; - if (head > 0) { - copy_vectorized(out, in, head); - n -= head; - in += head; - out += head; - } - } - { // process main part vectorized - using vec_t = typename raft::IOType::Type; - copy_vectorized( - reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); - } - { // process unaligned tail - uint32_t tail = align_elems::mod(n); - if (tail > 0) { - n -= tail; - copy_vectorized(out + n, in + n, tail); - } - } - } - if constexpr (VecElems <= 1) { - for (int i = threadIdx.x; i < n; i += blockDim.x) { - out[i] = in[i]; - } - } -} - -/** - * @brief Load a part of a vector from the index and from query, compute the (part of the) distance - * between them, and aggregate it using the provided Lambda; one structure per thread, per query, - * and per index item. - * - * @tparam kUnroll elements per loop (normally, kUnroll = raft::WarpSize / Veclen) - * @tparam Lambda computing the part of the distance for one dimension and aggregating it: - * void (AccT& acc, AccT x, AccT y) - * @tparam Veclen size of the vectorized load - * @tparam T type of the data in the query and the index - * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit - * values) - */ -template -struct loadAndComputeDist { - Lambda compute_dist; - AccT& dist; - - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in shared memory. - * Every thread here processes exactly kUnroll * Veclen elements independently of others. - */ - template - __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, - const T* query_shared, - IdxT loadIndex, - IdxT shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - raft::ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); - T queryRegs[Veclen]; - raft::lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in the global memory and is different for every - * thread. One warp loads exactly raft::WarpSize query elements at once and then reshuffles them - * into corresponding threads (`raft::WarpSize / (kUnroll * Veclen)` elements per thread at once). - */ - template - __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, - const T* query, - IdxT baseLoadIndex, - const int lane_id) - { - T queryReg = query[baseLoadIndex + lane_id]; - constexpr int stride = kUnroll * Veclen; - constexpr int totalIter = raft::WarpSize / stride; - constexpr int gmemStride = stride * kIndexGroupSize; -#pragma unroll - for (int i = 0; i < totalIter; ++i, data += gmemStride) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - raft::ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); - const int d = (i * kUnroll + j) * Veclen; -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); - } - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `raft::WarpSize`. - */ - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) - { - const int loadDim = dimBlocks + lane_id; - T queryReg = loadDim < dim ? query[loadDim] : 0; - const int loadDataIdx = lane_id * Veclen; - for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { - T enc[Veclen]; - raft::ldg(enc, data + loadDataIdx); -#pragma unroll - for (int k = 0; k < Veclen; k++) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), enc[k]); - } - } - } -}; - -// This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - loadIndex = loadIndex * veclen_int; -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - raft::ldg( - encV, - reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); - uint32_t queryRegs[veclen_int]; - raft::lds(queryRegs, - reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * uint8_veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - raft::ldg( - encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen_int = uint8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { - uint32_t enc[veclen_int]; - raft::ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - uint32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); - compute_dist(dist, q, enc[k]); - } - } - } -}; - -// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while -// using above common template of int2/int4 -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 4; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = query_shared[shmemIndex + j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = query[baseLoadIndex + lane_id]; - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; - uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 1; - int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; - uint32_t q = raft::shfl(queryReg, d, raft::WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -// This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - raft::ldg( - encV, - reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); - int32_t queryRegs[veclen_int]; - raft::lds(queryRegs, - reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - - int32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * int8_veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - raft::ldg( - encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - int32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); - compute_dist(dist, q, encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen_int = int8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { - int32_t enc[veclen_int]; - raft::ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - int32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); // Here 4 is for 1 - int; - compute_dist(dist, q, enc[k]); - } - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - int32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; - int32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - int32_t queryReg = query[baseLoadIndex + lane_id]; - -#pragma unroll - for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, - raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize), - data[lane_id + j * kIndexGroupSize]); - } - } - } - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 1; - const int loadDim = dimBlocks + lane_id; - int32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, raft::shfl(queryReg, d, raft::WarpSize), data[lane_id]); - } - } -}; - -/** - * Scan clusters for nearest neighbors of the query vectors. - * See `ivfflat_interleaved_scan` for more information. - * - * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. - * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is - * calculated, and the top-k nearest neighbors are selected. - * - * @param compute_dist distance function - * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a - * block; this number must be a multiple of `raft::WarpSize * Veclen`. - * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] - * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] - * @param[in] list_indices index.indices - * @param[in] list_data index.data - * @param[in] list_sizes index.list_sizes - * @param[in] list_offsets index.list_offsets - * @param n_probes - * @param k - * @param dim - * @param sample_filter - * @param[out] neighbors - * @param[out] distances - */ -template -RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) - interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, - const uint32_t query_smem_elems, - const T* query, - const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, - const T* const* list_data_ptrs, - const uint32_t* list_sizes, - const uint32_t queries_offset, - const uint32_t n_probes, - const uint32_t k, - const uint32_t dim, - IvfSampleFilterT sample_filter, - IdxT* neighbors, - float* distances) -{ - extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; - // Using shared memory for the (part of the) query; - // This allows to save on global memory bandwidth when reading index and query - // data at the same time. - // Its size is `query_smem_elems`. - T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); - // Make the query input and output point to this block's shared query - { - const int query_id = blockIdx.y; - query += query_id * dim; - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; - coarse_index += query_id * n_probes; - } - - // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); - __syncthreads(); - - using block_sort_t = raft::matrix::detail::select::warpsort::block_sort< - raft::matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - IdxT>; - block_sort_t queue(k); - - { - using align_warp = raft::Pow2; - const int lane_id = align_warp::mod(threadIdx.x); - - // How many full warps needed to compute the distance (without remainder) - const uint32_t full_warps_along_dim = align_warp::roundDown(dim); - - const uint32_t shm_assisted_dim = - (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; - - // Every CUDA block scans one cluster at a time. - for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - - // The number of vectors in each cluster(list); [nlist] - const uint32_t list_length = list_sizes[list_id]; - - // The number of interleaved groups to be processed - const uint32_t num_groups = - align_warp::div(list_length + align_warp::Mask); // raft::ceildiv by power of 2 - - constexpr int kUnroll = raft::WarpSize / Veclen; - constexpr uint32_t kNumWarps = kThreadsPerBlock / raft::WarpSize; - // Every warp reads raft::WarpSize vectors and computes the distances to them. - // Then, the distances and corresponding ids are distributed among the threads, - // and each thread adds one (id, dist) pair to the filtering queue. - for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; - group_id += kNumWarps) { - AccT dist = 0; - // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; - - // This is the vector a given lane/thread handles - const uint32_t vec_id = group_id * raft::WarpSize + lane_id; - const bool valid = - vec_id < list_length && sample_filter(queries_offset + blockIdx.y, list_id, vec_id); - - // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = 0; pos < shm_assisted_dim; - pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - - if (dim > query_smem_elems) { - // The default path - using raft::shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); - for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - } - - // Enqueue one element per thread - const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); - } - } - } - - // finalize and store selected neighbours - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); -} - -/** - * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size - */ -template -uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) -{ - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int num_sms; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); - - size_t min_grid_size = num_sms * num_blocks_per_sm; - size_t min_grid_x = raft::ceildiv(min_grid_size, numQueries); - return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); -} - -template -void launch_kernel(Lambda lambda, - PostLambda post_process, - const index& index, - const T* queries, - const uint32_t* coarse_index, - const uint32_t num_queries, - const uint32_t queries_offset, - const uint32_t n_probes, - const uint32_t k, - IvfSampleFilterT sample_filter, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - RAFT_EXPECTS(Veclen == index.veclen(), - "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = interleaved_scan_kernel; - const int max_query_smem = 16384; - int query_smem_elems = std::min(max_query_smem / sizeof(T), - raft::Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); - - // power-of-two less than cuda limit (for better addr alignment) - constexpr uint32_t kMaxGridY = 32768; - - if (grid_dim_x == 0) { - grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); - return; - } - - for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { - uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); - dim3 grid_dim(grid_dim_x, grid_dim_y, 1); - dim3 block_dim(kThreadsPerBlock); - RAFT_LOG_TRACE( - "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " - "smem_size = %d", - grid_dim.x, - grid_dim.y, - block_dim.x, - n_probes, - smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.inds_ptrs().data_handle(), - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - queries_offset + query_offset, - n_probes, - k, - index.dim(), - sample_filter, - neighbors, - distances); - queries += grid_dim_y * index.dim(); - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; - coarse_index += grid_dim_y * n_probes; - } -} - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - const auto diff = x - y; - acc += diff * diff; - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) - { - if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = raft::dp4a(diff, diff, acc); - } else { - const auto diff = __usad(x, y, 0u); - acc += diff * diff; - } - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of raft::dp4a is used, because the - // difference between two int8 numbers can be greater than 127 and therefore represented as a - // negative number in int8. Casting from int8 to int32 would yield incorrect results, while - // casting from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = raft::dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - -template -struct inner_prod_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { - acc = raft::dp4a(x, y, acc); - } else { - acc += x * y; - } - } -}; - -/** Select the distance computation function and forward the rest of the arguments. */ -template -void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... args) -{ - switch (metric) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case cuvs::distance::DistanceType::L2SqrtExpanded: - case cuvs::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case cuvs::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); - } -} - -/** - * Lift the `capacity` and `veclen` parameters to the template level, - * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. - */ -template (1, 16 / sizeof(T))> -struct select_interleaved_scan_kernel { - /** - * Recursively reduce the `Capacity` and `Veclen` parameters until they match the - * corresponding runtime arguments. - * By default, this recursive process starts with maximum possible values of the - * two parameters and ends with both values equal to 1. - */ - template - static inline void run(int capacity, int veclen, bool select_min, Args&&... args) - { - if constexpr (Capacity > 1) { - if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run(capacity, - veclen, - select_min, - std::forward(args)...); - } - } - if constexpr (Veclen > 1) { - if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( - capacity, 1, select_min, std::forward(args)...); - } - } - // NB: this is the limitation of the warpsort structures that use a huge number of - // registers (used in the main kernel here). - RAFT_EXPECTS(capacity == Capacity, - "Capacity must be power-of-two not bigger than the maximum allowed size " - "matrix::detail::select::warpsort::kMaxCapacity (%d).", - raft::matrix::detail::select::warpsort::kMaxCapacity); - RAFT_EXPECTS( - veclen == Veclen, - "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); - if (select_min) { - launch_with_fixed_consts( - std::forward(args)...); - } else { - launch_with_fixed_consts( - std::forward(args)...); - } - } -}; - -/** - * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. - * - * @tparam T value type - * @tparam AccT accumulated type - * @tparam IdxT type of the indices - * - * @param index previously built ivf-flat index - * @param[in] queries device pointer to the query vectors [batch_size, dim] - * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] - * @param n_queries batch size - * @param[in] queries_offset - * An offset of the current query batch. It is used for feeding sample_filter with the - * correct query index. - * @param metric type of the measured distance - * @param n_probes number of nearest clusters to query - * @param k number of nearest neighbors. - * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. - * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given - * metric. - * @param[out] neighbors device pointer to the result indices for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[out] distances device pointer to the result distances for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; - * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) - * @param stream - * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to - * provide a green light for every sample. - */ -template -void ivfflat_interleaved_scan(const index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const uint32_t queries_offset, - const cuvs::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IvfSampleFilterT sample_filter, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - const int capacity = raft::bound_by_power_of_two(k); - - auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( - index.inds_ptrs().data_handle(), sample_filter); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - queries_offset, - n_probes, - k, - filter_adapter, - neighbors, - distances, - grid_dim_x, - stream); -} - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan.cuh deleted file mode 100644 index 63f341dd9..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_interleaved_scan.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) -#include "ivf_flat_interleaved_scan-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_interleaved_scan-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_search-ext.cuh deleted file mode 100644 index 3a8776f7c..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_search-ext.cuh +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // uintX_t -#include // cuvs::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::ivf_flat::detail { - -template -void search(raft::resources const& handle, - const search_params& params, - const cuvs::neighbors::ivf_flat::index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ - extern template void cuvs::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const cuvs::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr, \ - IvfSampleFilterT sample_filter) - -instantiate_raft_neighbors_ivf_flat_detail_search( - float, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - int8_t, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - uint8_t, int64_t, cuvs::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_search-inl.cuh deleted file mode 100644 index 7f613963b..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_search-inl.cuh +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // is_min_close, DistanceType -#include // interleaved_scan -#include // cuvs::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // utils::mapping -#include // RAFT_LOG_TRACE -#include -#include // raft::resources -#include // raft::linalg::gemm -#include // raft::linalg::norm -#include // raft::linalg::unary_op -#include // raft::matrix::detail::select_k -#include // rmm::device_memory_resource - -namespace cuvs::neighbors::ivf_flat::detail { - -using namespace cuvs::spatial::knn::detail; // NOLINT - -template -void search_impl(raft::resources const& handle, - const cuvs::neighbors::ivf_flat::index& index, - const T* queries, - uint32_t n_queries, - uint32_t queries_offset, - uint32_t k, - uint32_t n_probes, - bool select_min, - IdxT* neighbors, - AccT* distances, - rmm::mr::device_memory_resource* search_mr, - IvfSampleFilterT sample_filter) -{ - auto stream = resource::get_cuda_stream(handle); - // The norm of query - rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); - // The distance value of cluster(list) and queries - rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); - // The topk distance value of cluster(list) and queries - rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); - // The topk index of cluster(list) and queries - rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); - // The topk distance value of candidate vectors from each cluster(list) - rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); - - size_t float_query_size; - if constexpr (std::is_integral_v) { - float_query_size = n_queries * index.dim(); - } else { - float_query_size = 0; - } - rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); - float* converted_queries_ptr = converted_queries_dev.data(); - - if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries); - } else { - linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); - } - - float alpha = 1.0f; - float beta = 0.0f; - - // todo(lsugy): raft distance? (if performance is similar/better than gemm) - switch (index.metric()) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - break; - } - default: { - alpha = 1.0f; - beta = 0.0f; - } - } - - linalg::gemm(handle, - true, - false, - index.n_lists(), - n_queries, - index.dim(), - &alpha, - index.centers().data_handle(), - index.dim(), - converted_queries_ptr, - index.dim(), - &beta, - distance_buffer_dev.data(), - index.n_lists(), - stream); - - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - raft::matrix::detail::select_k(handle, - distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min, - search_mr); - RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); - RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); - - auto distances_dev_ptr = refined_distances_dev.data(); - auto indices_dev_ptr = refined_indices_dev.data(); - - uint32_t grid_dim_x = 0; - if (n_probes > 1) { - // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( - index, - nullptr, - nullptr, - n_queries, - queries_offset, - index.metric(), - n_probes, - k, - select_min, - sample_filter, - nullptr, - nullptr, - grid_dim_x, - stream); - } else { - grid_dim_x = 1; - } - - if (grid_dim_x == 1) { - distances_dev_ptr = distances; - indices_dev_ptr = neighbors; - } - - ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( - index, - queries, - coarse_indices_dev.data(), - n_queries, - queries_offset, - index.metric(), - n_probes, - k, - select_min, - sample_filter, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); - - RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); - RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); - - // Merge topk values from different blocks - if (grid_dim_x > 1) { - raft::matrix::detail::select_k(handle, - refined_distances_dev.data(), - refined_indices_dev.data(), - n_queries, - k * grid_dim_x, - k, - distances, - neighbors, - select_min, - search_mr); - } -} - -/** See cuvs::neighbors::ivf_flat::search docs */ -template -inline void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - raft::common::nvtx::range fun_scope( - "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); - - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); - - // a batch size heuristic: try to keep the workspace within the specified size - constexpr uint32_t kExpectedWsSize = 1024 * 1024 * 1024; - const uint32_t max_queries = - std::min(n_queries, - raft::div_rounding_up_safe( - kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim())); - - auto pool_guard = raft::get_pool_memory_resource(mr, max_queries * n_probes * k * 16); - if (pool_guard) { - RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", - n_queries * n_probes * k * 16ull); - } - - for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { - uint32_t queries_batch = min(max_queries, n_queries - offset_q); - - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - offset_q, - k, - n_probes, - cuvs::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr, - sample_filter); - } -} - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_search.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_search.cuh deleted file mode 100644 index 7b03ebeab..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_search.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ivf_flat_search-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_search-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/cuvs/neighbors/detail/ivf_flat_serialize.cuh deleted file mode 100644 index 60d2392be..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_flat_serialize.cuh +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs::neighbors::ivf_flat::detail { - -// Serialization version -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int serialization_version = 4; - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index_ IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, std::ostream& os, const index& index_) -{ - RAFT_LOG_DEBUG( - "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - - std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); - dtype_string.resize(4); - os << dtype_string; - - serialize_scalar(handle, os, serialization_version); - serialize_scalar(handle, os, index_.size()); - serialize_scalar(handle, os, index_.dim()); - serialize_scalar(handle, os, index_.n_lists()); - serialize_scalar(handle, os, index_.metric()); - serialize_scalar(handle, os, index_.adaptive_centers()); - serialize_scalar(handle, os, index_.conservative_memory_allocation()); - serialize_mdspan(handle, os, index_.centers()); - if (index_.center_norms()) { - bool has_norms = true; - serialize_scalar(handle, os, has_norms); - serialize_mdspan(handle, os, *index_.center_norms()); - } else { - bool has_norms = false; - serialize_scalar(handle, os, has_norms); - } - auto sizes_host = raft::make_host_vector(index_.list_sizes().extent(0)); - copy(sizes_host.data_handle(), - index_.list_sizes().data_handle(), - sizes_host.size(), - resource::get_cuda_stream(handle)); - resource::sync_stream(handle); - serialize_mdspan(handle, os, sizes_host.view()); - - list_spec list_store_spec{index_.dim(), true}; - for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::serialize_list(handle, - os, - index_.lists()[label], - list_store_spec, - raft::Pow2::roundUp(sizes_host(label))); - } - resource::sync_stream(handle); -} - -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(handle, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ IVF-Flat index - * - */ -template -auto deserialize(raft::resources const& handle, std::istream& is) -> index -{ - char dtype_string[4]; - is.read(dtype_string, 4); - - auto ver = deserialize_scalar(handle, is); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(handle, is); - auto dim = deserialize_scalar(handle, is); - auto n_lists = deserialize_scalar(handle, is); - auto metric = deserialize_scalar(handle, is); - bool adaptive_centers = deserialize_scalar(handle, is); - bool cma = deserialize_scalar(handle, is); - - index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); - - deserialize_mdspan(handle, is, index_.centers()); - bool has_norms = deserialize_scalar(handle, is); - if (has_norms) { - index_.allocate_center_norms(handle); - if (!index_.center_norms()) { - RAFT_FAIL("Error inconsistent center norms"); - } else { - auto center_norms = index_.center_norms().value(); - deserialize_mdspan(handle, is, center_norms); - } - } - deserialize_mdspan(handle, is, index_.list_sizes()); - - list_spec list_device_spec{index_.dim(), cma}; - list_spec list_store_spec{index_.dim(), true}; - for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); - } - resource::sync_stream(handle); - - index_.recompute_internal_state(handle); - - return index_; -} - -template -auto deserialize(raft::resources const& handle, const std::string& filename) -> index -{ - std::ifstream is(filename, std::ios::in | std::ios::binary); - - if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(handle, is); - - is.close(); - - return index; -} -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_build.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_build.cuh deleted file mode 100644 index c3d3152e5..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_build.cuh +++ /dev/null @@ -1,1931 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include - -#include -#include - -namespace cuvs::neighbors::ivf_pq::detail { - -using namespace cuvs::spatial::knn::detail; // NOLINT - -template -__launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( - T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) -{ - using warp = raft::Pow2; - size_t row_ix = warp::div(size_t(threadIdx.x) + size_t(BlockDim) * size_t(blockIdx.x)); - uint32_t i = warp::mod(threadIdx.x); - if (row_ix >= n_rows) return; - out += row_ix * ld_out; - in += row_ix * ld_in; - auto f = utils::mapping{}; - for (uint32_t col_ix = i; col_ix < n_cols; col_ix += warp::Value) { - auto x = f(in[col_ix]); - __syncwarp(); - out[col_ix] = x; - } -} - -/** - * Copy the data one warp-per-row: - * - * 1. load the data per-warp - * 2. apply the `utils::mapping{}` - * 3. sync within warp - * 4. store the data. - * - * Assuming sizeof(T) >= sizeof(S) and the data is properly aligned (see the usage in `build`), this - * allows to re-structure the data within rows in-place. - */ -template -void copy_warped(T* out, - uint32_t ld_out, - const S* in, - uint32_t ld_in, - uint32_t n_cols, - size_t n_rows, - rmm::cuda_stream_view stream) -{ - constexpr uint32_t kBlockDim = 128; - dim3 threads(kBlockDim, 1, 1); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockDim / raft::WarpSize), 1, 1); - copy_warped_kernel - <<>>(out, ld_out, in, ld_in, n_cols, n_rows); -} - -/** - * @brief Fill-in a random orthogonal transformation matrix. - * - * @param handle - * @param force_random_rotation - * @param n_rows - * @param n_cols - * @param[out] rotation_matrix device pointer to a row-major matrix of size [n_rows, n_cols]. - * @param rng random number generator state - */ -inline void make_rotation_matrix(raft::resources const& handle, - bool force_random_rotation, - uint32_t n_rows, - uint32_t n_cols, - float* rotation_matrix, - raft::random::RngState rng = raft::random::RngState(7ULL)) -{ - raft::common::nvtx::range fun_scope( - "ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols); - auto stream = resource::get_cuda_stream(handle); - bool inplace = n_rows == n_cols; - uint32_t n = std::max(n_rows, n_cols); - if (force_random_rotation || !inplace) { - rmm::device_uvector buf(inplace ? 0 : n * n, stream); - float* mat = inplace ? rotation_matrix : buf.data(); - raft::random::normal(handle, rng, mat, n * n, 0.0f, 1.0f); - linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream); - if (!inplace) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix, - sizeof(float) * n_cols, - mat, - sizeof(float) * n, - sizeof(float) * n_cols, - n_rows, - cudaMemcpyDefault, - stream)); - } - } else { - uint32_t stride = n + 1; - auto rotation_matrix_view = - raft::make_device_vector_view(rotation_matrix, n * n); - linalg::map_offset(handle, rotation_matrix_view, [stride] __device__(uint32_t i) { - return static_cast(i % stride == 0u); - }); - } -} - -/** - * @brief Compute residual vectors from the source dataset given by selected indices. - * - * The residual has the form `rotation_matrix %* (dataset[row_ids, :] - center)` - * - */ -template -void select_residuals(raft::resources const& handle, - float* residuals, - IdxT n_rows, - uint32_t dim, - uint32_t rot_dim, - const float* rotation_matrix, // [rot_dim, dim] - const float* center, // [dim] - const T* dataset, // [.., dim] - const IdxT* row_ids, // [n_rows] - rmm::mr::device_memory_resource* device_memory - -) -{ - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector tmp(size_t(n_rows) * size_t(dim), stream, device_memory); - // Note: the number of rows of the input dataset isn't actually n_rows, but raft::matrix::gather - // doesn't need to know it, any strictly positive number would work. - cub::TransformInputIterator, const T*> mapping_itr( - dataset, utils::mapping{}); - raft::matrix::gather(mapping_itr, (IdxT)dim, n_rows, row_ids, n_rows, tmp.data(), stream); - - raft::matrix::linewise_op(handle, - raft::make_device_matrix_view(tmp.data(), n_rows, dim), - raft::make_device_matrix_view(tmp.data(), n_rows, dim), - true, - raft::sub_op{}, - raft::make_device_vector_view(center, dim)); - - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - rot_dim, - n_rows, - dim, - &alpha, - rotation_matrix, - dim, - tmp.data(), - dim, - &beta, - residuals, - rot_dim, - stream); -} - -/** - * @brief Compute residual vectors from the source dataset given by selected indices. - * - * The residual has the form - * `rotation_matrix %* (dataset[:, :] - centers[labels[:], 0:dim])` - * - */ -template -void flat_compute_residuals( - raft::resources const& handle, - float* residuals, // [n_rows, rot_dim] - IdxT n_rows, - raft::device_matrix_view - rotation_matrix, // [rot_dim, dim] - raft::device_matrix_view centers, // [n_lists, dim_ext] - const T* dataset, // [n_rows, dim] - std::variant labels, // [n_rows] - rmm::mr::device_memory_resource* device_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto dim = rotation_matrix.extent(1); - auto rot_dim = rotation_matrix.extent(0); - rmm::device_uvector tmp(n_rows * dim, stream, device_memory); - auto tmp_view = raft::make_device_vector_view(tmp.data(), tmp.size()); - linalg::map_offset(handle, tmp_view, [centers, dataset, labels, dim] __device__(size_t i) { - auto row_ix = i / dim; - auto el_ix = i % dim; - auto label = std::holds_alternative(labels) - ? std::get(labels) - : std::get(labels)[row_ix]; - return utils::mapping{}(dataset[i]) - centers(label, el_ix); - }); - - float alpha = 1.0f; - float beta = 0.0f; - linalg::gemm(handle, - true, - false, - rot_dim, - n_rows, - dim, - &alpha, - rotation_matrix.data_handle(), - dim, - tmp.data(), - dim, - &beta, - residuals, - rot_dim, - stream); -} - -template -__launch_bounds__(BlockDim) RAFT_KERNEL - fill_indices_kernel(IdxT n_rows, IdxT* data_indices, IdxT* data_offsets, const uint32_t* labels) -{ - const auto i = IdxT(BlockDim) * IdxT(blockIdx.x) + IdxT(threadIdx.x); - if (i >= n_rows) { return; } - data_indices[atomicAdd(data_offsets + labels[i], 1)] = i; -} - -/** - * @brief Calculate cluster offsets and arrange data indices into clusters. - * - * @param n_rows - * @param n_lists - * @param[in] labels output of k-means prediction [n_rows] - * @param[in] cluster_sizes [n_lists] - * @param[out] cluster_offsets [n_lists+1] - * @param[out] data_indices [n_rows] - * - * @return size of the largest cluster - */ -template -auto calculate_offsets_and_indices(IdxT n_rows, - uint32_t n_lists, - const uint32_t* labels, - const uint32_t* cluster_sizes, - IdxT* cluster_offsets, - IdxT* data_indices, - rmm::cuda_stream_view stream) -> uint32_t -{ - auto exec_policy = rmm::exec_policy(stream); - // Calculate the offsets - IdxT cumsum = 0; - update_device(cluster_offsets, &cumsum, 1, stream); - thrust::inclusive_scan( - exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, add_op{}); - update_host(&cumsum, cluster_offsets + n_lists, 1, stream); - uint32_t max_cluster_size = - *thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists); - stream.synchronize(); - RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up."); - RAFT_LOG_DEBUG("Max cluster size %d", max_cluster_size); - rmm::device_uvector data_offsets_buf(n_lists, stream); - auto data_offsets = data_offsets_buf.data(); - copy(data_offsets, cluster_offsets, n_lists, stream); - constexpr uint32_t n_threads = 128; // NOLINT - const IdxT n_blocks = raft::div_rounding_up_unsafe(n_rows, n_threads); - fill_indices_kernel - <<>>(n_rows, data_indices, data_offsets, labels); - return max_cluster_size; -} - -template -void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) -{ - auto stream = resource::get_cuda_stream(handle); - auto* device_memory = resource::get_workspace_resource(handle); - - // combine cluster_centers and their norms - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), - sizeof(float) * index->dim_ext(), - cluster_centers, - sizeof(float) * index->dim(), - sizeof(float) * index->dim(), - index->n_lists(), - cudaMemcpyDefault, - stream)); - - rmm::device_uvector center_norms(index->n_lists(), stream, device_memory); - raft::linalg::rowNorm(center_norms.data(), - cluster_centers, - index->dim(), - index->n_lists(), - raft::linalg::L2Norm, - true, - stream); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(), - sizeof(float) * index->dim_ext(), - center_norms.data(), - sizeof(float), - sizeof(float), - index->n_lists(), - cudaMemcpyDefault, - stream)); - - // Rotate cluster_centers - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index->rot_dim(), - index->n_lists(), - index->dim(), - &alpha, - index->rotation_matrix().data_handle(), - index->dim(), - cluster_centers, - index->dim(), - &beta, - index->centers_rot().data_handle(), - index->rot_dim(), - resource::get_cuda_stream(handle)); -} - -template -void transpose_pq_centers(const resources& handle, - index& index, - const float* pq_centers_source) -{ - auto stream = resource::get_cuda_stream(handle); - auto extents = index.pq_centers().extents(); - static_assert(extents.rank() == 3); - auto extents_source = - make_extents(extents.extent(0), extents.extent(2), extents.extent(1)); - auto span_source = make_mdspan( - pq_centers_source, extents_source); - auto pq_centers_view = raft::make_device_vector_view( - index.pq_centers().data_handle(), index.pq_centers().size()); - linalg::map_offset(handle, pq_centers_view, [span_source, extents] __device__(size_t i) { - uint32_t ii[3]; - for (int r = 2; r > 0; r--) { - ii[r] = i % extents.extent(r); - i /= extents.extent(r); - } - ii[0] = i; - return span_source(ii[0], ii[2], ii[1]); - }); -} - -template -void train_per_subset(raft::resources const& handle, - index& index, - size_t n_rows, - const float* trainset, // [n_rows, dim] - const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto device_memory = resource::get_workspace_resource(handle); - - rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory); - rmm::device_uvector sub_labels(n_rows, stream, device_memory); - - rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); - - for (uint32_t j = 0; j < index.pq_dim(); j++) { - raft::common::nvtx::range pq_per_subspace_scope( - "ivf_pq::build::per_subspace[%u]", j); - - // Get the rotated cluster centers for each training vector. - // This will be subtracted from the input vectors afterwards. - utils::copy_selected( - n_rows, - index.pq_len(), - index.centers_rot().data_handle() + index.pq_len() * j, - labels, - index.rot_dim(), - sub_trainset.data(), - index.pq_len(), - stream); - - // sub_trainset is the slice of: rotate(trainset) - centers_rot - float alpha = 1.0; - float beta = -1.0; - linalg::gemm(handle, - true, - false, - index.pq_len(), - n_rows, - index.dim(), - &alpha, - index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j, - index.dim(), - trainset, - index.dim(), - &beta, - sub_trainset.data(), - index.pq_len(), - stream); - - // train PQ codebook for this subspace - auto sub_trainset_view = - raft::make_device_matrix_view(sub_trainset.data(), n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( - pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, - index.pq_book_size(), - index.pq_len()); - auto sub_labels_view = raft::make_device_vector_view(sub_labels.data(), n_rows); - auto cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - cuvs::cluster::kmeans_balanced::helpers::build_clusters(handle, - kmeans_params, - sub_trainset_view, - centers_tmp_view, - sub_labels_view, - cluster_sizes_view, - utils::mapping{}); - } - transpose_pq_centers(handle, index, pq_centers_tmp.data()); -} - -template -void train_per_cluster(raft::resources const& handle, - index& index, - size_t n_rows, - const float* trainset, // [n_rows, dim] - const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto device_memory = resource::get_workspace_resource(handle); - - rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector cluster_sizes(index.n_lists(), stream, managed_memory); - rmm::device_uvector indices_buf(n_rows, stream, device_memory); - rmm::device_uvector offsets_buf(index.n_lists() + 1, stream, managed_memory); - - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(cluster_sizes.data()), - index.n_lists(), - labels, - n_rows, - 1, - stream); - - auto cluster_offsets = offsets_buf.data(); - auto indices = indices_buf.data(); - uint32_t max_cluster_size = calculate_offsets_and_indices( - IdxT(n_rows), index.n_lists(), labels, cluster_sizes.data(), cluster_offsets, indices, stream); - - rmm::device_uvector pq_labels( - size_t(max_cluster_size) * size_t(index.pq_dim()), stream, device_memory); - rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); - rmm::device_uvector rot_vectors( - size_t(max_cluster_size) * size_t(index.rot_dim()), stream, device_memory); - - resource::sync_stream(handle); // make sure cluster offsets are up-to-date - for (uint32_t l = 0; l < index.n_lists(); l++) { - auto cluster_size = cluster_sizes.data()[l]; - if (cluster_size == 0) continue; - raft::common::nvtx::range pq_per_cluster_scope( - "ivf_pq::build::per_cluster[%u](size = %u)", l, cluster_size); - - select_residuals(handle, - rot_vectors.data(), - IdxT(cluster_size), - index.dim(), - index.rot_dim(), - index.rotation_matrix().data_handle(), - index.centers().data_handle() + size_t(l) * size_t(index.dim_ext()), - trainset, - indices + cluster_offsets[l], - device_memory); - - // limit the cluster size to bound the training time. - // [sic] we interpret the data as pq_len-dimensional - size_t big_enough = 256ul * std::max(index.pq_book_size(), index.pq_dim()); - size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); - auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); - // train PQ codebook for this cluster - auto rot_vectors_view = raft::make_device_matrix_view( - rot_vectors.data(), pq_n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( - pq_centers_tmp.data() + static_cast(index.pq_book_size()) * - static_cast(index.pq_len()) * static_cast(l), - index.pq_book_size(), - index.pq_len()); - auto pq_labels_view = - raft::make_device_vector_view(pq_labels.data(), pq_n_rows); - auto pq_cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = kmeans_n_iters; - kmeans_params.metric = cuvs::distance::DistanceType::L2Expanded; - cuvs::cluster::kmeans_balanced::helpers::build_clusters(handle, - kmeans_params, - rot_vectors_view, - centers_tmp_view, - pq_labels_view, - pq_cluster_sizes_view, - utils::mapping{}); - } - transpose_pq_centers(handle, index, pq_centers_tmp.data()); -} - -/** - * A helper function: given the dataset in the rotated space - * [n_rows, rot_dim] = [n_rows, pq_dim * pq_len], - * reinterpret the last dimension as two: [n_rows, pq_dim, pq_len] - * - * @tparam T - * @tparam IdxT - * - * @param vectors input data [n_rows, rot_dim] - * @param pq_centers codebook (used to infer the structure - pq_len) - * @return reinterpreted vectors [n_rows, pq_dim, pq_len] - */ -template -static __device__ auto reinterpret_vectors( - raft::device_matrix_view vectors, - raft::device_mdspan, raft::row_major> pq_centers) - -> raft::device_mdspan, raft::row_major> -{ - const uint32_t pq_len = pq_centers.extent(1); - const uint32_t pq_dim = vectors.extent(1) / pq_len; - using layout_t = typename decltype(vectors)::layout_type; - using accessor_t = typename decltype(vectors)::accessor_type; - return raft::mdspan, layout_t, accessor_t>( - vectors.data_handle(), extent_3d{vectors.extent(0), pq_dim, pq_len}); -} - -/** - * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes - * one-per-byte. That is, independent of the code width (pq_bits), one code uses - * the whole byte, hence one vectors uses pq_dim bytes. - */ -struct unpack_codes { - raft::device_matrix_view out_codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_codes the destination for the read codes. - */ - __device__ inline unpack_codes(device_matrix_view out_codes) - : out_codes{out_codes} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - out_codes(i, j) = code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL unpack_list_data_kernel( - raft::device_matrix_view out_codes, - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - std::variant offset_or_indices) -{ - const uint32_t pq_dim = out_codes.extent(1); - auto unpack_action = unpack_codes{out_codes}; - run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); -} - -/** - * Unpack flat PQ codes from an existing list by the given offset. - * - * @param[out] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] list_data the packed ivf::list data. - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void unpack_list_data( - raft::device_matrix_view codes, - raft::device_mdspan::list_extents, raft::row_major> - list_data, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return unpack_list_data_kernel; - case 5: return unpack_list_data_kernel; - case 6: return unpack_list_data_kernel; - case 7: return unpack_list_data_kernel; - case 8: return unpack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(codes, list_data, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Unpack the list data; see the public interface for the api and usage. */ -template -void unpack_list_data(raft::resources const& res, - const index& index, - raft::device_matrix_view out_codes, - uint32_t label, - std::variant offset_or_indices) -{ - unpack_list_data(out_codes, - index.lists()[label]->data.view(), - offset_or_indices, - index.pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * A consumer for the `run_on_vector` that just flattens PQ codes - * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. - */ -template -struct unpack_contiguous { - uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `run_on_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - bitfield_view_t code_view{codes + i * code_size}; - code_view[j] = code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel( - uint8_t* out_codes, - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices) -{ - run_on_list( - in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); -} - -/** - * Unpack flat PQ codes from an existing list by the given offset. - * - * @param[out] codes flat compressed PQ codes [n_rows, raft::ceildiv(pq_dim * pq_bits, 8)] - * @param[in] list_data the packed ivf::list data. - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void unpack_contiguous_list_data( - uint8_t* codes, - raft::device_mdspan::list_extents, raft::row_major> - list_data, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return unpack_contiguous_list_data_kernel; - case 5: return unpack_contiguous_list_data_kernel; - case 6: return unpack_contiguous_list_data_kernel; - case 7: return unpack_contiguous_list_data_kernel; - case 8: return unpack_contiguous_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(codes, list_data, n_rows, pq_dim, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Unpack the list data; see the public interface for the api and usage. */ -template -void unpack_contiguous_list_data(raft::resources const& res, - const index& index, - uint8_t* out_codes, - uint32_t n_rows, - uint32_t label, - std::variant offset_or_indices) -{ - unpack_contiguous_list_data(out_codes, - index.lists()[label]->data.view(), - n_rows, - index.pq_dim(), - offset_or_indices, - index.pq_bits(), - resource::get_cuda_stream(res)); -} - -/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. - */ -struct reconstruct_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - uint32_t pq_len; - raft::device_mdspan, raft::row_major> pq_centers; - raft::device_mdspan, raft::row_major> centers_rot; - raft::device_mdspan, raft::row_major> out_vectors; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_vectors the destination for the decoded vectors. - * @param[in] pq_centers the codebook - * @param[in] centers_rot - * @param[in] codebook_kind - * @param[in] cluster_ix label/id of the cluster. - */ - __device__ inline reconstruct_vectors( - raft::device_matrix_view out_vectors, - raft::device_mdspan, raft::row_major> pq_centers, - raft::device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_len{pq_centers.extent(1)}, - pq_centers{pq_centers}, - centers_rot{reinterpret_vectors(centers_rot, pq_centers)}, - out_vectors{reinterpret_vectors(out_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - for (uint32_t k = 0; k < pq_len; k++) { - out_vectors(i, j, k) = pq_centers(partition_ix, k, code) + centers_rot(cluster_ix, j, k); - } - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL reconstruct_list_data_kernel( - raft::device_matrix_view out_vectors, - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - raft::device_mdspan, raft::row_major> pq_centers, - raft::device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix, - std::variant offset_or_indices) -{ - const uint32_t pq_dim = out_vectors.extent(1) / pq_centers.extent(1); - auto reconstruct_action = - reconstruct_vectors{out_vectors, pq_centers, centers_rot, codebook_kind, cluster_ix}; - run_on_list( - in_list_data, offset_or_indices, out_vectors.extent(0), pq_dim, reconstruct_action); -} - -/** Decode the list data; see the public interface for the api and usage. */ -template -void reconstruct_list_data(raft::resources const& res, - const index& index, - raft::device_matrix_view out_vectors, - uint32_t label, - std::variant offset_or_indices) -{ - auto n_rows = out_vectors.extent(0); - if (n_rows == 0) { return; } - auto& list = index.lists()[label]; - if (std::holds_alternative(offset_or_indices)) { - auto n_skip = std::get(offset_or_indices); - // sic! I'm using the upper bound `list.size` instead of exact `list_sizes(label)` - // to avoid an extra device-host data copy and the stream sync. - RAFT_EXPECTS(n_skip + n_rows <= list->size.load(), - "offset + output size must be not bigger than the cluster size."); - } - - auto tmp = raft::make_device_mdarray( - res, resource::get_workspace_resource(res), make_extents(n_rows, index.rot_dim())); - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return reconstruct_list_data_kernel; - case 5: return reconstruct_list_data_kernel; - case 6: return reconstruct_list_data_kernel; - case 7: return reconstruct_list_data_kernel; - case 8: return reconstruct_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index.pq_bits()); - kernel<<>>(tmp.view(), - list->data.view(), - index.pq_centers(), - index.centers_rot(), - index.codebook_kind(), - label, - offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - float* out_float_ptr = nullptr; - rmm::device_uvector out_float_buf( - 0, resource::get_cuda_stream(res), resource::get_workspace_resource(res)); - if constexpr (std::is_same_v) { - out_float_ptr = out_vectors.data_handle(); - } else { - out_float_buf.resize(size_t{n_rows} * size_t{index.dim()}, resource::get_cuda_stream(res)); - out_float_ptr = out_float_buf.data(); - } - // Rotate the results back to the original space - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(res, - false, - false, - index.dim(), - n_rows, - index.rot_dim(), - &alpha, - index.rotation_matrix().data_handle(), - index.dim(), - tmp.data_handle(), - index.rot_dim(), - &beta, - out_float_ptr, - index.dim(), - resource::get_cuda_stream(res)); - // Transform the data to the original type, if necessary - if constexpr (!std::is_same_v) { - linalg::map(res, - out_vectors, - utils::mapping{}, - raft::make_device_matrix_view(out_float_ptr, n_rows, index.dim())); - } -} - -/** - * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, - * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses - * pq_dim bytes. - */ -struct pass_codes { - raft::device_matrix_view codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[in] codes the source codes. - */ - __device__ inline pass_codes(device_matrix_view codes) - : codes{codes} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL pack_list_data_kernel( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - raft::device_matrix_view codes, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); -} - -/** - * Write flat PQ codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). - * - * @param[out] list_data the packed ivf::list data. - * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void pack_list_data( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - raft::device_matrix_view codes, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return pack_list_data_kernel; - case 5: return pack_list_data_kernel; - case 6: return pack_list_data_kernel; - case 7: return pack_list_data_kernel; - case 8: return pack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(list_data, codes, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void pack_list_data(raft::resources const& res, - index* index, - raft::device_matrix_view new_codes, - uint32_t label, - std::variant offset_or_indices) -{ - pack_list_data(index->lists()[label]->data.view(), - new_codes, - offset_or_indices, - index->pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * A producer for the `write_vector` reads tightly packed flat codes. That is, - * the codes are not expanded to one code-per-byte. - */ -template -struct pack_contiguous { - const uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `write_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t - { - bitfield_view_t code_view{const_cast(codes + i * code_size)}; - return uint8_t(code_view[j]); - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - const uint8_t* codes, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); -} - -/** - * Write flat PQ codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). - * - * @param[out] list_data the packed ivf::list data. - * @param[in] codes flat compressed PQ codes [n_rows, raft::ceildiv(pq_dim * pq_bits, 8)] - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void pack_contiguous_list_data( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - const uint8_t* codes, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return pack_contiguous_list_data_kernel; - case 5: return pack_contiguous_list_data_kernel; - case 6: return pack_contiguous_list_data_kernel; - case 7: return pack_contiguous_list_data_kernel; - case 8: return pack_contiguous_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(list_data, codes, n_rows, pq_dim, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void pack_contiguous_list_data(raft::resources const& res, - index* index, - const uint8_t* new_codes, - uint32_t n_rows, - uint32_t label, - std::variant offset_or_indices) -{ - pack_contiguous_list_data(index->lists()[label]->data.view(), - new_codes, - n_rows, - index->pq_dim(), - offset_or_indices, - index->pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * - * A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals - * into lvl-2 PQ codes. - * Computing a PQ code means finding the closest cluster in a pq_dim-subspace. - * - * @tparam SubWarpSize - * how many threads work on a single vector; - * bounded by either raft::WarpSize or pq_book_size. - * - * @param pq_centers - * - codebook_gen::PER_SUBSPACE: [pq_dim , pq_len, pq_book_size] - * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] - * @param new_vector a single input of length rot_dim, reinterpreted as [pq_dim, pq_len]. - * the input must be already transformed to floats, rotated, and the level 1 cluster - * center must be already substructed (i.e. this is the residual of a single input vector). - * @param codebook_kind - * @param j index along pq_dim "dimension" - * @param cluster_ix is used for PER_CLUSTER codebooks. - */ -/** - */ -template -struct encode_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - raft::device_mdspan, raft::row_major> pq_centers; - raft::device_mdspan, raft::row_major> in_vectors; - - __device__ inline encode_vectors( - raft::device_mdspan, raft::row_major> pq_centers, - raft::device_matrix_view in_vectors, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_centers{pq_centers}, - in_vectors{reinterpret_vectors(in_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline auto operator()(IdxT i, uint32_t j) -> uint8_t - { - uint32_t lane_id = raft::Pow2::mod(laneId()); - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - - const uint32_t pq_book_size = pq_centers.extent(2); - const uint32_t pq_len = pq_centers.extent(1); - float min_dist = std::numeric_limits::infinity(); - uint8_t code = 0; - // calculate the distance for each PQ cluster, find the minimum for each thread - for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { - // NB: the L2 quantifiers on residuals are always trained on L2 metric. - float d = 0.0f; - for (uint32_t k = 0; k < pq_len; k++) { - auto t = in_vectors(i, j, k) - pq_centers(partition_ix, k, l); - d += t * t; - } - if (d < min_dist) { - min_dist = d; - code = uint8_t(l); - } - } - // reduce among threads -#pragma unroll - for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { - const auto other_dist = raft::shfl_xor(min_dist, stride, SubWarpSize); - const auto other_code = raft::shfl_xor(code, stride, SubWarpSize); - if (other_dist < min_dist) { - min_dist = other_dist; - code = other_code; - } - } - return code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_kernel( - raft::device_matrix_view new_vectors, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - raft::device_vector_view list_sizes, - raft::device_vector_view inds_ptrs, - raft::device_vector_view data_ptrs, - raft::device_mdspan, raft::row_major> pq_centers, - codebook_gen codebook_kind) -{ - constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); - using subwarp_align = raft::Pow2; - const uint32_t lane_id = subwarp_align::mod(threadIdx.x); - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= new_vectors.extent(0)) { return; } - - const uint32_t cluster_ix = new_labels[row_ix]; - uint32_t out_ix; - if (lane_id == 0) { out_ix = atomicAdd(&list_sizes(cluster_ix), 1); } - out_ix = raft::shfl(out_ix, 0, kSubWarpSize); - - // write the label (one record per subwarp) - auto pq_indices = inds_ptrs(cluster_ix); - if (lane_id == 0) { - if (std::holds_alternative(src_offset_or_indices)) { - pq_indices[out_ix] = std::get(src_offset_or_indices) + row_ix; - } else { - pq_indices[out_ix] = std::get(src_offset_or_indices)[row_ix]; - } - } - - // write the codes (one record per subwarp): - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); - auto pq_dataset = - make_mdspan(data_ptrs[cluster_ix], pq_extents); - write_vector( - pq_dataset, - out_ix, - row_ix, - pq_dim, - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}); -} - -template -__launch_bounds__(BlockSize) RAFT_KERNEL encode_list_data_kernel( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - raft::device_matrix_view new_vectors, - raft::device_mdspan, raft::row_major> pq_centers, - codebook_gen codebook_kind, - uint32_t cluster_ix, - std::variant offset_or_indices) -{ - constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto encode_action = - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}; - write_list( - list_data, offset_or_indices, new_vectors.extent(0), pq_dim, encode_action); -} - -template -void encode_list_data(raft::resources const& res, - index* index, - raft::device_matrix_view new_vectors, - uint32_t label, - std::variant offset_or_indices) -{ - auto n_rows = new_vectors.extent(0); - if (n_rows == 0) { return; } - - auto mr = resource::get_workspace_resource(res); - - auto new_vectors_residual = - raft::make_device_mdarray(res, mr, make_extents(n_rows, index->rot_dim())); - - flat_compute_residuals(res, - new_vectors_residual.data_handle(), - n_rows, - index->rotation_matrix(), - index->centers(), - new_vectors.data_handle(), - label, - mr); - - constexpr uint32_t kBlockSize = 256; - const uint32_t threads_per_vec = std::min(WarpSize, index->pq_book_size()); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return encode_list_data_kernel; - case 5: return encode_list_data_kernel; - case 6: return encode_list_data_kernel; - case 7: return encode_list_data_kernel; - case 8: return encode_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index->pq_bits()); - kernel<<>>(index->lists()[label]->data.view(), - new_vectors_residual.view(), - index->pq_centers(), - index->codebook_kind(), - label, - offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * Assuming the index already has some data and allocated the space for more, write more data in it. - * There must be enough free space in `pq_dataset()` and `indices()`, as computed using - * `list_offsets()` and `list_sizes()`. - * - * NB: Since the pq_dataset is stored in the interleaved blocked format (see ivf_pq_types.hpp), one - * cannot just concatenate the old and the new codes; the positions for the codes are determined the - * same way as in the ivfpq_compute_similarity_kernel (see ivf_pq_search.cuh). - * - * @tparam T - * @tparam IdxT - * - * @param handle - * @param index - * @param[in] new_vectors - * a pointer to a row-major device array [index.dim(), n_rows]; - * @param[in] src_offset_or_indices - * references for the new data: - * either a starting index for the auto-indexing - * or a pointer to a device array of explicit indices [n_rows]; - * @param[in] new_labels - * cluster ids (first-level quantization) - a device array [n_rows]; - * @param n_rows - * the number of records to write in. - * @param mr - * a memory resource to use for device allocations - */ -template -void process_and_fill_codes(raft::resources const& handle, - index& index, - const T* new_vectors, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - IdxT n_rows, - rmm::mr::device_memory_resource* mr) -{ - auto new_vectors_residual = - raft::make_device_mdarray(handle, mr, make_extents(n_rows, index.rot_dim())); - - flat_compute_residuals(handle, - new_vectors_residual.data_handle(), - n_rows, - index.rotation_matrix(), - index.centers(), - new_vectors, - new_labels, - mr); - - constexpr uint32_t kBlockSize = 256; - const uint32_t threads_per_vec = std::min(WarpSize, index.pq_book_size()); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return process_and_fill_codes_kernel; - case 5: return process_and_fill_codes_kernel; - case 6: return process_and_fill_codes_kernel; - case 7: return process_and_fill_codes_kernel; - case 8: return process_and_fill_codes_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index.pq_bits()); - kernel<<>>(new_vectors_residual.view(), - src_offset_or_indices, - new_labels, - index.list_sizes(), - index.inds_ptrs(), - index.data_ptrs(), - index.pq_centers(), - index.codebook_kind()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Update the state of the dependent index members. */ -template -void recompute_internal_state(const raft::resources& res, index& index) -{ - auto stream = resource::get_cuda_stream(res); - auto tmp_res = resource::get_workspace_resource(res); - rmm::device_uvector sorted_sizes(index.n_lists(), stream, tmp_res); - - // Actualize the list pointers - auto data_ptrs = index.data_ptrs(); - auto inds_ptrs = index.inds_ptrs(); - for (uint32_t label = 0; label < index.n_lists(); label++) { - auto& list = index.lists()[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - copy(&data_ptrs(label), &data_ptr, 1, stream); - copy(&inds_ptrs(label), &inds_ptr, 1, stream); - } - - // Sort the cluster sizes in the descending order. - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortKeysDescending(nullptr, - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res); - cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(), - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - // copy the results to CPU - std::vector sorted_sizes_host(index.n_lists()); - copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); - resource::sync_stream(res); - - // accumulate the sorted cluster sizes - auto accum_sorted_sizes = index.accum_sorted_sizes(); - accum_sorted_sizes(0) = 0; - for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) { - accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label]; - } -} - -/** - * Helper function: allocate enough space in the list, compute the offset, at which to start - * writing, and fill-in indices. - * - * @return offset for writing the data - */ -template -auto extend_list_prepare( - raft::resources const& res, - index* index, - raft::device_vector_view new_indices, - uint32_t label) -> uint32_t -{ - uint32_t n_rows = new_indices.extent(0); - uint32_t offset; - // Allocate the lists to fit the new data - copy(&offset, index->list_sizes().data_handle() + label, 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); - uint32_t new_size = offset + n_rows; - copy(index->list_sizes().data_handle() + label, &new_size, 1, resource::get_cuda_stream(res)); - auto spec = list_spec{ - index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; - auto& list = index->lists()[label]; - ivf::resize_list(res, list, spec, new_size, offset); - copy(list->indices.data_handle() + offset, - new_indices.data_handle(), - n_rows, - resource::get_cuda_stream(res)); - return offset; -} - -/** - * Extend one list of the index in-place, by the list label, skipping the classification and - * encoding steps. - * See the public interface for the api and usage. - */ -template -void extend_list_with_codes( - raft::resources const& res, - index* index, - raft::device_matrix_view new_codes, - raft::device_vector_view new_indices, - uint32_t label) -{ - // Allocate memory and write indices - auto offset = extend_list_prepare(res, index, new_indices, label); - // Pack the data - pack_list_data(res, index, new_codes, label, offset); - // Update the pointers and the sizes - recompute_internal_state(res, *index); -} - -/** - * Extend one list of the index in-place, by the list label, skipping the classification step. - * See the public interface for the api and usage. - */ -template -void extend_list(raft::resources const& res, - index* index, - raft::device_matrix_view new_vectors, - raft::device_vector_view new_indices, - uint32_t label) -{ - // Allocate memory and write indices - auto offset = extend_list_prepare(res, index, new_indices, label); - // Encode the data - encode_list_data(res, index, new_vectors, label, offset); - // Update the pointers and the sizes - recompute_internal_state(res, *index); -} - -/** - * Remove all data from a single list. - * See the public interface for the api and usage. - */ -template -void erase_list(raft::resources const& res, index* index, uint32_t label) -{ - uint32_t zero = 0; - copy(index->list_sizes().data_handle() + label, &zero, 1, resource::get_cuda_stream(res)); - index->lists()[label].reset(); - recompute_internal_state(res, *index); -} - -/** Copy the state of an index into a new index, but share the list data among the two. */ -template -auto clone(const raft::resources& res, const index& source) -> index -{ - auto stream = resource::get_cuda_stream(res); - - // Allocate the new index - index target(res, - source.metric(), - source.codebook_kind(), - source.n_lists(), - source.dim(), - source.pq_bits(), - source.pq_dim()); - - // Copy the independent parts - copy(target.list_sizes().data_handle(), - source.list_sizes().data_handle(), - source.list_sizes().size(), - stream); - copy(target.rotation_matrix().data_handle(), - source.rotation_matrix().data_handle(), - source.rotation_matrix().size(), - stream); - copy(target.pq_centers().data_handle(), - source.pq_centers().data_handle(), - source.pq_centers().size(), - stream); - copy(target.centers().data_handle(), - source.centers().data_handle(), - source.centers().size(), - stream); - copy(target.centers_rot().data_handle(), - source.centers_rot().data_handle(), - source.centers_rot().size(), - stream); - - // Copy shared pointers - target.lists() = source.lists(); - - // Make sure the device pointers point to the new lists - recompute_internal_state(res, target); - - return target; -} - -/** - * Extend the index in-place. - * See cuvs::spatial::knn::ivf_pq::extend docs. - */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - raft::common::nvtx::range fun_scope( - "ivf_pq::extend(%zu, %u)", size_t(n_rows), index->dim()); - - resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::extend"); - auto stream = resource::get_cuda_stream(handle); - const auto n_clusters = index->n_lists(); - - RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, - "You must pass data indices when the index is non-empty."); - - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "Unsupported data type"); - - rmm::mr::device_memory_resource* device_memory = nullptr; - auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); - if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq::extend: using pool memory resource"); } - - rmm::mr::managed_memory_resource managed_memory_upstream; - rmm::mr::pool_memory_resource managed_memory( - &managed_memory_upstream, 1024 * 1024); - - // The spec defines how the clusters look like - auto spec = list_spec{ - index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; - // Try to allocate an index with the same parameters and the projected new size - // (which can be slightly larger than index->size() + n_rows, due to padding). - // If this fails, the index would be too big to fit in the device anyway. - std::optional> placeholder_list( - std::in_place_t{}, - handle, - list_spec{spec}, - n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); - - // Available device memory - size_t free_mem, total_mem; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem)); - - // Decide on an approximate threshold when we'd better start saving device memory by using - // managed allocations for large device buffers - rmm::mr::device_memory_resource* labels_mr = device_memory; - rmm::mr::device_memory_resource* batches_mr = device_memory; - if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) > - free_mem) { - labels_mr = &managed_memory; - } - // Allocate a buffer for the new labels (classifying the new data) - rmm::device_uvector new_data_labels(n_rows, stream, labels_mr); - if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; } - - // Calculate the batch size for the input data if it's not accessible directly from the device - constexpr size_t kReasonableMaxBatchSize = 65536; - size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - { - size_t size_factor = 0; - // we'll use two temporary buffers for converted inputs when computing the codes. - size_factor += (index->dim() + index->rot_dim()) * sizeof(float); - // ...and another buffer for indices - size_factor += sizeof(IdxT); - // if the input data is not accessible on device, we'd need a buffer for it. - switch (utils::check_pointer_residency(new_vectors)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: size_factor += index->dim() * sizeof(T); - } - // the same with indices - if (new_indices != nullptr) { - switch (utils::check_pointer_residency(new_indices)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: size_factor += sizeof(IdxT); - } - } - // make the batch size fit into the remaining memory - while (size_factor * max_batch_size > free_mem && max_batch_size > 128) { - max_batch_size >>= 1; - } - if (size_factor * max_batch_size > free_mem) { - // if that still doesn't fit, resort to the UVM - batches_mr = &managed_memory; - max_batch_size = kReasonableMaxBatchSize; - } else { - // If we're keeping the batches in device memory, update the available mem tracker. - free_mem -= size_factor * max_batch_size; - } - } - - // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr); - // Release the placeholder memory, because we don't intend to allocate any more long-living - // temporary buffers before we allocate the index data. - // This memory could potentially speed up UVM accesses, if any. - placeholder_list.reset(); - { - // The cluster centers in the index are stored padded, which is not acceptable by - // the kmeans_balanced::predict. Thus, we need the restructuring copy. - rmm::device_uvector cluster_centers( - size_t(n_clusters) * size_t(index->dim()), stream, device_memory); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data(), - sizeof(float) * index->dim(), - index->centers().data_handle(), - sizeof(float) * index->dim_ext(), - sizeof(float) * index->dim(), - n_clusters, - cudaMemcpyDefault, - stream)); - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( - cluster_centers.data(), n_clusters, index->dim()); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - cuvs::cluster::kmeans_balanced::predict(handle, - kmeans_params, - batch_data_view, - centers_view, - batch_labels_view, - utils::mapping{}); - } - } - - auto list_sizes = index->list_sizes().data_handle(); - // store the current cluster sizes, because we'll need them later - rmm::device_uvector orig_list_sizes(n_clusters, stream, device_memory); - copy(orig_list_sizes.data(), list_sizes, n_clusters, stream); - - // Get the combined cluster sizes - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(list_sizes), - IdxT(n_clusters), - new_data_labels.data(), - n_rows, - 1, - stream); - linalg::add(list_sizes, list_sizes, orig_list_sizes.data(), n_clusters, stream); - - // Allocate the lists to fit the new data - { - std::vector new_cluster_sizes(n_clusters); - std::vector old_cluster_sizes(n_clusters); - copy(new_cluster_sizes.data(), list_sizes, n_clusters, stream); - copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_clusters, stream); - resource::sync_stream(handle); - for (uint32_t label = 0; label < n_clusters; label++) { - ivf::resize_list( - handle, index->lists()[label], spec, new_cluster_sizes[label], old_cluster_sizes[label]); - } - } - - // Update the pointers and the sizes - recompute_internal_state(handle, *index); - - // Recover old cluster sizes: they are used as counters in the fill-codes kernel - copy(list_sizes, orig_list_sizes.data(), n_clusters, stream); - - // By this point, the index state is updated and valid except it doesn't contain the new data - // Fill the extended index with the new data (possibly, in batches) - utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, batches_mr); - for (const auto& vec_batch : vec_batches) { - const auto& idx_batch = *idx_batches++; - process_and_fill_codes(handle, - *index, - vec_batch.data(), - new_indices != nullptr - ? std::variant(idx_batch.data()) - : std::variant(IdxT(idx_batch.offset())), - new_data_labels.data() + vec_batch.offset(), - IdxT(vec_batch.size()), - batches_mr); - } -} - -/** - * Create a new index that contains more data. - * See cuvs::spatial::knn::ivf_pq::extend docs. - */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - auto ext_index = clone(handle, orig_index); - detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); - return ext_index; -} - -/** See cuvs::spatial::knn::ivf_pq::build docs */ -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - raft::common::nvtx::range fun_scope( - "ivf_pq::build(%zu, %u)", size_t(n_rows), dim); - resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::build"); - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "Unsupported data type"); - - RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - - auto stream = resource::get_cuda_stream(handle); - - index index(handle, params, dim); - utils::memzero( - index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); - - { - auto trainset_ratio = std::max( - 1, - size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); - size_t n_rows_train = n_rows / trainset_ratio; - - auto* device_memory = resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; - rmm::mr::pool_memory_resource managed_memory( - &managed_memory_upstream, 1024 * 1024); - - // If the trainset is small enough to comfortably fit into device memory, put it there. - // Otherwise, use the managed memory. - constexpr size_t kTolerableRatio = 4; - rmm::mr::device_memory_resource* big_memory_resource = &managed_memory; - if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio < - resource::get_workspace_free_bytes(handle)) { - big_memory_resource = device_memory; - } - - // Besides just sampling, we transform the input dataset into floats to make it easier - // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, big_memory_resource); - // TODO: a proper sampling - if constexpr (std::is_same_v) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - } else { - size_t dim = index.dim(); - cudaPointerAttributes dataset_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset)); - if (dataset_attr.devicePointer != nullptr) { - // data is available on device: just run the kernel to copy and map the data - auto p = reinterpret_cast(dataset_attr.devicePointer); - auto trainset_view = - raft::make_device_vector_view(trainset.data(), dim * n_rows_train); - linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) { - auto col = i % dim; - return utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); - }); - } else { - // data is not available: first copy, then map inplace - auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + - (sizeof(float) - sizeof(T)) * index.dim()); - // We copy the data in strides, one row at a time, and place the smaller rows of type T - // at the end of float rows. - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp, - sizeof(float) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - // Transform the input `{T -> float}`, one row per warp. - // The threads in each warp copy the data synchronously; this and the layout of the data - // (content is aligned to the end of the rows) together allow doing the transform in-place. - copy_warped(trainset.data(), - index.dim(), - trainset_tmp, - index.dim() * sizeof(float) / sizeof(T), - index.dim(), - n_rows_train, - stream); - } - } - - // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, - // dim_ext]! - rmm::device_uvector cluster_centers_buf( - index.n_lists() * index.dim(), stream, device_memory); - auto cluster_centers = cluster_centers_buf.data(); - - // Train balanced hierarchical kmeans clustering - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = - raft::make_device_matrix_view(cluster_centers, index.n_lists(), index.dim()); - cuvs::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = index.metric(); - cuvs::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); - - // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); - auto centers_const_view = raft::make_device_matrix_view( - cluster_centers, index.n_lists(), index.dim()); - auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); - cuvs::cluster::kmeans_balanced::predict(handle, - kmeans_params, - trainset_const_view, - centers_const_view, - labels_view, - utils::mapping()); - - // Make rotation matrix - make_rotation_matrix(handle, - params.force_random_rotation, - index.rot_dim(), - index.dim(), - index.rotation_matrix().data_handle()); - - set_centers(handle, &index, cluster_centers); - - // Train PQ codebooks - switch (index.codebook_kind()) { - case codebook_gen::PER_SUBSPACE: - train_per_subset(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory); - break; - case codebook_gen::PER_CLUSTER: - train_per_cluster(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory); - break; - default: RAFT_FAIL("Unreachable code"); - } - } - - // add the data if necessary - if (params.add_data_on_build) { - detail::extend(handle, &index, dataset, nullptr, n_rows); - } - return index; -} -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_codepacking.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_codepacking.cuh deleted file mode 100644 index bbd47baa0..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_codepacking.cuh +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs::neighbors::ivf_pq::detail { - -/** A chunk of PQ-encoded vector managed by one CUDA thread. */ -using pq_vec_t = raft::TxN_t::io_t; - -/** - * This type mimics the `uint8_t&` for the indexing operator of `bitfield_view_t`. - * - * @tparam Bits number of bits comprising the value. - */ -template -struct bitfield_ref_t { - static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); - constexpr static uint8_t kMask = static_cast((1u << Bits) - 1u); - uint8_t* ptr; - uint32_t offset; - - constexpr operator uint8_t() // NOLINT - { - auto pair = static_cast(ptr[0]); - if (offset + Bits > 8) { pair |= static_cast(ptr[1]) << 8; } - return static_cast((pair >> offset) & kMask); - } - - constexpr auto operator=(uint8_t code) -> bitfield_ref_t& - { - if (offset + Bits > 8) { - auto pair = static_cast(ptr[0]); - pair |= static_cast(ptr[1]) << 8; - pair &= ~(static_cast(kMask) << offset); - pair |= static_cast(code) << offset; - ptr[0] = static_cast(Pow2<256>::mod(pair)); - ptr[1] = static_cast(Pow2<256>::div(pair)); - } else { - ptr[0] = (ptr[0] & ~(kMask << offset)) | (code << offset); - } - return *this; - } -}; - -/** - * View a byte array as an array of unsigned integers of custom small bit size. - * - * @tparam Bits number of bits comprising a single element of the array. - */ -template -struct bitfield_view_t { - static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); - uint8_t* raw; - - constexpr auto operator[](uint32_t i) -> bitfield_ref_t - { - uint32_t bit_offset = i * Bits; - return bitfield_ref_t{raw + raft::Pow2<8>::div(bit_offset), - raft::Pow2<8>::mod(bit_offset)}; - } -}; - -/** - * Process a single vector in a list. - * - * @tparam PqBits - * @tparam Action tells how to process a single vector (e.g. reconstruct or just unpack) - * - * @param[in] in_list_data the encoded cluster data. - * @param[in] in_ix in-cluster index of the vector to be decoded (one-per-thread). - * @param[in] out_ix the output index passed to the action - * @param[in] pq_dim - * @param action a callable action to be invoked on each PQ code (component of the encoding) - * type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim). - */ -template -__device__ void run_on_vector( - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - uint32_t in_ix, - uint32_t out_ix, - uint32_t pq_dim, - Action action) -{ - using group_align = raft::Pow2; - const uint32_t group_ix = group_align::div(in_ix); - const uint32_t ingroup_ix = group_align::mod(in_ix); - - pq_vec_t code_chunk; - bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - for (uint32_t j = 0, i = 0; j < pq_dim; i++) { - // read the chunk - code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); - // read the codes, one/pq_dim at a time -#pragma unroll - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - // read a piece of the reconstructed vector - action(code_view[k], out_ix, j); - } - } -} - -/** - * Process a single vector in a list. - * - * @tparam PqBits - * @tparam SubWarpSize how many threads work on the same ix (only the first thread writes data). - * @tparam IdxT type of the index passed to the action - * @tparam Action tells how to process a single vector (e.g. encode or just pack) - * - * @param[in] out_list_data the encoded cluster data. - * @param[in] out_ix in-cluster index of the vector to be processed (one-per-SubWarpSize threads). - * @param[in] in_ix the input index passed to the action (one-per-SubWarpSize threads). - * @param[in] pq_dim - * @param action a callable action to be invoked on each PQ code (component of the encoding) - * type: (uint32_t in_ix, uint32_t j) -> uint8_t, where j = [0..pq_dim). - */ -template -__device__ void write_vector( - raft::device_mdspan::list_extents, raft::row_major> - out_list_data, - uint32_t out_ix, - IdxT in_ix, - uint32_t pq_dim, - Action action) -{ - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - - using group_align = raft::Pow2; - const uint32_t group_ix = group_align::div(out_ix); - const uint32_t ingroup_ix = group_align::mod(out_ix); - - pq_vec_t code_chunk; - bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - for (uint32_t j = 0, i = 0; j < pq_dim; i++) { - // clear the chunk - if (lane_id == 0) { code_chunk = pq_vec_t{}; } - // write the codes, one/pq_dim at a time -#pragma unroll - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - // write a single code - uint8_t code = action(in_ix, j); - if (lane_id == 0) { code_view[k] = code; } - } - // write the chunk to the list - if (lane_id == 0) { - *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; - } - } -} - -/** Process the given indices or a block of a single list (cluster). */ -template -__device__ void run_on_list( - raft::device_mdspan::list_extents, raft::row_major> - in_list_data, - std::variant offset_or_indices, - uint32_t len, - uint32_t pq_dim, - Action action) -{ - for (uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; ix < len; ix += blockDim.x) { - const uint32_t src_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + ix - : std::get(offset_or_indices)[ix]; - run_on_vector(in_list_data, src_ix, ix, pq_dim, action); - } -} - -/** Process the given indices or a block of a single list (cluster). */ -template -__device__ void write_list( - raft::device_mdspan::list_extents, raft::row_major> - out_list_data, - std::variant offset_or_indices, - uint32_t len, - uint32_t pq_dim, - Action action) -{ - using subwarp_align = raft::Pow2; - uint32_t stride = subwarp_align::div(blockDim.x); - uint32_t ix = subwarp_align::div(threadIdx.x + blockDim.x * blockIdx.x); - for (; ix < len; ix += stride) { - const uint32_t dst_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + ix - : std::get(offset_or_indices)[ix]; - write_vector(out_list_data, dst_ix, ix, pq_dim, action); - } -} - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-ext.cuh deleted file mode 100644 index 26fd7e493..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // __half -#include // cuvs::distance::DistanceType -#include // cuvs::neighbors::ivf_pq::detail::fp_8bit -#include // cuvs::neighbors::ivf_pq::codebook_gen -#include // none_ivf_sample_filter -#include // RAFT_WEAK_FUNCTION -#include // RAFT_EXPLICIT -#include // rmm::cuda_stream_view - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::ivf_pq::detail { - -// is_local_topk_feasible is not inline here, because we would have to define it -// here as well. That would run the risk of the definitions here and in the -// -inl.cuh header diverging. -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) - -> bool; - -template -RAFT_KERNEL compute_similarity_kernel(uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -template -struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; -}; - -template -void compute_similarity_run(selected s, - rmm::cuda_stream_view stream, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -/** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ -template -auto compute_similarity_select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) - -> selected RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::ivf_pq::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - extern template auto \ - cuvs::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->cuvs::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - cuvs::neighbors::ivf_pq::detail::compute_similarity_run( \ - cuvs::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - cuvs::distance::DistanceType metric, \ - cuvs::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - half, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - half, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - float, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - cuvs::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - cuvs::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA cuvs::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-inl.cuh deleted file mode 100644 index c5c1be45c..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ /dev/null @@ -1,940 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // cuvs::distance::DistanceType -#include // dummy_block_sort_t -#include // codebook_gen -#include // none_ivf_sample_filter -#include // raft::matrix::detail::select::warpsort::warp_sort_distributed -#include // RAFT_CUDA_TRY -#include // raft::atomicMin -#include // raft::Pow2 -#include // raft::TxN_t -#include // rmm::cuda_stream_view - -namespace cuvs::neighbors::ivf_pq::detail { - -/** - * Maximum value of k for the fused calculate & select in ivfpq. - * - * If runtime value of k is larger than this, the main search operation - * is split into two kernels (per batch, first calculate distance, then select top-k). - */ -static constexpr int kMaxCapacity = 128; -static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), - "kMaxCapacity must be a power of two, not smaller than the raft::WarpSize."); - -// using weak attribute here, because it may be compiled multiple times. -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) - -> bool -{ - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small - return true; -} - -template -struct pq_block_sort { - using type = raft::matrix::detail::select::warpsort::block_sort< - raft::matrix::detail::select::warpsort::warp_sort_distributed_ext, - Capacity, - true, - T, - IdxT>; - - static auto get_mem_required(uint32_t k_max) - { - if (k_max == 0 || k_max > Capacity) { - return pq_block_sort<0, T, IdxT>::get_mem_required(k_max); - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { - return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max); - } - } - return type::queue_t::mem_required; - } -}; - -template -struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { - using type = dummy_block_sort_t; - static auto mem_required(uint32_t) -> size_t { return 0; } - static auto get_mem_required(uint32_t) { return mem_required; } -}; - -template -using block_sort_t = typename pq_block_sort::type; - -/** - * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` - * (which does not take into account `reservedSharedMemPerBlock`), - * given by a desired schmem-L1 split and a per-block memory requirement in bytes. - * - * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's - * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the - * carveout to zero, it will choose a non-zero config that will allow to run at least one active - * block per SM. - * - * @param shmem_fraction - * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. - * @param shmem_per_block - * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. - * @param dev_props - * device properties. - * @return - * a carveout value in percents [0, 100]. - */ -constexpr inline auto estimate_carveout(double shmem_fraction, - size_t shmem_per_block, - const cudaDeviceProp& dev_props) -> int -{ - using shmem_unit = raft::Pow2<128>; - size_t m = shmem_unit::roundUp(shmem_per_block); - size_t r = dev_props.reservedSharedMemPerBlock; - size_t s = dev_props.sharedMemPerMultiprocessor; - return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); -} - -/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ -template -__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, - typename VecT::math_t& pq_code, - const VecT& pq_codes, - const LutT*& lut_head, - const LutT*& lut_end) -{ - if constexpr (CheckBounds) { - if (lut_head >= lut_end) { return; } - } - constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); - constexpr uint32_t kPqShift = 1u << PqBits; - constexpr uint32_t kPqMask = kPqShift - 1u; - if constexpr (BitsLeft >= PqBits) { - uint8_t code = pq_code & kPqMask; - pq_code >>= PqBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } else if constexpr (Ix < VecT::Ratio) { - uint8_t code = pq_code; - pq_code = pq_codes.val.data[Ix]; - constexpr uint32_t kRemBits = PqBits - BitsLeft; - constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; - code |= (pq_code & kRemMask) << BitsLeft; - pq_code >>= kRemBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); - } -} - -/* Compute the similarity for one vector in the pq_dataset */ -template -__device__ auto ivfpq_compute_score(uint32_t pq_dim, - const typename VecT::io_t* pq_head, - const LutT* lut_scores, - OutT early_stop_limit) -> OutT -{ - constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; - auto lut_head = lut_scores; - auto lut_end = lut_scores + (pq_dim << PqBits); - VecT pq_codes; - OutT score{0}; - for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { - *pq_codes.vectorized_data() = *pq_head; - pq_head += kIndexGroupSize; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). - if (score >= early_stop_limit) { return score; } - } - if (pq_dim > 0) { - *pq_codes.vectorized_data() = *pq_head; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } - return score; -} - -/** - * The main kernel that computes similarity scores across multiple queries and probes. - * When `Capacity > 0`, it also selects top K candidates for each query and probe - * (which need to be merged across probes afterwards). - * - * Each block processes a (query, probe) pair: it calculates the distance between the single query - * vector and all the dataset vector in the cluster that we are probing. - * - * @tparam OutT - * The output type - distances. - * @tparam LutT - * The lookup table element type (lut_scores). - * @tparam PqBits - * The bit length of an encoded vector element after compression by PQ - * (NB: pq_book_size = 1 << PqBits). - * @tparam Capacity - * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. - * @tparam PrecompBaseDiff - * Defines whether we should precompute part of the distance and keep it in shared memory - * before the main part (score calculation) to increase memory usage efficiency in the latter. - * For L2, this is the distance between the query and the cluster center. - * @tparam EnableSMemLut - * Defines whether to use the shared memory for the lookup table (`lut_scores`). - * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) - * at the cost of reducing global memory reading throughput. - * - * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). - * @param n_probes the number of clusters to search for each query - * @param pq_dim - * The dimensionality of an encoded vector after compression by PQ. - * @param n_queries the number of queries. - * @param queries_offset - * An offset of the current query batch. It is used for feeding sample_filter with the - * correct query index. - * @param metric the distance type. - * @param codebook_kind Defines the way PQ codebooks have been trained. - * @param topk the `k` in the select top-k. - * @param max_samples the size of the output for a single query. - * @param cluster_centers - * The device pointer to the cluster centers in the original space (NB: after rotation) - * [n_clusters, dim]. - * @param pq_centers - * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len]. - * @param pq_dataset - * The device pointer to the PQ index (data) [n_rows, ...]. - * @param cluster_labels - * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. - * @param _chunk_indices - * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. - * @param queries - * The device pointer to the queries (NB: after rotation) [n_queries, dim]. - * @param index_list - * An optional device pointer to the enforced order of search [n_queries, n_probes]. - * One can pass reordered indices here to try to improve data reading locality. - * @param query_kth - * query_kths keep the current state of the filtering - atomically updated distances to the - * k-th closest neighbors for each query [n_queries]. - * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to - * provide a green light for every sample. - * @param lut_scores - * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. - * Ignored when `EnableSMemLut == true`. - * @param _out_scores - * The device pointer to the output scores - * [n_queries, max_samples] or [n_queries, n_probes, topk]. - * @param _out_indices - * The device pointer to the output indices [n_queries, n_probes, topk]. - * These are the indices of the records as they appear in the database view formed by the probed - * clusters / defined by the `_chunk_indices`. - * The indices can have values within the range [0, max_samples). - * Ignored when `Capacity == 0`. - */ -template -RAFT_KERNEL compute_similarity_kernel(uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - /* Shared memory: - - * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * lut_end+: - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 - * topk::warp_sort::mem_required - local topk temporary buffer (if necessary) - * topk::block_sort: some amount of shared memory, but overlaps with the rest: - block_sort only needs shared memory for `.done()` operation, which can come very last. - */ - extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT - constexpr bool kManageLocalTopK = Capacity > 0; - - constexpr uint32_t PqShift = 1u << PqBits; // NOLINT - constexpr uint32_t PqMask = PqShift - 1u; // NOLINT - - const uint32_t pq_len = dim / pq_dim; - const uint32_t lut_size = pq_dim * PqShift; - - if constexpr (EnableSMemLut) { - lut_scores = reinterpret_cast(smem_buf); - } else { - lut_scores += lut_size * blockIdx.x; - } - - uint8_t* lut_end = nullptr; - if constexpr (EnableSMemLut) { - lut_end = reinterpret_cast(lut_scores + lut_size); - } else { - lut_end = smem_buf; - } - - for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { - if (ib >= gridDim.x) { - // sync shared memory accesses on the second and further iterations - __syncthreads(); - } - uint32_t query_ix; - uint32_t probe_ix; - if (index_list == nullptr) { - query_ix = ib % n_queries; - probe_ix = ib / n_queries; - } else { - auto ordered_ix = index_list[ib]; - query_ix = ordered_ix / n_probes; - probe_ix = ordered_ix % n_probes; - } - - const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); - const float* query = queries + (dim * query_ix); - OutT* out_scores; - uint32_t* out_indices = nullptr; - if constexpr (kManageLocalTopK) { - // Store topk calculated distances to out_scores (and its indices to out_indices) - const uint64_t out_offset = probe_ix + n_probes * query_ix; - out_scores = _out_scores + out_offset * topk; - out_indices = _out_indices + out_offset * topk; - } else { - // Store all calculated distances to out_scores - out_scores = _out_scores + uint64_t(max_samples) * query_ix; - } - uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; - const float* cluster_center = cluster_centers + dim * label; - const float* pq_center; - if (codebook_kind == codebook_gen::PER_SUBSPACE) { - pq_center = pq_centers; - } else { - pq_center = pq_centers + (pq_len << PqBits) * label; - } - - if constexpr (PrecompBaseDiff) { - // Reduce number of memory reads later by pre-computing parts of the score - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - reinterpret_cast(lut_end)[i] = query[i] - cluster_center[i]; - } - } break; - case distance::DistanceType::InnerProduct: { - float2 pvals; - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(lut_end)[i] = pvals; - } - } break; - default: __builtin_unreachable(); - } - __syncthreads(); - } - - { - // Create a lookup table - // For each subspace, the lookup table stores the distance between the actual query vector - // (projected into the subspace) and all possible pq vectors in that subspace. - for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { - const uint32_t i_pq = i >> PqBits; - uint32_t j = i_pq * pq_len; - const uint32_t j_end = pq_len + j; - auto cur_pq_center = pq_center + (i & PqMask) + - (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); - float score = 0.0; - do { - float pq_c = *cur_pq_center; - cur_pq_center += PqShift; - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - float diff; - if constexpr (PrecompBaseDiff) { - diff = reinterpret_cast(lut_end)[j]; - } else { - diff = query[j] - cluster_center[j]; - } - diff -= pq_c; - score += diff * diff; - } break; - case distance::DistanceType::InnerProduct: { - // NB: we negate the scores as we hardcoded select-topk to always compute the minimum - float q; - if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(lut_end)[j]; - q = pvals.x; - score -= pvals.y; - } else { - q = query[j]; - score -= q * cluster_center[j]; - } - score -= q * pq_c; - } break; - default: __builtin_unreachable(); - } - } while (++j < j_end); - lut_scores[i] = LutT(score); - } - } - - // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved - // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved - // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest - // possible access by thread warps. - // - // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. - // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. - // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record - // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, - // so that the warp can achieve the best coalesced read throughput. - using group_align = raft::Pow2; - using vec_align = raft::Pow2; - using local_topk_t = block_sort_t; - using op_t = uint32_t; - using vec_t = raft::TxN_t; - - uint32_t sample_offset = 0; - if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } - uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; - uint32_t n_samples_aligned = group_align::roundUp(n_samples); - constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; - uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; - auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + - group_align::mod(threadIdx.x) * vec_align::Value; - pq_line_width *= blockDim.x; - - constexpr OutT kDummy = raft::upper_bound(); - OutT query_kth = kDummy; - if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - OutT early_stop_limit = kDummy; - switch (metric) { - // If the metric is non-negative, we can use the query_kth approximation as an early stop - // threshold to skip some iterations when computing the score. Add such metrics here. - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - early_stop_limit = query_kth; - } break; - default: break; - } - - // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score - __threadfence_block(); - __syncthreads(); - local_topk_t block_topk(topk, lut_end, query_kth); - - // Compute a distance for each sample - for (uint32_t i = threadIdx.x; i < n_samples_aligned; - i += blockDim.x, pq_thread_data += pq_line_width) { - OutT score = kDummy; - bool valid = i < n_samples; - // Check bounds and that the sample is acceptable for the query - if (valid && sample_filter(queries_offset + query_ix, label, i)) { - score = ivfpq_compute_score( - pq_dim, - reinterpret_cast(pq_thread_data), - lut_scores, - early_stop_limit); - } - if constexpr (kManageLocalTopK) { - block_topk.add(score, sample_offset + i); - } else { - if (valid) { out_scores[sample_offset + i] = score; } - } - } - if constexpr (kManageLocalTopK) { - // sync threads before the topk merging operation, because we reuse smem_buf - __syncthreads(); - block_topk.done(smem_buf); - block_topk.store(out_scores, out_indices); - if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } - } else { - // fill in the rest of the out_scores with dummy values - if (probe_ix + 1 == n_probes) { - for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; - i += blockDim.x) { - out_scores[i] = kDummy; - } - } - } - } -} - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -// The config struct lifts the runtime parameters to the template parameters -template -struct compute_similarity_kernel_config { - public: - static auto get(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - return kernel_choose_bits(pq_bits, k_max); - } - - private: - static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - switch (pq_bits) { - case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); - case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); - case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); - case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); - case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - } - - template - static auto kernel_try_capacity(uint32_t k_max) - -> compute_similarity_kernel_t - { - if constexpr (Capacity > 0) { - if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } - } - return compute_similarity_kernel; - } -}; - -// A standalone accessor function was necessary to make sure template -// instantiation work correctly. This accessor function is not used anymore and -// may be removed. -template -auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t -{ - return compute_similarity_kernel_config::get(pq_bits, k_max); -} - -/** Estimate the occupancy for the given kernel on the given device. */ -template -struct occupancy_t { - using shmem_unit = raft::Pow2<128>; - - int blocks_per_sm = 0; - double occupancy = 0.0; - double shmem_use = 1.0; - - inline occupancy_t() = default; - inline occupancy_t(size_t smem, - uint32_t n_threads, - compute_similarity_kernel_t kernel, - const cudaDeviceProp& dev_props) - { - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); - occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); - shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / - double(dev_props.sharedMemPerMultiprocessor); - } -}; - -template -struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; -}; - -template -void compute_similarity_run(selected s, - rmm::cuda_stream_view stream, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - s.kernel<<>>(dim, - n_probes, - pq_dim, - n_queries, - queries_offset, - metric, - codebook_kind, - topk, - max_samples, - cluster_centers, - pq_centers, - pq_dataset, - cluster_labels, - _chunk_indices, - queries, - index_list, - query_kths, - sample_filter, - lut_scores, - _out_scores, - _out_indices); - RAFT_CHECK_CUDA(stream); -} - -/** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ -template -auto compute_similarity_select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -> selected -{ - // Shared memory for storing the lookup table - size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); - // Shared memory for storing pre-computed pieces to speedup the lookup table construction - // (e.g. the distance between a cluster center and the query for L2). - size_t bdf_mem = sizeof(float) * precomp_data_count; - - // Shared memory used by the fused top-k during cluster scanning; - // may overlap with the precomputed distance array - struct ltk_add_mem_t { - size_t (*mem_required)(uint32_t); - - ltk_add_mem_t(bool manage_local_topk, uint32_t topk) - : mem_required(pq_block_sort::get_mem_required( - manage_local_topk ? topk : 0)) - { - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return mem_required(n_threads); - } - } ltk_add_mem{manage_local_topk, topk}; - - // Shared memory for the fused top-k component; - // may overlap with all other uses of shared memory - struct ltk_reduce_mem_t { - uint32_t subwarp_size; - uint32_t topk; - bool manage_local_topk; - ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk) - : manage_local_topk(manage_local_topk), topk(topk) - { - subwarp_size = raft::WarpSize; - while (topk * 2 <= subwarp_size) { - subwarp_size /= 2; - } - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return manage_local_topk ? raft::matrix::detail::select::warpsort:: - template calc_smem_size_for_block_wide( - n_threads / subwarp_size, topk) - : 0; - } - } ltk_reduce_mem{manage_local_topk, topk}; - - struct total_shared_mem_t { - ltk_add_mem_t& ltk_add_mem; - ltk_reduce_mem_t& ltk_reduce_mem; - size_t lut_mem; - size_t bdf_mem; - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return std::max(ltk_reduce_mem(n_threads), - lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads))); - } - }; - - // Total amount of work; should be enough to occupy the GPU. - uint32_t n_blocks = n_queries * n_probes; - - // The minimum block size we may want: - // 1. It's a power-of-two for efficient L1 caching of pq_centers values - // (multiples of `1 << pq_bits`). - // 2. It should be large enough to fully utilize an SM. - uint32_t n_threads_min = raft::WarpSize; - while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < - dev_props.maxThreadsPerMultiProcessor) { - n_threads_min *= 2; - } - // Further increase the minimum block size to make sure full device occupancy - // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) - while (int(n_blocks * n_threads_min) < - dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - // Even further, increase it to allow less blocks per SM if there not enough queries. - // With this, we reduce the chance of different clusters being processed by two blocks - // on the same SM and thus improve the data locality for L1 caching. - while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - - // Granularity of changing the number of threads when computing the maximum block size. - // It's good to have it multiple of the PQ book width. - uint32_t n_threads_gty = raft::round_up_safe(1u << pq_bits, raft::WarpSize); - - /* - Shared memory / L1 cache balance is the main limiter of this kernel. - The more blocks per SM we launch, the more shared memory we need. Besides that, we have - three versions of the kernel varying in performance and shmem usage. - - We try the most demanding and the fastest kernel first, trying to maximize occupancy with - the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further - optimize occupancy and data locality for the L1 cache. - */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; - auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{ - std::make_tuple(conf_fast(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem}, - true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0}, - true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem}, - false)}; - - // we may allow slightly lower than 100% occupancy; - constexpr double kTargetOccupancy = 0.75; - // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; - for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) { - if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) { - // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. - continue; - } - - // First, we set the carveout hint to the preferred value. The driver will increase this if - // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, - // this carveout value will limit the calculated occupancy. When we're done selecting the best - // launch configuration, we will tighten the carveout once more, based on the final memory - // usage and occupancy. - const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); - - // Get the theoretical maximum possible number of threads per block - cudaFuncAttributes kernel_attrs; - RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); - uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); - - // Actual required shmem depens on the number of threads - size_t smem_size = smem_size_f(n_threads); - - // Make sure the kernel can get enough shmem. - cudaError_t cuda_status = - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (cuda_status != cudaSuccess) { - RAFT_EXPECTS( - cuda_status == cudaGetLastError(), - "Tried to reset the expected cuda error code, but it didn't match the expectation"); - // Failed to request enough shmem for the kernel. Skip the candidate. - continue; - } - - occupancy_t cur(smem_size, n_threads, kernel, dev_props); - if (cur.blocks_per_sm <= 0) { - // For some reason, we still cannot make this kernel run. Skip the candidate. - continue; - } - - { - // Try to reduce the number of threads to increase occupancy and data locality - auto n_threads_tmp = n_threads_min; - while (n_threads_tmp * 2 < n_threads) { - n_threads_tmp *= 2; - } - if (n_threads_tmp < n_threads) { - while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = smem_size_f(n_threads_tmp); - occupancy_t tmp( - smem_size_tmp, n_threads_tmp, kernel, dev_props); - bool select_it = false; - if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { - // Normally, the smaller the block the better for L1 cache hit rate. - // Hence, the occupancy should be "just good enough" - select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); - } else if (lut_is_in_shmem) { - // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), - // the locality is not going to improve with increasing the number of blocks per SM. - // Hence, the only metric here is the occupancy. - bool improves_occupancy = tmp.occupancy > cur.occupancy; - // Otherwise, the performance still improves with a smaller block size, - // given there is enough work to do - bool improves_parallelism = - tmp.occupancy == cur.occupancy && - 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; - select_it = improves_occupancy || improves_parallelism; - } else { - // If we don't use shared memory for the lookup table, increasing the number of blocks - // is very taxing on the global memory usage. - // In this case, the occupancy must increase a lot to make it worth the cost. - select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); - } - if (select_it) { - n_threads = n_threads_tmp; - smem_size = smem_size_tmp; - cur = tmp; - } - n_threads_tmp /= 2; - } - } - } - - { - if (selected_perf.occupancy <= 0.0 // no candidate yet - || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && - selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy - ) { - selected_perf = cur; - if (lut_is_in_shmem) { - selected_config = { - kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; - } else { - // When the global memory is used for the lookup table, we need to minimize the grid - // size; otherwise, the kernel may quickly run out of memory. - auto n_blocks_min = - std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); - selected_config = {kernel, - dim3(n_blocks_min, 1, 1), - dim3(n_threads, 1, 1), - smem_size, - size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; - } - // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here - // a rather conservative bar; most likely, the kernel gets more shared memory than this, - // and the occupancy doesn't get hurt. - auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); - if (cur.occupancy >= kTargetOccupancy) { break; } - } else if (selected_perf.occupancy > 0.0) { - // If we found a reasonable candidate on a previous iteration, and this one is not better, - // then don't try any more candidates because they are much slower anyway. - break; - } - } - } - - RAFT_EXPECTS(selected_perf.occupancy > 0.0, - "Couldn't determine a working kernel launch configuration."); - - return selected_config; -} - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity.cuh deleted file mode 100644 index d987c0d4e..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_compute_similarity.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) -#include "ivf_pq_compute_similarity-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_pq_compute_similarity-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_dummy_block_sort.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_dummy_block_sort.cuh deleted file mode 100644 index 8732aed3e..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_dummy_block_sort.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::matrix::detail::select::warpsort::warp_sort_distributed - -/* - * This header file is a bit of an ugly duckling. The type dummy_block_sort is - * needed by both ivf_pq_search.cuh and ivf_pq_compute_similarity.cuh. - * - * I have decided to move it to it's own header file, which is overkill. Perhaps - * there is a nicer solution. - * - */ - -namespace cuvs::neighbors::ivf_pq::detail { - -template -struct dummy_block_sort_t { - using queue_t = - raft::matrix::detail::select::warpsort::warp_sort_distributed; - template - __device__ dummy_block_sort_t(int k, Args...){}; -}; - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_fp_8bit.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_fp_8bit.cuh deleted file mode 100644 index d574dbde3..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_fp_8bit.cuh +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include - -#include - -namespace cuvs::neighbors::ivf_pq::detail { - -/** 8-bit floating-point storage type. - * - * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined - * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. - */ -template -struct fp_8bit { - static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); - constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT - constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT - - public: - uint8_t bitstring; - - HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} - HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} - HDI auto operator=(float fp) -> fp_8bit& - { - bitstring = float2fp_8bit(fp).bitstring; - return *this; - } - HDI explicit operator float() const { return fp_8bit2float(*this); } - HDI explicit operator half() const { return fp_8bit2half(*this); } - - private: - static constexpr float kMin = 1.0f / float(1u << ExpMask); - static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); - - static HDI auto float2fp_8bit(float v) -> fp_8bit - { - if constexpr (Signed) { - auto u = fp_8bit(std::abs(v)).bitstring; - u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit - return fp_8bit(u); - } else { - // sic! all small and negative numbers are truncated to zero. - if (v < kMin) { return fp_8bit{static_cast(0)}; } - // protect from overflow - if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } - // the rest of possible float values should be within the normalized range - return fp_8bit{static_cast( - (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; - } - } - - static HDI auto fp_8bit2float(const fp_8bit& v) -> float - { - uint32_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - float r; - constexpr uint32_t kBase32 = (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23); - *reinterpret_cast(&r) = kBase32 + (u << (15u + ExpBits)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } - - static HDI auto fp_8bit2half(const fp_8bit& v) -> half - { - uint16_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - half r; - constexpr uint16_t kBase16 = (0x3c00u | (0x0200u >> ValBits)) - (ExpMask << 10); - *reinterpret_cast(&r) = kBase16 + (u << (2u + ExpBits)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } -}; - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_search.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_search.cuh deleted file mode 100644 index fa6f64c7b..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_search.cuh +++ /dev/null @@ -1,860 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include - -#include - -namespace cuvs::neighbors::ivf_pq::detail { - -using namespace cuvs::spatial::knn::detail; // NOLINT - -/** - * Select the clusters to probe and, as a side-effect, translate the queries type `T -> float` - * - * Assuming the number of clusters is not that big (a few thousands), we do a plain GEMM - * followed by select_k to select the clusters to probe. There's no need to return the similarity - * scores here. - */ -template -void select_clusters(raft::resources const& handle, - uint32_t* clusters_to_probe, // [n_queries, n_probes] - float* float_queries, // [n_queries, dim_ext] - uint32_t n_queries, - uint32_t n_probes, - uint32_t n_lists, - uint32_t dim, - uint32_t dim_ext, - cuvs::distance::DistanceType metric, - const T* queries, // [n_queries, dim] - const float* cluster_centers, // [n_lists, dim_ext] - rmm::mr::device_memory_resource* mr) -{ - auto stream = resource::get_cuda_stream(handle); - /* NOTE[qc_distances] - - We compute query-center distances to choose the clusters to probe. - We accomplish that with just one GEMM operation thanks to some preprocessing: - - L2 distance: - cluster_centers[i, dim()] contains the squared norm of the center vector i; - we extend the dimension K of the GEMM to compute it together with all the dot products: - - `qc_distances[i, j] = |cluster_centers[j]|^2 - 2 * (queries[i], cluster_centers[j])` - - This is a monotonous mapping of the proper L2 distance. - - IP distance: - `qc_distances[i, j] = - (queries[i], cluster_centers[j])` - - This is a negative inner-product distance. We minimize it to find the similar clusters. - - NB: qc_distances is NOT used further in ivfpq_search. - */ - float norm_factor; - switch (metric) { - case cuvs::distance::DistanceType::L2SqrtExpanded: - case cuvs::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; - case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; - default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); - } - auto float_queries_view = - raft::make_device_vector_view(float_queries, dim_ext * n_queries); - linalg::map_offset( - handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { - uint32_t col = ix % dim_ext; - uint32_t row = ix / dim_ext; - return col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; - }); - - float alpha; - float beta; - uint32_t gemm_k = dim; - switch (metric) { - case cuvs::distance::DistanceType::L2SqrtExpanded: - case cuvs::distance::DistanceType::L2Expanded: { - alpha = -2.0; - beta = 0.0; - gemm_k = dim + 1; - RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); - } break; - case cuvs::distance::DistanceType::InnerProduct: { - alpha = -1.0; - beta = 0.0; - } break; - default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); - } - rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); - linalg::gemm(handle, - true, - false, - n_lists, - n_queries, - gemm_k, - &alpha, - cluster_centers, - dim_ext, - float_queries, - dim_ext, - &beta, - qc_distances.data(), - n_lists, - stream); - - // Select neighbor clusters for each query. - rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); - raft::matrix::detail::select_k(handle, - qc_distances.data(), - nullptr, - n_queries, - n_lists, - n_probes, - cluster_dists.data(), - clusters_to_probe, - true, - mr); -} - -/** - * For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that - * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total - * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. - */ -template -__launch_bounds__(BlockDim) RAFT_KERNEL - calc_chunk_indices_kernel(uint32_t n_probes, - const uint32_t* cluster_sizes, // [n_clusters] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t* n_samples // [n_queries] - ) -{ - using block_scan = cub::BlockScan; - __shared__ typename block_scan::TempStorage shm; - - // locate the query data - clusters_to_probe += n_probes * blockIdx.x; - chunk_indices += n_probes * blockIdx.x; - - // block scan - const uint32_t n_probes_aligned = raft::Pow2::roundUp(n_probes); - uint32_t total = 0; - for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) { - auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u; - auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u; - if (threadIdx.x == 0) { chunk += total; } - block_scan(shm).InclusiveSum(chunk, chunk, total); - __syncthreads(); - if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; } - } - // save the total size - if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } -} - -struct calc_chunk_indices { - public: - struct configured { - void* kernel; - dim3 block_dim; - dim3 grid_dim; - uint32_t n_probes; - - inline void operator()(const uint32_t* cluster_sizes, - const uint32_t* clusters_to_probe, - uint32_t* chunk_indices, - uint32_t* n_samples, - rmm::cuda_stream_view stream) - { - void* args[] = // NOLINT - {&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples}; - RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream)); - } - }; - - static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured - { - return try_block_dim<1024>(n_probes, n_queries); - } - - private: - template - static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured - { - if constexpr (BlockDim >= raft::WarpSize * 2) { - if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } - } - return {reinterpret_cast(calc_chunk_indices_kernel), - dim3(BlockDim, 1, 1), - dim3(n_queries, 1, 1), - n_probes}; - } -}; - -/** - * Look up the chunk id corresponding to the sample index. - * - * Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is an - * ordered number of one of such vectors. This function looks up to which chunk it belongs, - * and returns the index within the chunk (which is also an index within a cluster). - * - * @param[inout] sample_ix - * input: the offset of the sample in the batch; - * output: the offset inside the chunk (probe) / selected cluster. - * @param[in] n_probes number of probes - * @param[in] chunk_indices offsets of the chunks within the batch [n_probes] - * @return chunk index (== n_probes when the input index is not in the valid range, - * which can happen if there is not enough data to output in the selected clusters). - */ -__device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT - uint32_t n_probes, - const uint32_t* chunk_indices) -> uint32_t -{ - uint32_t ix_min = 0; - uint32_t ix_max = n_probes; - do { - uint32_t i = (ix_min + ix_max) / 2; - if (chunk_indices[i] <= sample_ix) { - ix_min = i + 1; - } else { - ix_max = i; - } - } while (ix_min < ix_max); - if (ix_min > 0) { sample_ix -= chunk_indices[ix_min - 1]; } - return ix_min; -} - -template -__launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -{ - const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x); - const uint32_t query_ix = i / uint64_t(topk); - if (query_ix >= n_queries) { return; } - const uint32_t k = i % uint64_t(topk); - neighbors_in += query_ix * topk; - neighbors_out += query_ix * topk; - chunk_indices += query_ix * n_probes; - clusters_to_probe += query_ix * n_probes; - uint32_t data_ix = neighbors_in[k]; - const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); - const bool valid = chunk_ix < n_probes; - neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : ivf_pq::kOutOfBoundsRecord; -} - -/** - * Transform found sample indices into the corresponding database indices - * (as stored in index.indices()). - * The sample indices are the record indices as they appear in the database view formed by the - * probed clusters / defined by the `chunk_indices`. - * We assume the searched sample sizes (for a single query) fit into `uint32_t`. - */ -template -void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk, - rmm::cuda_stream_view stream) -{ - constexpr int kPNThreads = 256; - const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel - <<>>(neighbors_out, - neighbors_in, - db_indices, - clusters_to_probe, - chunk_indices, - n_queries, - n_probes, - topk); -} - -/** - * Post-process the scores depending on the metric type; - * translate the element type if necessary. - */ -template -void postprocess_distances(float* out, // [n_queries, topk] - const ScoreT* in, // [n_queries, topk] - distance::DistanceType metric, - uint32_t n_queries, - uint32_t topk, - float scaling_factor, - rmm::cuda_stream_view stream) -{ - size_t len = size_t(n_queries) * size_t(topk); - switch (metric) { - case distance::DistanceType::L2Unexpanded: - case distance::DistanceType::L2Expanded: { - linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); - } break; - case distance::DistanceType::L2SqrtUnexpanded: - case distance::DistanceType::L2SqrtExpanded: { - linalg::unaryOp( - out, - in, - len, - raft::compose_op{ - raft::mul_const_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, - stream); - } break; - case distance::DistanceType::InnerProduct: { - linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); - } break; - default: RAFT_FAIL("Unexpected metric."); - } -} - -/** - * An approximation to the number of times each cluster appears in a batched sample. - * - * If the pairs (probe_ix, query_ix) are sorted by the probe_ix, there is a good chance that - * the same probe_ix (cluster) is processed by several blocks on a single SM. This greatly - * increases the L1 cache hit rate (i.e. increases the data locality). - * - * This function gives an estimate of how many times a specific cluster may appear in the - * batch. Thus, it gives a practical limit to how many blocks should be active on the same SM - * to improve the L1 cache hit rate. - */ -constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, - uint32_t n_probes, - uint32_t n_queries) -> uint32_t -{ - /* - Let say: - n = n_clusters - k = n_probes - m = n_queries - r = # of times a specific block appears in the batched sample. - - Then, r has the Binomial distribution (p = k / n): - P(r) = C(m,r) * k^r * (n - k)^(m - r) / n^m - E[r] = m * k / n - E[r | r > 0] = m * k / n / (1 - (1 - k/n)^m) - - The latter can be approximated by a much simpler formula, assuming (k / n) -> 0: - E[r | r > 0] = 1 + (m - 1) * k / (2 * n) + O( (k/n)^2 ) - */ - return 1 + (n_queries - 1) * n_probes / (2 * n_clusters); -} - -/** - * The "main part" of the search, which assumes that outer-level `search` has already: - * - * 1. computed the closest clusters to probe (`clusters_to_probe`); - * 2. transformed input queries into the rotated space (rot_dim); - * 3. split the query batch into smaller chunks, so that the device workspace - * is guaranteed to fit into GPU memory. - */ -template -void ivfpq_search_worker(raft::resources const& handle, - const index& index, - uint32_t max_samples, - uint32_t n_probes, - uint32_t topK, - uint32_t n_queries, - uint32_t queries_offset, // needed for filtering - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const float* query, // [n_queries, rot_dim] - IdxT* neighbors, // [n_queries, topK] - float* distances, // [n_queries, topK] - float scaling_factor, - double preferred_shmem_carveout, - IvfSampleFilterT sample_filter) -{ - auto stream = resource::get_cuda_stream(handle); - auto mr = resource::get_workspace_resource(handle); - - bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries); - auto topk_len = manage_local_topk ? n_probes * topK : max_samples; - std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); - std::size_t n_queries_topk_len = std::size_t(n_queries) * std::size_t(topk_len); - if (manage_local_topk) { - RAFT_LOG_DEBUG("Fused version of the search kernel is selected (manage_local_topk == true)"); - } else { - RAFT_LOG_DEBUG( - "Non-fused version of the search kernel is selected (manage_local_topk == false)"); - } - - rmm::device_uvector index_list_sorted_buf(0, stream, mr); - uint32_t* index_list_sorted = nullptr; - rmm::device_uvector num_samples(n_queries, stream, mr); - rmm::device_uvector chunk_index(n_queries_probes, stream, mr); - // [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk] - rmm::device_uvector distances_buf(n_queries_topk_len, stream, mr); - rmm::device_uvector neighbors_buf(0, stream, mr); - uint32_t* neighbors_ptr = nullptr; - if (manage_local_topk) { - neighbors_buf.resize(n_queries_topk_len, stream); - neighbors_ptr = neighbors_buf.data(); - } - rmm::device_uvector neighbors_uint32_buf(0, stream, mr); - uint32_t* neighbors_uint32 = nullptr; - if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { - neighbors_uint32 = reinterpret_cast(neighbors); - } else { - neighbors_uint32_buf.resize(n_queries * topK, stream); - neighbors_uint32 = neighbors_uint32_buf.data(); - } - - calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), - clusters_to_probe, - chunk_index.data(), - num_samples.data(), - stream); - - auto coresidency = expected_probe_coresidency(index.n_lists(), n_probes, n_queries); - - if (coresidency > 1) { - // Sorting index by cluster number (label). - // The goal is to incrase the L2 cache hit rate to read the vectors - // of a cluster by processing the cluster at the same time as much as - // possible. - index_list_sorted_buf.resize(n_queries_probes, stream); - auto index_list_buf = - raft::make_device_mdarray(handle, mr, make_extents(n_queries_probes)); - rmm::device_uvector cluster_labels_out(n_queries_probes, stream, mr); - auto index_list = index_list_buf.data_handle(); - index_list_sorted = index_list_sorted_buf.data(); - - linalg::map_offset(handle, index_list_buf.view(), identity_op{}); - - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortPairs(nullptr, - cub_workspace_size, - clusters_to_probe, - cluster_labels_out.data(), - index_list, - index_list_sorted, - n_queries_probes, - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, mr); - cub::DeviceRadixSort::SortPairs(cub_workspace.data(), - cub_workspace_size, - clusters_to_probe, - cluster_labels_out.data(), - index_list, - index_list_sorted, - n_queries_probes, - begin_bit, - end_bit, - stream); - } - - // select and run the main search kernel - uint32_t precomp_data_count = 0; - switch (index.metric()) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2SqrtUnexpanded: - case distance::DistanceType::L2Unexpanded: - case distance::DistanceType::L2Expanded: { - // stores basediff (query[i] - center[i]) - precomp_data_count = index.rot_dim(); - } break; - case distance::DistanceType::InnerProduct: { - // stores two components (query[i] * center[i], query[i] * center[i]) - precomp_data_count = index.rot_dim() * 2; - } break; - default: { - RAFT_FAIL("Unsupported metric"); - } break; - } - - auto search_instance = compute_similarity_select( - resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); - - rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); - std::optional> query_kths_buf{std::nullopt}; - float* query_kths = nullptr; - if (manage_local_topk) { - query_kths_buf.emplace( - raft::make_device_mdarray(handle, mr, make_extents(n_queries))); - linalg::map(handle, - query_kths_buf->view(), - raft::const_op{dummy_block_sort_t::queue_t::kDummy}); - query_kths = query_kths_buf->data_handle(); - } - compute_similarity_run(search_instance, - stream, - index.rot_dim(), - n_probes, - index.pq_dim(), - n_queries, - queries_offset, - index.metric(), - index.codebook_kind(), - topK, - max_samples, - index.centers_rot().data_handle(), - index.pq_centers().data_handle(), - index.data_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - query, - index_list_sorted, - query_kths, - sample_filter, - device_lut.data(), - distances_buf.data(), - neighbors_ptr); - - // Select topk vectors for each query - rmm::device_uvector topk_dists(n_queries * topK, stream, mr); - raft::matrix::detail::select_k(handle, - distances_buf.data(), - neighbors_ptr, - n_queries, - topk_len, - topK, - topk_dists.data(), - neighbors_uint32, - true, - mr); - - // Postprocessing - postprocess_distances( - distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream); - postprocess_neighbors(neighbors, - neighbors_uint32, - index.inds_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - n_queries, - n_probes, - topK, - stream); -} - -/** - * This structure helps selecting a proper instance of the worker search function, - * which contains a few template parameters. - */ -template -struct ivfpq_search { - public: - using fun_t = decltype(&ivfpq_search_worker); - - /** - * Select an instance of the ivf-pq search function based on search tuning parameters, - * such as the look-up data type or the internal score type. - */ - static auto fun(const search_params& params, distance::DistanceType metric) -> fun_t - { - return fun_try_score_t(params, metric); - } - - private: - template - static auto filter_reasonable_instances(const search_params& params) -> fun_t - { - if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; - } else { - RAFT_FAIL( - "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " - "Size of the internal_distance_dtype should be not smaller than the size of the lut_dtype.", - int(params.lut_dtype), - int(params.internal_distance_dtype)); - } - } - - template - static auto fun_try_lut_t(const search_params& params, distance::DistanceType metric) -> fun_t - { - bool signed_metric = false; - switch (metric) { - case cuvs::distance::DistanceType::InnerProduct: signed_metric = true; break; - default: break; - } - - switch (params.lut_dtype) { - case CUDA_R_32F: return filter_reasonable_instances(params); - case CUDA_R_16F: return filter_reasonable_instances(params); - case CUDA_R_8U: - case CUDA_R_8I: - if (signed_metric) { - return filter_reasonable_instances>(params); - } else { - return filter_reasonable_instances>(params); - } - default: RAFT_FAIL("Unexpected lut_dtype (%d)", int(params.lut_dtype)); - } - } - - static auto fun_try_score_t(const search_params& params, distance::DistanceType metric) -> fun_t - { - switch (params.internal_distance_dtype) { - case CUDA_R_32F: return fun_try_lut_t(params, metric); - case CUDA_R_16F: return fun_try_lut_t(params, metric); - default: - RAFT_FAIL("Unexpected internal_distance_dtype (%d)", int(params.internal_distance_dtype)); - } - } -}; - -/** - * A heuristic for bounding the number of queries per batch, to improve GPU utilization. - * (based on the number of SMs and the work size). - * - * @param res is used to query the workspace size - * @param k top-k - * @param n_probes number of selected clusters per query - * @param n_queries number of queries hoped to be processed at once. - * (maximum value for the returned batch size) - * @param max_samples maximum possible number of samples to be processed for the given `n_probes` - * - * @return maximum recommended batch size. - */ -inline auto get_max_batch_size(raft::resources const& res, - uint32_t k, - uint32_t n_probes, - uint32_t n_queries, - uint32_t max_samples) -> uint32_t -{ - uint32_t max_batch_size = n_queries; - uint32_t n_ctas_total = getMultiProcessorCount() * 2; - uint32_t n_ctas_total_per_batch = n_ctas_total / max_batch_size; - float utilization = float(n_ctas_total_per_batch * max_batch_size) / n_ctas_total; - if (n_ctas_total_per_batch > 1 || (n_ctas_total_per_batch == 1 && utilization < 0.6)) { - uint32_t n_ctas_total_per_batch_1 = n_ctas_total_per_batch + 1; - uint32_t max_batch_size_1 = n_ctas_total / n_ctas_total_per_batch_1; - float utilization_1 = float(n_ctas_total_per_batch_1 * max_batch_size_1) / n_ctas_total; - if (utilization < utilization_1) { max_batch_size = max_batch_size_1; } - } - // Check in the tmp distance buffer is not too big - auto ws_size = [k, n_probes, max_samples](uint32_t bs) -> uint64_t { - const uint64_t buffers_fused = 12ull * k * n_probes; - const uint64_t buffers_non_fused = 4ull * max_samples; - const uint64_t other = 32ull * n_probes; - return static_cast(bs) * - (other + (is_local_topk_feasible(k, n_probes, bs) ? buffers_fused : buffers_non_fused)); - }; - auto max_ws_size = resource::get_workspace_free_bytes(res); - if (ws_size(max_batch_size) > max_ws_size) { - uint32_t smaller_batch_size = bound_by_power_of_two(max_batch_size); - // gradually reduce the batch size until we fit into the max size limit. - while (smaller_batch_size > 1 && ws_size(smaller_batch_size) > max_ws_size) { - smaller_batch_size >>= 1; - } - return smaller_batch_size; - } - return max_batch_size; -} - -/** See cuvs::spatial::knn::ivf_pq::search docs */ -template -inline void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "Unsupported element type."); - raft::common::nvtx::range fun_scope( - "ivf_pq::search(n_queries = %u, n_probes = %u, k = %u, dim = %zu)", - n_queries, - params.n_probes, - k, - index.dim()); - resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::search"); - - RAFT_EXPECTS( - params.internal_distance_dtype == CUDA_R_16F || params.internal_distance_dtype == CUDA_R_32F, - "internal_distance_dtype must be either CUDA_R_16F or CUDA_R_32F"); - RAFT_EXPECTS(params.lut_dtype == CUDA_R_16F || params.lut_dtype == CUDA_R_32F || - params.lut_dtype == CUDA_R_8U, - "lut_dtype must be CUDA_R_16F, CUDA_R_32F or CUDA_R_8U"); - RAFT_EXPECTS(k > 0, "parameter `k` in top-k must be positive."); - RAFT_EXPECTS( - k <= index.size(), - "parameter `k` (%u) in top-k must not be larger that the total size of the index (%zu)", - k, - static_cast(index.size())); - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - - switch (utils::check_pointer_residency(queries, neighbors, distances)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: RAFT_FAIL("all pointers must be accessible from the device."); - } - - auto stream = resource::get_cuda_stream(handle); - - auto dim = index.dim(); - auto dim_ext = index.dim_ext(); - auto n_probes = std::min(params.n_probes, index.n_lists()); - - uint32_t max_samples = 0; - { - IdxT ms = raft::Pow2<128>::roundUp(index.accum_sorted_sizes()(n_probes)); - RAFT_EXPECTS(ms <= IdxT(std::numeric_limits::max()), - "The maximum sample size is too big."); - max_samples = ms; - } - - auto mr = resource::get_workspace_resource(handle); - - // Maximum number of query vectors to search at the same time. - const auto max_queries = std::min(std::max(n_queries, 1), 4096); - auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); - - rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); - rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); - rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - - auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( - index.inds_ptrs().data_handle(), sample_filter); - auto search_instance = ivfpq_search::fun(params, index.metric()); - - for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { - uint32_t queries_batch = min(max_queries, n_queries - offset_q); - - select_clusters(handle, - clusters_to_probe.data(), - float_queries.data(), - queries_batch, - n_probes, - index.n_lists(), - dim, - dim_ext, - index.metric(), - queries + static_cast(dim) * offset_q, - index.centers().data_handle(), - mr); - - // Rotate queries - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index.rot_dim(), - queries_batch, - dim, - &alpha, - index.rotation_matrix().data_handle(), - dim, - float_queries.data(), - dim_ext, - &beta, - rot_queries.data(), - index.rot_dim(), - stream); - - for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) { - uint32_t batch_size = min(max_batch_size, queries_batch - offset_b); - /* The distance calculation is done in the rotated/transformed space; - as long as `index.rotation_matrix()` is orthogonal, the distances and thus results are - preserved. - */ - search_instance(handle, - index, - max_samples, - n_probes, - k, - batch_size, - offset_q + offset_b, - clusters_to_probe.data() + uint64_t(n_probes) * offset_b, - rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, - neighbors + uint64_t(k) * (offset_q + offset_b), - distances + uint64_t(k) * (offset_q + offset_b), - utils::config::kDivisor / utils::config::kDivisor, - params.preferred_shmem_carveout, - filter_adapter); - } - } -} - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/cuvs/neighbors/detail/ivf_pq_serialize.cuh deleted file mode 100644 index 79d059c46..000000000 --- a/cpp/include/cuvs/neighbors/detail/ivf_pq_serialize.cuh +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -namespace cuvs::neighbors::ivf_pq::detail { - -// Serialization version -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int kSerializationVersion = 3; - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index IVF-PQ index - * - */ -template -void serialize(raft::resources const& handle_, std::ostream& os, const index& index) -{ - RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", - static_cast(index.size()), - static_cast(index.dim()), - static_cast(index.pq_dim()), - static_cast(index.pq_bits())); - - serialize_scalar(handle_, os, kSerializationVersion); - serialize_scalar(handle_, os, index.size()); - serialize_scalar(handle_, os, index.dim()); - serialize_scalar(handle_, os, index.pq_bits()); - serialize_scalar(handle_, os, index.pq_dim()); - serialize_scalar(handle_, os, index.conservative_memory_allocation()); - - serialize_scalar(handle_, os, index.metric()); - serialize_scalar(handle_, os, index.codebook_kind()); - serialize_scalar(handle_, os, index.n_lists()); - - serialize_mdspan(handle_, os, index.pq_centers()); - serialize_mdspan(handle_, os, index.centers()); - serialize_mdspan(handle_, os, index.centers_rot()); - serialize_mdspan(handle_, os, index.rotation_matrix()); - - auto sizes_host = - raft::make_host_mdarray(index.list_sizes().extents()); - copy(sizes_host.data_handle(), - index.list_sizes().data_handle(), - sizes_host.size(), - resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); - serialize_mdspan(handle_, os, sizes_host.view()); - auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; - for (uint32_t label = 0; label < index.n_lists(); label++) { - ivf::serialize_list(handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); - } -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index IVF-PQ index - * - */ -template -void serialize(raft::resources const& handle_, - const std::string& filename, - const index& index) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(handle_, of, index); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } - return; -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - */ -template -auto deserialize(raft::resources const& handle_, std::istream& is) -> index -{ - auto ver = deserialize_scalar(handle_, is); - if (ver != kSerializationVersion) { - RAFT_FAIL("serialization version mismatch %d vs. %d", ver, kSerializationVersion); - } - auto n_rows = deserialize_scalar(handle_, is); - auto dim = deserialize_scalar(handle_, is); - auto pq_bits = deserialize_scalar(handle_, is); - auto pq_dim = deserialize_scalar(handle_, is); - auto cma = deserialize_scalar(handle_, is); - - auto metric = deserialize_scalar(handle_, is); - auto codebook_kind = deserialize_scalar(handle_, is); - auto n_lists = deserialize_scalar(handle_, is); - - RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", - static_cast(n_rows), - static_cast(dim), - static_cast(pq_dim), - static_cast(pq_bits), - static_cast(n_lists)); - - auto index = cuvs::neighbors::ivf_pq::index( - handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); - - deserialize_mdspan(handle_, is, index.pq_centers()); - deserialize_mdspan(handle_, is, index.centers()); - deserialize_mdspan(handle_, is, index.centers_rot()); - deserialize_mdspan(handle_, is, index.rotation_matrix()); - deserialize_mdspan(handle_, is, index.list_sizes()); - auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; - auto list_store_spec = list_spec{pq_bits, pq_dim, true}; - for (auto& list : index.lists()) { - ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec); - } - - resource::sync_stream(handle_); - - recompute_internal_state(handle_, index); - - return index; -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - */ -template -auto deserialize(raft::resources const& handle_, const std::string& filename) -> index -{ - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(handle_, infile); - - infile.close(); - - return index; -} - -} // namespace cuvs::neighbors::ivf_pq::detail diff --git a/cpp/include/cuvs/neighbors/detail/knn_brute_force.cuh b/cpp/include/cuvs/neighbors/detail/knn_brute_force.cuh deleted file mode 100644 index 6914ea030..000000000 --- a/cpp/include/cuvs/neighbors/detail/knn_brute_force.cuh +++ /dev/null @@ -1,550 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::detail { -using namespace cuvs::spatial::knn::detail; -using namespace cuvs::spatial::knn; - -/** - * Calculates brute force knn, using a fixed memory budget - * by tiling over both the rows and columns of pairwise_distances - */ -template -void tiled_brute_force_knn(const raft::resources& handle, - const ElementType* search, // size (m ,d) - const ElementType* index, // size (n ,d) - size_t m, - size_t n, - size_t d, - size_t k, - ElementType* distances, // size (m, k) - IndexType* indices, // size (m, k) - cuvs::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr) -{ - // Figure out the number of rows/cols to tile for - size_t tile_rows = 0; - size_t tile_cols = 0; - auto stream = raft::resource::get_cuda_stream(handle); - auto device_memory = raft::resource::get_workspace_resource(handle); - auto total_mem = device_memory->get_mem_info(stream).second; - faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); - - // for unittesting, its convenient to be able to put a max size on the tiles - // so we can test the tiling logic without having to use huge inputs. - if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } - if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } - - // tile_cols must be at least k items - tile_cols = std::max(tile_cols, k); - - // stores pairwise distances for the current tile - rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); - - // calculate norms for L2 expanded distances - this lets us avoid calculating - // norms repeatedly per-tile, and just do once for the entire input - auto pairwise_metric = metric; - rmm::device_uvector search_norms(0, stream); - rmm::device_uvector index_norms(0, stream); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded) { - if (!precomputed_search_norms) { search_norms.resize(m, stream); } - if (!precomputed_index_norms) { index_norms.resize(n, stream); } - // cosine needs the l2norm, where as l2 distances needs the squared norm - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - if (!precomputed_search_norms) { - raft::linalg::rowNorm(search_norms.data(), - search, - d, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); - } - if (!precomputed_index_norms) { - raft::linalg::rowNorm(index_norms.data(), - index, - d, - n, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); - } - } else { - if (!precomputed_search_norms) { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - } - if (!precomputed_index_norms) { - raft::linalg::rowNorm( - index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); - } - } - pairwise_metric = cuvs::distance::DistanceType::InnerProduct; - } - - // if we're tiling over columns, we need additional buffers for temporary output - // distances/indices - size_t num_col_tiles = raft::ceildiv(n, tile_cols); - size_t temp_out_cols = k * num_col_tiles; - - // the final column tile could have less than 'k' items in it - // in which case the number of columns here is too high in the temp output. - // adjust if necessary - auto last_col_tile_size = n % tile_cols; - if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; } - - // if we have less than k items in the index, we should fill out the result - // to indicate that we are missing items (and match behaviour in faiss) - if (n < k) { - raft::matrix::fill(handle, - raft::make_device_matrix_view(distances, m, k), - std::numeric_limits::lowest()); - - if constexpr (std::is_signed_v) { - raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1}); - } - } - - rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); - rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); - - bool select_min = cuvs::distance::is_min_close(metric); - - for (size_t i = 0; i < m; i += tile_rows) { - size_t current_query_size = std::min(tile_rows, m - i); - - for (size_t j = 0; j < n; j += tile_cols) { - size_t current_centroid_size = std::min(tile_cols, n - j); - size_t current_k = std::min(current_centroid_size, k); - - // calculate the top-k elements for the current tile, by calculating the - // full pairwise distance for the tile - and then selecting the top-k from that - // note: we're using a int32 IndexType here on purpose in order to - // use the pairwise_distance instantiations. Since the tile size will ensure - // that the total memory is < 1GB per tile, this will not cause any issues - distance::pairwise_distance(handle, - search + i * d, - index + j * d, - temp_distances.data(), - current_query_size, - current_centroid_size, - d, - pairwise_metric, - true, - metric_arg); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); - auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); - auto dist = temp_distances.data(); - bool sqrt = metric == cuvs::distance::DistanceType::L2SqrtExpanded; - - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - - cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); - auto val = l2_op(row_norms[row], col_norms[col], dist[idx]); - return distance_epilogue(val, row, col); - }); - } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { - auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); - auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); - auto dist = temp_distances.data(); - - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]); - val = distance_epilogue(val, row, col); - return val; - }); - } else { - // if we're not l2 distance, and we have a distance epilogue - run it now - if constexpr (!std::is_same_v) { - auto distances_ptr = temp_distances.data(); - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(temp_distances.data(), - current_query_size * current_centroid_size), - [=] __device__(size_t idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - return distance_epilogue(distances_ptr[idx], row, col); - }); - } - } - - raft::matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_distances.data(), current_query_size, current_centroid_size), - std::nullopt, - raft::make_device_matrix_view( - distances + i * k, current_query_size, current_k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, current_k), - select_min, - true); - - // if we're tiling over columns, we need to do a couple things to fix up - // the output of select_k - // 1. The column id's in the output are relative to the tile, so we need - // to adjust the column ids by adding the column the tile starts at (j) - // 2. select_k writes out output in a row-major format, which means we - // can't just concat the output of all the tiles and do a select_k on the - // concatenation. - // Fix both of these problems in a single pass here - if (tile_cols != n) { - const ElementType* in_distances = distances + i * k; - const IndexType* in_indices = indices + i * k; - ElementType* out_distances = temp_out_distances.data(); - IndexType* out_indices = temp_out_indices.data(); - - auto count = thrust::make_counting_iterator(0); - thrust::for_each(raft::resource::get_thrust_policy(handle), - count, - count + current_query_size * current_k, - [=] __device__(IndexType i) { - IndexType row = i / current_k, col = i % current_k; - IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; - - out_distances[out_index] = in_distances[i]; - out_indices[out_index] = in_indices[i] + j; - }); - } - } - - if (tile_cols != n) { - // select the actual top-k items here from the temporary output - raft::matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_out_distances.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - temp_out_indices.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - distances + i * k, current_query_size, k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, k), - select_min, - true); - } - } -} - -/** - * Search the kNN for the k-nearest neighbors of a set of query vectors - * @param[in] input vector of device device memory array pointers to search - * @param[in] sizes vector of memory sizes for each device array pointer in input - * @param[in] D number of cols in input and search_items - * @param[in] search_items set of vectors to query for neighbors - * @param[in] n number of items in search_items - * @param[out] res_I pointer to device memory for returning k nearest indices - * @param[out] res_D pointer to device memory for returning k nearest distances - * @param[in] k number of neighbors to query - * @param[in] userStream the main cuda stream to use - * @param[in] internalStreams optional when n_params > 0, the index partitions can be - * queried in parallel using these streams. Note that n_int_streams also - * has to be > 0 for these to be used and their cardinality does not need - * to correspond to n_parts. - * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the - * user stream will be used. - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] translations translation ids for indices when index rows represent - * non-contiguous partitions - * @param[in] metric corresponds to the cuvs::distance::DistanceType enum (default is L2Expanded) - * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - */ -template -void brute_force_knn_impl( - raft::resources const& handle, - std::vector& input, - std::vector& sizes, - IntType D, - value_t* search_items, - IntType n, - IdxType* res_I, - value_t* res_D, - IntType k, - bool rowMajorIndex = true, - bool rowMajorQuery = true, - std::vector* translations = nullptr, - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded, - float metricArg = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - std::vector* input_norms = nullptr, - const value_t* search_norms = nullptr) -{ - auto userStream = resource::get_cuda_stream(handle); - - ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - - std::vector* id_ranges; - if (translations == nullptr) { - // If we don't have explicit translations - // for offsets of the indices, build them - // from the local partitions - id_ranges = new std::vector(); - IdxType total_n = 0; - for (size_t i = 0; i < input.size(); i++) { - id_ranges->push_back(total_n); - total_n += sizes[i]; - } - } else { - // otherwise, use the given translations - id_ranges = translations; - } - - int device; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - rmm::device_uvector trans(id_ranges->size(), userStream); - raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); - - rmm::device_uvector all_D(0, userStream); - rmm::device_uvector all_I(0, userStream); - - value_t* out_D = res_D; - IdxType* out_I = res_I; - - if (input.size() > 1) { - all_D.resize(input.size() * k * n, userStream); - all_I.resize(input.size() * k * n, userStream); - - out_D = all_D.data(); - out_I = all_I.data(); - } - - // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitations of the pairwise_distance API: - // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have - // multiple options here (like rowMajorQuery/rowMajorIndex) - // 2) because of tiling, we need to be able to set a custom stride in the PW - // api, which isn't supported - // Instead, transpose the input matrices if they are passed as col-major. - auto search = search_items; - rmm::device_uvector search_row_major(0, userStream); - if (!rowMajorQuery) { - search_row_major.resize(n * D, userStream); - raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); - search = search_row_major.data(); - } - - // transpose into a temporary buffer if necessary - rmm::device_uvector index_row_major(0, userStream); - if (!rowMajorIndex) { - size_t total_size = 0; - for (auto size : sizes) { - total_size += size; - } - index_row_major.resize(total_size * D, userStream); - } - - // Make other streams from pool wait on main stream - resource::wait_stream_pool_on_stream(handle); - - size_t total_rows_processed = 0; - for (size_t i = 0; i < input.size(); i++) { - value_t* out_d_ptr = out_D + (i * k * n); - IdxType* out_i_ptr = out_I + (i * k * n); - - auto stream = resource::get_next_usable_stream(handle, i); - - if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - std::is_same_v && - (metric == cuvs::distance::DistanceType::L2Unexpanded || - metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || - metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { - fusedL2Knn(D, - out_i_ptr, - out_d_ptr, - input[i], - search_items, - sizes[i], - n, - k, - rowMajorIndex, - rowMajorQuery, - stream, - metric, - input_norms ? (*input_norms)[i] : nullptr, - search_norms); - - // Perform necessary post-processing - if (metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || - metric == cuvs::distance::DistanceType::LpUnexpanded) { - float p = 0.5; // standard l2 - if (metric == cuvs::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( - res_D, - res_D, - n * k, - [p] __device__(float input) { return powf(fabsf(input), p); }, - stream); - } - } else { - switch (metric) { - case cuvs::distance::DistanceType::Haversine: - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - // Create a new handle with the current stream from the stream pool - raft::resources stream_pool_handle(handle); - raft::resource::set_cuda_stream(stream_pool_handle, stream); - - auto index = input[i]; - if (!rowMajorIndex) { - index = index_row_major.data() + total_rows_processed * D; - total_rows_processed += sizes[i]; - raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); - } - - tiled_brute_force_knn(stream_pool_handle, - search, - index, - n, - sizes[i], - D, - k, - out_d_ptr, - out_i_ptr, - metric, - metricArg, - 0, - 0, - distance_epilogue, - input_norms ? (*input_norms)[i] : nullptr, - search_norms); - break; - } - } - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - // Sync internal streams if used. We don't need to - // sync the user stream because we'll already have - // fully serial execution. - resource::sync_stream_pool(handle); - - if (input.size() > 1 || translations != nullptr) { - // This is necessary for proper index translations. If there are - // no translations or partitions to combine, it can be skipped. - knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); - } - - if (translations == nullptr) delete id_ranges; -}; - -template -void brute_force_search( - raft::resources const& res, - const cuvs::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - std::optional> query_norms = std::nullopt) -{ - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - - auto k = neighbors.extent(1); - auto d = idx.dataset().extent(1); - - std::vector dataset = {const_cast(idx.dataset().data_handle())}; - std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - true, - nullptr, - idx.metric(), - idx.metric_arg(), - raft::identity_op(), - norms.size() ? &norms : nullptr, - query_norms ? query_norms->data_handle() : nullptr); -} -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/knn_brute_force_batch_k_query.cuh b/cpp/include/cuvs/neighbors/detail/knn_brute_force_batch_k_query.cuh deleted file mode 100644 index 8d6dce407..000000000 --- a/cpp/include/cuvs/neighbors/detail/knn_brute_force_batch_k_query.cuh +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace cuvs::neighbors::brute_force::detail { -template -class gpu_batch_k_query : public batch_k_query { - public: - gpu_batch_k_query(const raft::resources& res, - const cuvs::neighbors::brute_force::index& index, - raft::device_matrix_view query, - int64_t batch_size) - : batch_k_query(res, index.size(), query.extent(0), batch_size), - index(index), - query(query) - { - auto metric = index.metric(); - - // precompute query norms, and re-use across batches - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded) { - query_norms = raft::make_device_vector(res, query.extent(0)); - - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(res, - query, - query_norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op{}); - } else { - raft::linalg::norm(res, - query, - query_norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - } - } - } - - protected: - void load_batch(int64_t offset, int64_t next_batch_size, batch* output) const override - { - if (offset >= index.size()) { return; } - - // we're aiming to load multiple batches here - since we don't know the max iteration - // grow the size we're loading exponentially - int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), this->index_size); - output->resize(this->res, this->query_size, batch_size); - - std::optional> query_norms_view; - if (query_norms) { query_norms_view = query_norms->view(); } - - cuvs::neighbors::detail::brute_force_search( - this->res, index, query, output->indices(), output->distances(), query_norms_view); - }; - - void slice_batch(const batch& input, - int64_t offset, - int64_t batch_size, - batch* output) const override - { - auto num_queries = input.indices().extent(0); - batch_size = std::min(batch_size, index.size() - offset); - - output->resize(this->res, num_queries, batch_size); - - if (!num_queries || !batch_size) { return; } - - raft::matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; - raft::matrix::slice(this->res, input.indices(), output->indices(), coords); - raft::matrix::slice(this->res, input.distances(), output->distances(), coords); - } - - const cuvs::neighbors::brute_force::index& index; - raft::device_matrix_view query; - std::optional> query_norms; -}; -} // namespace cuvs::neighbors::brute_force::detail diff --git a/cpp/include/cuvs/neighbors/detail/knn_merge_parts.cuh b/cpp/include/cuvs/neighbors/detail/knn_merge_parts.cuh deleted file mode 100644 index 00610c45e..000000000 --- a/cpp/include/cuvs/neighbors/detail/knn_merge_parts.cuh +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace cuvs::neighbors::detail { - -template -RAFT_KERNEL knn_merge_parts_kernel(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - value_t initK, - value_idx initV, - int k, - value_idx* translations) -{ - constexpr int kNumWarps = tpb / raft::WarpSize; - - __shared__ value_t smemK[kNumWarps * warp_q]; - __shared__ value_idx smemV[kNumWarps * warp_q]; - - /** - * Uses shared memory - */ - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - int total_k = k * n_parts; - - int i = threadIdx.x; - - // Get starting pointers for cols in current thread - int part = i / k; - size_t row_idx = (row * k) + (part * n_samples * k); - - int col = i % k; - - const value_t* inKStart = inK + (row_idx + col); - const value_idx* inVStart = inV + (row_idx + col); - - int limit = raft::Pow2::roundDown(total_k); - value_idx translation = 0; - - for (; i < limit; i += tpb) { - translation = translations[part]; - heap.add(*inKStart, (*inVStart) + translation); - - part = (i + tpb) / k; - row_idx = (row * k) + (part * n_samples * k); - - col = (i + tpb) % k; - - inKStart = inK + (row_idx + col); - inVStart = inV + (row_idx + col); - } - - // Handle last remainder fraction of a warp of elements - if (i < total_k) { - translation = translations[part]; - heap.addThreadQ(*inKStart, (*inVStart) + translation); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void knn_merge_parts_impl(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - auto grid = dim3(n_samples); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = std::numeric_limits::max(); - auto vInit = -1; - knn_merge_parts_kernel - <<>>( - inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * @brief Merge knn distances and index matrix, which have been partitioned - * by row, into a single matrix with only the k-nearest neighbors. - * - * @param inK partitioned knn distance matrix - * @param inV partitioned knn index matrix - * @param outK merged knn distance matrix - * @param outV merged knn index matrix - * @param n_samples number of samples per partition - * @param n_parts number of partitions - * @param k number of neighbors per partition (also number of merged neighbors) - * @param stream CUDA stream to use - * @param translations mapping of index offsets for each partition - */ -template -inline void knn_merge_parts(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - if (k == 1) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 32) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 64) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 128) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 256) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 512) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 1024) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); -} -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/nn_descent.cuh b/cpp/include/cuvs/neighbors/detail/nn_descent.cuh deleted file mode 100644 index cd2208bfa..000000000 --- a/cpp/include/cuvs/neighbors/detail/nn_descent.cuh +++ /dev/null @@ -1,1456 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include - -#include "../nn_descent_types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include // raft::util::arch::SM_* -#include -#include -#include -#include - -namespace cuvs::neighbors::experimental::nn_descent::detail { - -using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; -template -using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; - -using DistData_t = float; -constexpr int DEGREE_ON_DEVICE{32}; -constexpr int SEGMENT_SIZE{32}; -constexpr int counter_interval{100}; -template -struct InternalID_t; - -// InternalID_t uses 1 bit for marking (new or old). -template <> -class InternalID_t { - private: - using Index_t = int; - Index_t id_{std::numeric_limits::max()}; - - public: - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const - { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ void mark_old() - { - if (id_ >= 0) id_ = -id_ - 1; - } - __host__ __device__ bool operator==(const InternalID_t& other) const - { - return id() == other.id(); - } -}; - -template -struct ResultItem; - -template <> -class ResultItem { - private: - using Index_t = int; - Index_t id_; - DistData_t dist_; - - public: - __host__ __device__ ResultItem() - : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; - __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) - : id_(id_with_flag), dist_(dist){}; - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const - { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ DistData_t& dist() { return dist_; } - - __host__ __device__ void mark_old() - { - if (id_ >= 0) id_ = -id_ - 1; - } - - __host__ __device__ bool operator<(const ResultItem& other) const - { - if (dist_ == other.dist_) return id() < other.id(); - return dist_ < other.dist_; - } - __host__ __device__ bool operator==(const ResultItem& other) const - { - return id() == other.id(); - } - __host__ __device__ bool operator>=(const ResultItem& other) const - { - return !(*this < other); - } - __host__ __device__ bool operator<=(const ResultItem& other) const - { - return (*this == other) || (*this < other); - } - __host__ __device__ bool operator>(const ResultItem& other) const - { - return !(*this <= other); - } - __host__ __device__ bool operator!=(const ResultItem& other) const - { - return !(*this == other); - } -}; - -using align32 = raft::Pow2<32>; - -template -int get_batch_size(const int it_now, const T nrow, const int batch_size) -{ - int it_total = raft::ceildiv(nrow, batch_size); - return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; -} - -// for avoiding bank conflict -template -constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) -{ - // all "4"s are for alignment - if constexpr (std::is_same::value) { - ndim = raft::ceildiv(ndim, 4) * 4; - return ndim + (ndim % 32 == 0) * 4; - } -} - -template -__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) -{ - ResultItem y; - y.dist() = __shfl_xor_sync(raft::warp_full_mask(), x.dist(), mask, raft::warp_size()); - y.id_with_flag() = - __shfl_xor_sync(raft::warp_full_mask(), x.id_with_flag(), mask, raft::warp_size()); - return x < y == dir ? y : x; -} - -__device__ __forceinline__ int xor_swap(int x, int mask, int dir) -{ - int y = __shfl_xor_sync(raft::warp_full_mask(), x, mask, raft::warp_size()); - return x < y == dir ? y : x; -} - -// TODO: Move to RAFT utils https://github.com/rapidsai/raft/issues/1827 -__device__ __forceinline__ uint bfe(uint lane_id, uint pos) -{ - uint res; - asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); - return res; -} - -template -__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) -{ - static_assert(raft::warp_size() == 32); - auto& element = *element_ptr; - element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x10, bfe(lane_id, 4)); - element = xor_swap(element, 0x08, bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 0)); - return; -} - -struct BuildConfig { - size_t max_dataset_size; - size_t dataset_dim; - size_t node_degree{64}; - size_t internal_node_degree{0}; - // If internal_node_degree == 0, the value of node_degree will be assigned to it - size_t max_iterations{50}; - float termination_threshold{0.0001}; -}; - -template -class BloomFilter { - public: - BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) - : nrow_(nrow), - num_sets_per_list_(num_sets_per_list), - num_hashs_(num_hashs), - bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) - { - } - - void add(size_t list_id, Index_t key) - { - if (is_cleared) { is_cleared = false; } - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - } - } - - bool check(size_t list_id, Index_t key) - { - bool is_present = true; - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - - if (!is_present) return false; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - if (!is_present) return false; - } - return true; - } - - void clear() - { - if (is_cleared) return; -#pragma omp parallel for - for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { - bitsets_[i] = 0; - } - is_cleared = true; - } - - private: - uint32_t hash_0(uint32_t value) - { - value *= 1103515245; - value += 12345; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } - - uint32_t hash_1(uint32_t value) - { - value *= 1664525; - value += 1013904223; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } - - static constexpr int num_bits_per_set_ = 512; - bool is_cleared{true}; - std::vector bitsets_; - size_t nrow_; - size_t num_sets_per_list_; - size_t num_hashs_; -}; - -template -struct GnndGraph { - static constexpr int segment_size = 32; - InternalID_t* h_graph; - - size_t nrow; - size_t node_degree; - int num_samples; - int num_segments; - - raft::host_matrix h_dists; - - thrust::host_vector> h_graph_new; - thrust::host_vector> h_list_sizes_new; - - thrust::host_vector> h_graph_old; - thrust::host_vector> h_list_sizes_old; - BloomFilter bloom_filter; - - GnndGraph(const GnndGraph&) = delete; - GnndGraph& operator=(const GnndGraph&) = delete; - GnndGraph(const size_t nrow, - const size_t node_degree, - const size_t internal_node_degree, - const size_t num_samples); - void init_random_graph(); - // TODO: Create a generic bloom filter utility https://github.com/rapidsai/raft/issues/1827 - // Use Bloom filter to sample "new" neighbors for local joining - void sample_graph_new(InternalID_t* new_neighbors, const size_t width); - void sample_graph(bool sample_new); - void update_graph(const InternalID_t* new_neighbors, - const DistData_t* new_dists, - const size_t width, - std::atomic& update_counter); - void sort_lists(); - void clear(); - ~GnndGraph(); -}; - -template -class GNND { - public: - GNND(raft::resources const& res, const BuildConfig& build_config); - GNND(const GNND&) = delete; - GNND& operator=(const GNND&) = delete; - - void build(Data_t* data, const Index_t nrow, Index_t* output_graph); - ~GNND() = default; - using ID_t = InternalID_t; - - private: - void add_reverse_edges(Index_t* graph_ptr, - Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, - int2* list_sizes, - cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0); - - raft::resources const& res; - - BuildConfig build_config_; - GnndGraph graph_; - std::atomic update_counter_; - - size_t nrow_; - size_t ndim_; - - raft::device_matrix<__half, size_t, raft::row_major> d_data_; - raft::device_vector l2_norms_; - - raft::device_matrix graph_buffer_; - raft::device_matrix dists_buffer_; - - // TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827 - thrust::host_vector> graph_host_buffer_; - thrust::host_vector> dists_host_buffer_; - - raft::device_vector d_locks_; - - thrust::host_vector> h_rev_graph_new_; - thrust::host_vector> h_graph_old_; - thrust::host_vector> h_rev_graph_old_; - // int2.x is the number of forward edges, int2.y is the number of reverse edges - - raft::device_vector d_list_sizes_new_; - raft::device_vector d_list_sizes_old_; -}; - -constexpr int TILE_ROW_WIDTH = 64; -constexpr int TILE_COL_WIDTH = 128; - -constexpr int NUM_SAMPLES = 32; -// For now, the max. number of samples is 32, so the sample cache size is fixed -// to 64 (32 * 2). -constexpr int MAX_NUM_BI_SAMPLES = 64; -constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); -constexpr int BLOCK_SIZE = 512; -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 16; - -template -__device__ __forceinline__ void load_vec(Data_t* vec_buffer, - const Data_t* d_vec, - const int load_dims, - const int padding_dims, - const int lane_id) -{ - if constexpr (std::is_same_v or std::is_same_v or - std::is_same_v) { - constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } - } - if constexpr (std::is_same_v) { - if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && - load_dims % 4 == 0 && padding_dims % 4 == 0) { - constexpr int num_load_elems_per_warp = raft::warp_size() * 4; -#pragma unroll - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; - if (idx_in_vec + 4 <= load_dims) { - *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); - } else if (idx_in_vec + 4 <= padding_dims) { - *(float2*)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); - } - } - } else { - constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } - } - } -} - -// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 -/** Calculate L2 norm, and cast data to __half */ -template -RAFT_KERNEL preprocess_data_kernel(const Data_t* input_data, - __half* output_data, - int dim, - DistData_t* l2_norms, - size_t list_offset = 0) -{ - extern __shared__ char buffer[]; - __shared__ float l2_norm; - Data_t* s_vec = (Data_t*)buffer; - size_t list_id = list_offset + blockIdx.x; - - load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); - if (threadIdx.x == 0) { l2_norm = 0; } - __syncthreads(); - int lane_id = threadIdx.x % raft::warp_size(); - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + lane_id; - float part_dist = 0; - if (idx < dim) { - part_dist = s_vec[idx]; - part_dist = part_dist * part_dist; - } - __syncwarp(); - for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { - part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); - } - if (lane_id == 0) { l2_norm += part_dist; } - __syncwarp(); - } - - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + threadIdx.x; - if (idx < dim) { - if (l2_norms == nullptr) { - output_data[list_id * dim + idx] = - (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); - } else { - output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; - if (idx == 0) { l2_norms[list_id] = l2_norm; } - } - } - } -} - -template -RAFT_KERNEL add_rev_edges_kernel(const Index_t* graph, - Index_t* rev_graph, - int num_samples, - int2* list_sizes) -{ - size_t list_id = blockIdx.x; - int2 list_size = list_sizes[list_id]; - - for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { - // each node has same number (num_samples) of forward and reverse edges - size_t rev_list_id = graph[list_id * num_samples + idx]; - // there are already num_samples forward edges - int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); - if (idx_in_rev_list >= num_samples) { - atomicExch(&list_sizes[rev_list_id].y, num_samples); - } else { - rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; - } - } -} - -template > -__device__ void insert_to_global_graph(ResultItem elem, - size_t list_id, - ID_t* graph, - DistData_t* dists, - int node_degree, - int* locks) -{ - int tx = threadIdx.x; - int lane_id = tx % raft::warp_size(); - size_t global_idx_base = list_id * node_degree; - if (elem.id() == list_id) return; - - const int num_segments = raft::ceildiv(node_degree, raft::warp_size()); - - int loop_flag = 0; - do { - int segment_id = elem.id() % num_segments; - if (lane_id == 0) { - loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; - } - - loop_flag = __shfl_sync(raft::warp_full_mask(), loop_flag, 0); - - if (loop_flag == 1) { - ResultItem knn_list_frag; - int local_idx = segment_id * raft::warp_size() + lane_id; - size_t global_idx = global_idx_base + local_idx; - if (local_idx < node_degree) { - knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); - knn_list_frag.dist() = dists[global_idx]; - } - - int pos_to_insert = -1; - ResultItem prev_elem; - - prev_elem.id_with_flag() = - __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.id_with_flag(), 1); - prev_elem.dist() = __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.dist(), 1); - - if (lane_id == 0) { - prev_elem = ResultItem{std::numeric_limits::min(), - std::numeric_limits::lowest()}; - } - if (elem > prev_elem && elem < knn_list_frag) { - pos_to_insert = segment_id * raft::warp_size() + lane_id; - } else if (elem == prev_elem || elem == knn_list_frag) { - pos_to_insert = -2; - } - uint mask = __ballot_sync(raft::warp_full_mask(), pos_to_insert >= 0); - if (mask) { - uint set_lane_id = __fns(mask, 0, 1); - pos_to_insert = __shfl_sync(raft::warp_full_mask(), pos_to_insert, set_lane_id); - } - - if (pos_to_insert >= 0) { - int local_idx = segment_id * raft::warp_size() + lane_id; - if (local_idx > pos_to_insert) { - local_idx++; - } else if (local_idx == pos_to_insert) { - graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); - dists[global_idx_base + local_idx] = elem.dist(); - local_idx++; - } - size_t global_pos = global_idx_base + local_idx; - if (local_idx < (segment_id + 1) * raft::warp_size() && local_idx < node_degree) { - graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); - dists[global_pos] = knn_list_frag.dist(); - } - } - __threadfence(); - if (loop_flag && lane_id == 0) { atomicExch(&locks[list_id * num_segments + segment_id], 0); } - } - } while (!loop_flag); -} - -template -__device__ ResultItem get_min_item(const Index_t id, - const int idx_in_list, - const Index_t* neighbs, - const DistData_t* distances, - const bool find_in_row = true) -{ - int lane_id = threadIdx.x % raft::warp_size(); - - static_assert(MAX_NUM_BI_SAMPLES == 64); - int idx[MAX_NUM_BI_SAMPLES / raft::warp_size()]; - float dist[MAX_NUM_BI_SAMPLES / raft::warp_size()] = {std::numeric_limits::max(), - std::numeric_limits::max()}; - idx[0] = lane_id; - idx[1] = raft::warp_size() + lane_id; - - if (neighbs[idx[0]] != id) { - dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] - : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; - } - - if (neighbs[idx[1]] != id) { - dist[1] = - find_in_row - ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + raft::warp_size() + lane_id] - : distances[idx_in_list + (raft::warp_size() + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; - } - - if (dist[1] < dist[0]) { - dist[0] = dist[1]; - idx[0] = idx[1]; - } - __syncwarp(); - for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { - float other_idx = __shfl_down_sync(raft::warp_full_mask(), idx[0], offset); - float other_dist = __shfl_down_sync(raft::warp_full_mask(), dist[0], offset); - if (other_dist < dist[0]) { - dist[0] = other_dist; - idx[0] = other_idx; - } - } - - ResultItem result; - result.dist() = __shfl_sync(raft::warp_full_mask(), dist[0], 0); - result.id_with_flag() = neighbs[__shfl_sync(raft::warp_full_mask(), idx[0], 0)]; - return result; -} - -template -__device__ __forceinline__ void remove_duplicates( - T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) -{ - static_assert(raft::warp_size() == 32); - if (!(threadIdx.x >= execute_warp_id * raft::warp_size() && - threadIdx.x < execute_warp_id * raft::warp_size() + raft::warp_size())) { - return; - } - int lane_id = threadIdx.x % raft::warp_size(); - T elem = std::numeric_limits::max(); - if (lane_id < list_a_size) { elem = list_a[lane_id]; } - warp_bitonic_sort(&elem, lane_id); - - if (elem != std::numeric_limits::max()) { list_a[lane_id] = elem; } - - T elem_b = std::numeric_limits::max(); - - if (lane_id < list_b_size) { elem_b = list_b[lane_id]; } - __syncwarp(); - - int idx_l = 0; - int idx_r = list_a_size; - bool existed = false; - while (idx_l < idx_r) { - int idx = (idx_l + idx_r) / 2; - int elem = list_a[idx]; - if (elem == elem_b) { - existed = true; - break; - } - if (elem_b > elem) { - idx_l = idx + 1; - } else { - idx_r = idx; - } - } - if (!existed && elem_b != std::numeric_limits::max()) { - int idx = atomicAdd(&unique_counter, 1); - list_a[list_a_size + idx] = elem_b; - } -} - -// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 -// Per -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, -// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 -// For architectures 750 and 860, the values for MAX_RESIDENT_THREAD_PER_SM -// is 1024 and 1536 respectively, which means the bounds don't work anymore -template > -RAFT_KERNEL -#ifdef __CUDA_ARCH__ -#if (__CUDA_ARCH__) == 750 || (__CUDA_ARCH__) == 860 -__launch_bounds__(BLOCK_SIZE) -#else -__launch_bounds__(BLOCK_SIZE, 4) -#endif -#endif - local_join_kernel(const Index_t* graph_new, - const Index_t* rev_graph_new, - const int2* sizes_new, - const Index_t* graph_old, - const Index_t* rev_graph_old, - const int2* sizes_old, - const int width, - const __half* data, - const int data_dim, - ID_t* graph, - DistData_t* dists, - int graph_width, - int* locks, - DistData_t* l2_norms) -{ -#if (__CUDA_ARCH__ >= 700) - using namespace nvcuda; - __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; - - constexpr int APAD = 8; - constexpr int BPAD = 8; - __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors - __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors - static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= - sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); - // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov - float* s_distances = (float*)&s_ov[0][0]; - int* s_unique_counter = (int*)&s_ov[0][0]; - - if (threadIdx.x == 0) { - s_unique_counter[0] = 0; - s_unique_counter[1] = 0; - } - - Index_t* new_neighbors = s_list; - Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; - - size_t list_id = blockIdx.x; - int2 list_new_size2 = sizes_new[list_id]; - int list_new_size = list_new_size2.x + list_new_size2.y; - int2 list_old_size2 = sizes_old[list_id]; - int list_old_size = list_old_size2.x + list_old_size2.y; - - if (!list_new_size) return; - int tx = threadIdx.x; - - if (tx < list_new_size2.x) { - new_neighbors[tx] = graph_new[list_id * width + tx]; - } else if (tx >= list_new_size2.x && tx < list_new_size) { - new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; - } - - if (tx < list_old_size2.x) { - old_neighbors[tx] = graph_old[list_id * width + tx]; - } else if (tx >= list_old_size2.x && tx < list_old_size) { - old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; - } - - __syncthreads(); - - remove_duplicates(new_neighbors, - list_new_size2.x, - new_neighbors + list_new_size2.x, - list_new_size2.y, - s_unique_counter[0], - 0); - - remove_duplicates(old_neighbors, - list_old_size2.x, - old_neighbors + list_old_size2.x, - list_old_size2.y, - s_unique_counter[1], - 1); - __syncthreads(); - list_new_size = list_new_size2.x + s_unique_counter[0]; - list_old_size = list_old_size2.x + s_unique_counter[1]; - - int warp_id = threadIdx.x / raft::warp_size(); - int lane_id = threadIdx.x % raft::warp_size(); - constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); - - int warp_id_y = warp_id / 4; - int warp_id_x = warp_id % 4; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } - } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, - c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, - wmma::mem_row_major); - __syncthreads(); - - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && - i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { - if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; - } else { - s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; - } - } else { - s_distances[i] = std::numeric_limits::max(); - } - } - __syncthreads(); - - for (int step = 0; step < raft::ceildiv(list_new_size, num_warps); step++) { - int idx_in_list = step * num_warps + tx / raft::warp_size(); - if (idx_in_list >= list_new_size) continue; - auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } - } - - if (!list_old_size) return; - - __syncthreads(); - - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; - if (TILE_COL_WIDTH < data_dim) { -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - } -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_old_size) { - size_t neighbor_id = old_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_ov[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } - } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, - c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, - wmma::mem_row_major); - __syncthreads(); - - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && - i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { - if (l2_norms == nullptr) { - s_distances[i] = -s_distances[i]; - } else { - s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + - l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - - 2.0 * s_distances[i]; - } - } else { - s_distances[i] = std::numeric_limits::max(); - } - } - __syncthreads(); - - for (int step = 0; step < raft::ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { - int idx_in_list = step * num_warps + tx / raft::warp_size(); - if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; - if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) - continue; - ResultItem min_elem{std::numeric_limits::max(), - std::numeric_limits::max()}; - if (idx_in_list < MAX_NUM_BI_SAMPLES) { - auto temp_min_item = - get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); - if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } - } else { - auto temp_min_item = get_min_item( - s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); - if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } - } - - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } - } -#endif -} - -namespace { -template -int insert_to_ordered_list(InternalID_t* list, - DistData_t* dist_list, - const int width, - const InternalID_t neighb_id, - const DistData_t dist) -{ - if (dist > dist_list[width - 1]) { return width; } - - int idx_insert = width; - bool position_found = false; - for (int i = 0; i < width; i++) { - if (list[i].id() == neighb_id.id()) { return width; } - if (!position_found && dist_list[i] > dist) { - idx_insert = i; - position_found = true; - } - } - if (idx_insert == width) return idx_insert; - - memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); - memmove(dist_list + idx_insert + 1, - dist_list + idx_insert, - sizeof(*dist_list) * (width - idx_insert - 1)); - - list[idx_insert] = neighb_id; - dist_list[idx_insert] = dist; - return idx_insert; -}; - -} // namespace - -template -GnndGraph::GnndGraph(const size_t nrow, - const size_t node_degree, - const size_t internal_node_degree, - const size_t num_samples) - : nrow(nrow), - node_degree(node_degree), - num_samples(num_samples), - bloom_filter(nrow, internal_node_degree / segment_size, 3), - h_dists{raft::make_host_matrix(nrow, node_degree)}, - h_graph_new(nrow * num_samples), - h_list_sizes_new(nrow), - h_graph_old(nrow * num_samples), - h_list_sizes_old{nrow} -{ - // node_degree must be a multiple of segment_size; - assert(node_degree % segment_size == 0); - assert(internal_node_degree % segment_size == 0); - - num_segments = node_degree / segment_size; - // To save the CPU memory, graph should be allocated by external function - h_graph = nullptr; -} - -// This is the only operation on the CPU that cannot be overlapped. -// So it should be as fast as possible. -template -void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - auto list_new = h_graph_new.data() + i * num_samples; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j].id(); - if ((size_t)new_neighb_id >= nrow) break; - if (bloom_filter.check(i, new_neighb_id)) { continue; } - bloom_filter.add(i, new_neighb_id); - new_neighbors[i * width + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = new_neighb_id; - if (h_list_sizes_new[i].x == num_samples) break; - } - } -} - -template -void GnndGraph::init_random_graph() -{ - for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { - // random sequence (range: 0~nrow) - // segment_x stores neighbors which id % num_segments == x - std::vector rand_seq(nrow / num_segments); - std::iota(rand_seq.begin(), rand_seq.end(), 0); - auto gen = std::default_random_engine{seg_idx}; - std::shuffle(rand_seq.begin(), rand_seq.end(), gen); - -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - size_t base_idx = i * node_degree + seg_idx * segment_size; - auto h_neighbor_list = h_graph + base_idx; - auto h_dist_list = h_dists.data_handle() + base_idx; - for (size_t j = 0; j < static_cast(segment_size); j++) { - size_t idx = base_idx + j; - Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; - if ((size_t)id == i) { - id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; - } - h_neighbor_list[j].id_with_flag() = id; - h_dist_list[j] = std::numeric_limits::max(); - } - } - } -} - -template -void GnndGraph::sample_graph(bool sample_new) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - h_list_sizes_old[i].x = 0; - h_list_sizes_old[i].y = 0; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - auto list = h_graph + i * node_degree; - auto list_old = h_graph_old.data() + i * num_samples; - auto list_new = h_graph_new.data() + i * num_samples; - for (int j = 0; j < segment_size; j++) { - for (int k = 0; k < num_segments; k++) { - auto neighbor = list[k * segment_size + j]; - if ((size_t)neighbor.id() >= nrow) continue; - if (!neighbor.is_new()) { - if (h_list_sizes_old[i].x < num_samples) { - list_old[h_list_sizes_old[i].x++] = neighbor.id(); - } - } else if (sample_new) { - if (h_list_sizes_new[i].x < num_samples) { - list[k * segment_size + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = neighbor.id(); - } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } - } - } -} - -template -void GnndGraph::update_graph(const InternalID_t* new_neighbors, - const DistData_t* new_dists, - const size_t width, - std::atomic& update_counter) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j]; - auto new_dist = new_dists[i * width + j]; - if (new_dist == std::numeric_limits::max()) break; - if ((size_t)new_neighb_id.id() == i) continue; - int seg_idx = new_neighb_id.id() % num_segments; - auto list = h_graph + i * node_degree + seg_idx * segment_size; - auto dist_list = h_dists.data_handle() + i * node_degree + seg_idx * segment_size; - int insert_pos = - insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); - if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } - } - } -} - -template -void GnndGraph::sort_lists() -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - std::vector> new_list; - for (size_t j = 0; j < node_degree; j++) { - new_list.emplace_back(h_dists.data_handle()[i * node_degree + j], - h_graph[i * node_degree + j].id()); - } - std::sort(new_list.begin(), new_list.end()); - for (size_t j = 0; j < node_degree; j++) { - h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; - h_dists.data_handle()[i * node_degree + j] = new_list[j].first; - } - } -} - -template -void GnndGraph::clear() -{ - bloom_filter.clear(); -} - -template -GnndGraph::~GnndGraph() -{ - assert(h_graph == nullptr); -} - -template -GNND::GNND(raft::resources const& res, const BuildConfig& build_config) - : res(res), - build_config_(build_config), - graph_(build_config.max_dataset_size, - align32::roundUp(build_config.node_degree), - align32::roundUp(build_config.internal_node_degree ? build_config.internal_node_degree - : build_config.node_degree), - NUM_SAMPLES), - nrow_(build_config.max_dataset_size), - ndim_(build_config.dataset_dim), - d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>( - res, nrow_, build_config.dataset_dim)}, - l2_norms_{raft::make_device_vector(res, nrow_)}, - graph_buffer_{ - raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, - dists_buffer_{ - raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, - graph_host_buffer_(nrow_ * DEGREE_ON_DEVICE), - dists_host_buffer_(nrow_ * DEGREE_ON_DEVICE), - d_locks_{raft::make_device_vector(res, nrow_)}, - h_rev_graph_new_(nrow_ * NUM_SAMPLES), - h_graph_old_(nrow_ * NUM_SAMPLES), - h_rev_graph_old_(nrow_ * NUM_SAMPLES), - d_list_sizes_new_{raft::make_device_vector(res, nrow_)}, - d_list_sizes_old_{raft::make_device_vector(res, nrow_)} -{ - static_assert(NUM_SAMPLES <= 32); - - thrust::fill(thrust::device, - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, - reinterpret_cast(graph_buffer_.data_handle()), - reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); -}; - -template -void GNND::add_reverse_edges(Index_t* graph_ptr, - Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, - int2* list_sizes, - cudaStream_t stream) -{ - add_rev_edges_kernel<<>>( - graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); - raft::copy( - h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); -} - -template -void GNND::local_join(cudaStream_t stream) -{ - thrust::fill(thrust::device.on(stream), - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - local_join_kernel<<>>( - thrust::raw_pointer_cast(graph_.h_graph_new.data()), - thrust::raw_pointer_cast(h_rev_graph_new_.data()), - d_list_sizes_new_.data_handle(), - thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(h_rev_graph_old_.data()), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_.data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle()); -} - -template -void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) -{ - using input_t = typename std::remove_const::type; - - cudaStream_t stream = raft::resource::get_cuda_stream(res); - nrow_ = nrow; - graph_.h_graph = (InternalID_t*)output_graph; - - cudaPointerAttributes data_ptr_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); - size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches{ - data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; - for (auto const& batch : vec_batches) { - preprocess_data_kernel<<(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_.data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset()); - } - - thrust::fill(thrust::device.on(stream), - (Index_t*)graph_buffer_.data_handle(), - (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), - std::numeric_limits::max()); - - graph_.clear(); - graph_.init_random_graph(); - graph_.sample_graph(true); - - auto update_and_sample = [&](bool update_graph) { - if (update_graph) { - update_counter_ = 0; - graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), - thrust::raw_pointer_cast(dists_host_buffer_.data()), - DEGREE_ON_DEVICE, - update_counter_); - if (update_counter_ < build_config_.termination_threshold * nrow_ * - build_config_.dataset_dim / counter_interval) { - update_counter_ = -1; - } - } - graph_.sample_graph(false); - }; - - for (size_t it = 0; it < build_config_.max_iterations; it++) { - raft::copy(d_list_sizes_new_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), - nrow_, - raft::resource::get_cuda_stream(res)); - raft::copy(thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(graph_.h_graph_old.data()), - nrow_ * NUM_SAMPLES, - raft::resource::get_cuda_stream(res)); - raft::copy(d_list_sizes_old_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), - nrow_, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - - std::thread update_and_sample_thread(update_and_sample, it); - - RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations); - - // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it - // contains some information for local_join. - static_assert(DEGREE_ON_DEVICE * sizeof(*(dists_buffer_.data_handle())) >= - NUM_SAMPLES * sizeof(*(graph_buffer_.data_handle()))); - add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), - thrust::raw_pointer_cast(h_rev_graph_new_.data()), - (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_new_.data_handle(), - stream); - add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(h_rev_graph_old_.data()), - (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_old_.data_handle(), - stream); - - // Tensor operations from `mma.h` are guarded with archicteture - // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, - // we need to ensure that `local_join_kernel` (which uses tensor) operations - // is not only not compiled, but also a runtime error is presented to the user - auto kernel = preprocess_data_kernel; - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); - auto wmma_range = - raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); - - if (wmma_range.contains(runtime_arch)) { - local_join(stream); - } else { - THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); - } - - update_and_sample_thread.join(); - - if (update_counter_ == -1) { break; } - raft::copy(thrust::raw_pointer_cast(graph_host_buffer_.data()), - graph_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - raft::copy(thrust::raw_pointer_cast(dists_host_buffer_.data()), - dists_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); - - graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); - } - - graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), - thrust::raw_pointer_cast(dists_host_buffer_.data()), - DEGREE_ON_DEVICE, - update_counter_); - raft::resource::sync_stream(res); - graph_.sort_lists(); - - // Reuse graph_.h_dists as the buffer for shrink the lists in graph - static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); - Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); - -#pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - size_t idx = i * graph_.node_degree + j; - int id = graph_.h_graph[idx].id(); - if (id < static_cast(nrow_)) { - graph_shrink_buffer[i * build_config_.node_degree + j] = id; - } else { - graph_shrink_buffer[i * build_config_.node_degree + j] = - cuvs::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; - } - } - } - graph_.h_graph = nullptr; - -#pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - output_graph[i * build_config_.node_degree + j] = - graph_shrink_buffer[i * build_config_.node_degree + j]; - } - } -} - -template , memory_type::host>> -void build(raft::resources const& res, - const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset, - index& idx) -{ - RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, - "The dataset size for GNND should be less than %d", - std::numeric_limits::max() - 1); - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - // The elements in each knn-list are partitioned into different buckets, and we need more buckets - // to mitigate bucket collisions. `intermediate_degree` is OK to larger than - // extended_graph_degree. - size_t extended_graph_degree = - align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); - size_t extended_intermediate_degree = align32::roundUp( - static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); - - auto int_graph = raft::make_host_matrix( - dataset.extent(0), static_cast(extended_graph_degree)); - - BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), - .dataset_dim = static_cast(dataset.extent(1)), - .node_degree = extended_graph_degree, - .internal_node_degree = extended_intermediate_degree, - .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold}; - - GNND nnd(res, build_config); - nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); - -#pragma omp parallel for - for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { - for (size_t j = 0; j < graph_degree; j++) { - auto graph = idx.graph().data_handle(); - graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; - } - } -} - -template , memory_type::host>> -index build( - raft::resources const& res, - const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset) -{ - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - index idx{res, dataset.extent(0), static_cast(graph_degree)}; - - build(res, params, dataset, idx); - - return idx; -} - -} // namespace cuvs::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/cuvs/neighbors/detail/refine.cuh b/cpp/include/cuvs/neighbors/detail/refine.cuh deleted file mode 100644 index 170f97398..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine.cuh +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "refine_device.cuh" -#include "refine_host.hpp" diff --git a/cpp/include/cuvs/neighbors/detail/refine_common.hpp b/cpp/include/cuvs/neighbors/detail/refine_common.hpp deleted file mode 100644 index 3def36a39..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine_common.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace cuvs::neighbors::detail { - -/** Checks whether the input data extents are compatible. */ -template -void refine_check_input(ExtentsT dataset, - ExtentsT queries, - ExtentsT candidates, - ExtentsT indices, - ExtentsT distances, - distance::DistanceType metric) -{ - auto n_queries = queries.extent(0); - auto k = distances.extent(1); - - RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && - candidates.extent(0) == n_queries, - "Number of rows in output indices, distances and candidates matrices must be equal" - " with the number of rows in search matrix. Expected %d, got %d, %d, and %d", - static_cast(n_queries), - static_cast(indices.extent(0)), - static_cast(distances.extent(0)), - static_cast(candidates.extent(0))); - - RAFT_EXPECTS(indices.extent(1) == k, - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), - "Number of columns must be equal for dataset and queries"); - - RAFT_EXPECTS(candidates.extent(1) >= k, - "Number of neighbor candidates must not be smaller than k (%d vs %d)", - static_cast(candidates.extent(1)), - static_cast(k)); -} - -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/refine_device.cuh b/cpp/include/cuvs/neighbors/detail/refine_device.cuh deleted file mode 100644 index 5bc478702..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine_device.cuh +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace cuvs::neighbors::detail { - -/** - * See cuvs::neighbors::refine for docs. - */ -template -void refine_device( - raft::resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - matrix_idx n_candidates = neighbor_candidates.extent(1); - matrix_idx n_queries = queries.extent(0); - matrix_idx dim = dataset.extent(1); - uint32_t k = static_cast(indices.extent(1)); - - RAFT_EXPECTS(k <= raft::matrix::detail::select::warpsort::kMaxCapacity, - "k must be lest than topk::kMaxCapacity (%d).", - raft::matrix::detail::select::warpsort::kMaxCapacity); - - raft::common::nvtx::range fun_scope( - "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); - - refine_check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - // The refinement search can be mapped to an IVF flat search: - // - We consider that the candidate vectors form a cluster, separately for each query. - // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with - // n_candidates elements. - // - We consider that the coarse level search is already performed and assigned a single cluster - // to search for each query (the cluster formed from the corresponding candidates). - // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. - rmm::device_uvector fake_coarse_idx(n_queries, resource::get_cuda_stream(handle)); - - thrust::sequence(raft::resource::get_thrust_policy(handle), - fake_coarse_idx.data(), - fake_coarse_idx.data() + n_queries); - - cuvs::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, true, dim); - - cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); - uint32_t grid_dim_x = 1; - cuvs::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< - data_t, - typename cuvs::spatial::knn::detail::utils::config::value_t, - idx_t>(refinement_index, - queries.data_handle(), - fake_coarse_idx.data(), - static_cast(n_queries), - 0, - refinement_index.metric(), - 1, - k, - cuvs::distance::is_min_close(metric), - cuvs::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), - distances.data_handle(), - grid_dim_x, - resource::get_cuda_stream(handle)); -} - -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/refine_host-ext.hpp b/cpp/include/cuvs/neighbors/detail/refine_host-ext.hpp deleted file mode 100644 index c2dcdd91f..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine_host-ext.hpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // int64_t - -#include // cuvs::distance::DistanceType -#include // raft::host_matrix_view -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::detail { - -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) RAFT_EXPLICIT; - -} - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ - extern template void cuvs::neighbors::detail::refine_host( \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric); - -instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); - -#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/cuvs/neighbors/detail/refine_host-inl.hpp b/cpp/include/cuvs/neighbors/detail/refine_host-inl.hpp deleted file mode 100644 index c753e56f7..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine_host-inl.hpp +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace cuvs::neighbors::detail { - -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances) -{ - size_t n_queries = queries.extent(0); - size_t n_rows = dataset.extent(0); - size_t dim = dataset.extent(1); - size_t orig_k = neighbor_candidates.extent(1); - size_t refined_k = indices.extent(1); - - raft::common::nvtx::range fun_scope( - "neighbors::refine_host(%zu, %zu -> %zu)", n_queries, orig_k, refined_k); - - auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); - if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } - -#pragma omp parallel num_threads(suggested_n_threads) - { - std::vector> refined_pairs(orig_k); - for (size_t i = omp_get_thread_num(); i < n_queries; i += omp_get_num_threads()) { - // Compute the refined distance using original dataset vectors - const DataT* query = queries.data_handle() + dim * i; - for (size_t j = 0; j < orig_k; j++) { - IdxT id = neighbor_candidates(i, j); - DistanceT distance = 0.0; - if (static_cast(id) >= n_rows) { - distance = std::numeric_limits::max(); - } else { - const DataT* row = dataset.data_handle() + dim * id; - for (size_t k = 0; k < dim; k++) { - distance += DC::template eval(query[k], row[k]); - } - } - refined_pairs[j] = std::make_tuple(distance, id); - } - // Sort the query neighbors by their refined distances - std::sort(refined_pairs.begin(), refined_pairs.end()); - // Store first refined_k neighbors - for (size_t j = 0; j < refined_k; j++) { - indices(i, j) = std::get<1>(refined_pairs[j]); - if (distances.data_handle() != nullptr) { - distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[j])); - } - } - } - } -} - -struct distance_comp_l2 { - template - static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT - { - auto d = a - b; - return d * d; - } - template - static inline auto postprocess(const DistanceT& a) -> DistanceT - { - return a; - } -}; - -struct distance_comp_inner { - template - static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT - { - return -a * b; - } - template - static inline auto postprocess(const DistanceT& a) -> DistanceT - { - return -a; - } -}; - -/** - * Naive CPU implementation of refine operation - * - * All pointers are expected to be accessible on the host. - */ -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - refine_check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - switch (metric) { - case cuvs::distance::DistanceType::L2Expanded: - return refine_host_impl( - dataset, queries, neighbor_candidates, indices, distances); - case cuvs::distance::DistanceType::InnerProduct: - return refine_host_impl( - dataset, queries, neighbor_candidates, indices, distances); - default: throw raft::logic_error("Unsupported metric"); - } -} - -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/refine_host.hpp b/cpp/include/cuvs/neighbors/detail/refine_host.hpp deleted file mode 100644 index ff0de7566..000000000 --- a/cpp/include/cuvs/neighbors/detail/refine_host.hpp +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "refine_host-inl.hpp" -#endif - -#ifdef RAFT_COMPILED -#include "refine_host-ext.hpp" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/selection_faiss-ext.cuh b/cpp/include/cuvs/neighbors/detail/selection_faiss-ext.cuh deleted file mode 100644 index e123f81e7..000000000 --- a/cpp/include/cuvs/neighbors/detail/selection_faiss-ext.cuh +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // size_t -#include // uint32_t -#include // __half -#include // kFaissMaxK -#include // RAFT_EXPLICIT - -#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) - -namespace cuvs::neighbors::detail { - -template -void select_k(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) RAFT_EXPLICIT; -}; // namespace cuvs::neighbors::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ - extern template void cuvs::neighbors::detail::select_k(const key_t* inK, \ - const payload_t* inV, \ - size_t n_rows, \ - size_t n_cols, \ - key_t* outK, \ - payload_t* outV, \ - bool select_min, \ - int k, \ - cudaStream_t stream) - -instantiate_raft_neighbors_detail_select_k(uint32_t, float); -instantiate_raft_neighbors_detail_select_k(int32_t, float); -instantiate_raft_neighbors_detail_select_k(long, float); -instantiate_raft_neighbors_detail_select_k(size_t, double); -// test/neighbors/selection.cu -instantiate_raft_neighbors_detail_select_k(int, double); -instantiate_raft_neighbors_detail_select_k(size_t, float); - -instantiate_raft_neighbors_detail_select_k(uint32_t, double); -instantiate_raft_neighbors_detail_select_k(int64_t, double); -instantiate_raft_neighbors_detail_select_k(uint32_t, __half); -instantiate_raft_neighbors_detail_select_k(int64_t, __half); - -#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/include/cuvs/neighbors/detail/selection_faiss-inl.cuh b/cpp/include/cuvs/neighbors/detail/selection_faiss-inl.cuh deleted file mode 100644 index f10339485..000000000 --- a/cpp/include/cuvs/neighbors/detail/selection_faiss-inl.cuh +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include // kFaissMaxK - -namespace cuvs::neighbors::detail { - -template -RAFT_KERNEL select_k_kernel(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - key_t initK, - payload_t initV, - int k) -{ - using align_warp = raft::Pow2; - constexpr int kNumWarps = align_warp::div(tpb); - - __shared__ key_t smemK[kNumWarps * warp_q]; - __shared__ payload_t smemV[kNumWarps * warp_q]; - - faiss_select::BlockSelect, - warp_q, - thread_q, - tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - { - size_t i = size_t(threadIdx.x); - - inK += row * n_cols; - if (inV != nullptr) { inV += row * n_cols; } - - // Whole warps must participate in the selection - size_t limit = align_warp::roundDown(n_cols); - - for (; i < limit; i += tpb) { - heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); - } - - // Handle last remainder fraction of a warp of elements - if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void select_k_impl(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - auto grid = dim3(n_rows); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = select_min ? raft::upper_bound() : lower_bound(); - auto vInit = -1; - if (select_min) { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } else { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } - RAFT_CUDA_TRY(cudaGetLastError()); -} - -/** - * @brief Select the k-nearest neighbors from dense - * distance and index matrices. - * - * @param[in] inK partitioned knn distance matrix - * @param[in] inV partitioned knn index matrix - * @param[in] n_rows number of rows in distance and index matrices - * @param[in] n_cols number of columns in distance and index matrices - * @param[out] outK merged knn distance matrix - * @param[out] outV merged knn index matrix - * @param[in] select_min whether to select the min or the max distances - * @param[in] k number of neighbors per partition (also number of merged neighbors) - * @param[in] stream CUDA stream to use - */ -template -inline void select_k(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - constexpr int max_k = kFaissMaxK(); - if (k == 1) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 32) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 64) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 128) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 256) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 512) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 1024 && k <= max_k) - // note: have to use constexpr std::min here to avoid instantiating templates - // for parameters we don't support - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 2048 && k <= max_k) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else - ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); -} -}; // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/detail/selection_faiss.cuh b/cpp/include/cuvs/neighbors/detail/selection_faiss.cuh deleted file mode 100644 index dd229b37e..000000000 --- a/cpp/include/cuvs/neighbors/detail/selection_faiss.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "selection_faiss-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "selection_faiss-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/detail/selection_faiss_helpers.cuh b/cpp/include/cuvs/neighbors/detail/selection_faiss_helpers.cuh deleted file mode 100644 index bbe4752d2..000000000 --- a/cpp/include/cuvs/neighbors/detail/selection_faiss_helpers.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace cuvs::neighbors::detail { - -// This function is used in cpp/test/neighbors/select.cu. We want to make it -// available through both the selection_faiss-inl.cuh and -// selection_faiss-ext.cuh headers. -template -constexpr int kFaissMaxK() -{ - if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } - return 2048; -} - -} // namespace cuvs::neighbors::detail diff --git a/cpp/include/cuvs/neighbors/epsilon_neighborhood.cuh b/cpp/include/cuvs/neighbors/epsilon_neighborhood.cuh deleted file mode 100644 index dfa300c22..000000000 --- a/cpp/include/cuvs/neighbors/epsilon_neighborhood.cuh +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __EPSILON_NEIGH_H -#define __EPSILON_NEIGH_H - -#pragma once - -#include -#include -#include -#include - -namespace cuvs::neighbors::epsilon_neighborhood { - -/** - * @brief Computes epsilon neighborhood for the L2-Squared distance metric - * - * @tparam value_t IO and math type - * @tparam idx_t Index type - * - * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] - * @param[out] vd vertex degree array [on device] [len = m + 1] - * `vd + m` stores the total number of edges in the adjacency - * matrix. Pass a nullptr if you don't need this info. - * @param[in] x first matrix [row-major] [on device] [dim = m x k] - * @param[in] y second matrix [row-major] [on device] [dim = n x k] - * @param[in] m number of rows in x - * @param[in] n number of rows in y - * @param[in] k number of columns in x and k - * @param[in] eps defines epsilon neighborhood radius (should be passed as - * squared as we compute L2-squared distance in this method) - * @param[in] stream cuda stream - */ -template -void epsUnexpL2SqNeighborhood(bool* adj, - idx_t* vd, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - value_t eps, - cudaStream_t stream) -{ - spatial::knn::detail::epsUnexpL2SqNeighborhood( - adj, vd, x, y, m, n, k, eps, stream); -} - -/** - * @defgroup epsilon_neighbors Epislon Neighborhood Operations - * @{ - */ - -/** - * @brief Computes epsilon neighborhood for the L2-Squared distance metric and given ball size. - * The epsilon neighbors is represented by a dense boolean adjacency matrix of size m * n and - * an array of degrees for each vertex, which can be used as a compressed sparse row (CSR) - * indptr array. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::neighbors; - * raft::raft::resources handle; - * ... - * auto adj = raft::make_device_matrix(handle, m * n); - * auto vd = raft::make_device_vector(handle, m+1); - * epsilon_neighborhood::eps_neighbors_l2sq(handle, x, y, adj.view(), vd.view(), eps); - * @endcode - * - * @tparam value_t IO and math type - * @tparam idx_t Index type - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle raft handle to manage library resources - * @param[in] x first matrix [row-major] [on device] [dim = m x k] - * @param[in] y second matrix [row-major] [on device] [dim = n x k] - * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] - * @param[out] vd vertex degree array [on device] [len = m + 1] - * `vd + m` stores the total number of edges in the adjacency - * matrix. Pass a nullptr if you don't need this info. - * @param[in] eps defines epsilon neighborhood radius (should be passed as - * squared as we compute L2-squared distance in this method) - */ -template -void eps_neighbors_l2sq(raft::resources const& handle, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_matrix_view adj, - raft::device_vector_view vd, - value_t eps) -{ - epsUnexpL2SqNeighborhood(adj.data_handle(), - vd.data_handle(), - x.data_handle(), - y.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - eps, - resource::get_cuda_stream(handle)); -} - -/** @} */ // end group epsilon_neighbors - -} // namespace cuvs::neighbors::epsilon_neighborhood - -#endif \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/ivf_flat-ext.cuh b/cpp/include/cuvs/neighbors/ivf_flat-ext.cuh deleted file mode 100644 index 3b66a589b..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat-ext.cuh +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // int64_t - -#include -#include // cuvs::neighbors::ivf_flat::index -#include // raft::device_matrix_view -#include // raft::resources -#include // RAFT_EXPLICIT -#include // rmm::mr::device_memory_resource - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace cuvs::neighbors::ivf_flat { - -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index RAFT_EXPLICIT; - -template -auto build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) - -> index RAFT_EXPLICIT; - -template -void build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - cuvs::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; - -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index RAFT_EXPLICIT; - -template -auto extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index RAFT_EXPLICIT; - -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) RAFT_EXPLICIT; - -template -void extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) RAFT_EXPLICIT; - -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; - -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - -} // namespace cuvs::neighbors::ivf_flat - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ - extern template auto cuvs::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->cuvs::neighbors::ivf_flat::index; \ - \ - extern template auto cuvs::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index; \ - \ - extern template void cuvs::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx); - -instantiate_raft_neighbors_ivf_flat_build(float, int64_t); -instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); -#undef instantiate_raft_neighbors_ivf_flat_build - -#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ - extern template auto cuvs::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->cuvs::neighbors::ivf_flat::index; \ - \ - extern template auto cuvs::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index; \ - \ - extern template void cuvs::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - cuvs::neighbors::ivf_flat::index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ - \ - extern template void cuvs::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* index); - -instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); -instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); - -#undef instantiate_raft_neighbors_ivf_flat_extend - -#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ - extern template void cuvs::neighbors::ivf_flat::search( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::search_params& params, \ - const cuvs::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr); \ - \ - extern template void cuvs::neighbors::ivf_flat::search( \ - raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::search_params& params, \ - const cuvs::neighbors::ivf_flat::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -instantiate_raft_neighbors_ivf_flat_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); - -#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/include/cuvs/neighbors/ivf_flat-inl.cuh b/cpp/include/cuvs/neighbors/ivf_flat-inl.cuh deleted file mode 100644 index d956f060c..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat-inl.cuh +++ /dev/null @@ -1,602 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace cuvs::neighbors::ivf_flat { - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index -{ - return cuvs::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * ivf_flat::index index; - * ivf_flat::build(handle, dataset, index_params, index); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] - * @param[out] idx reference to ivf_flat::index - * - */ -template -void build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - cuvs::neighbors::ivf_flat::index& idx) -{ - idx = cuvs::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** @} */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return cuvs::neighbors::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] orig_index original index - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index -{ - return extend(handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); -} - -/** @} */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - cuvs::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace cuvs::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * ivf_flat::extend(handle, dataset, no_opt, &index_empty); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] index pointer to index, to be overwritten in-place - */ -template -void extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) -{ - extend(handle, - index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); -} - -/** @} */ - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * filtering::none_ivf_sample_filter filter; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or - * `(uint32_t query_ix, uint32 sample_ix) -> bool` - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - cuvs::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); -} - -/** - * @brief Search ANN using the constructed index using the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - cuvs::neighbors::ivf_flat::detail::search(handle, - params, - index, - queries, - n_queries, - k, - neighbors, - distances, - mr, - cuvs::neighbors::filtering::none_ivf_sample_filter()); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * filtering::none_ivf_sample_filter filter; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries1, out_inds1, out_dists1, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries2, out_inds2, out_dists2, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries3, out_inds3, out_dists3, filter); - * ... - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or - * `(uint32_t query_ix, uint32 sample_ix) -> bool` - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - search_with_filtering(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - resource::get_workspace_resource(handle), - sample_filter); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); - * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); - * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); - * ... - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - search_with_filtering(handle, - params, - index, - queries, - neighbors, - distances, - cuvs::neighbors::filtering::none_ivf_sample_filter()); -} - -/** @} */ - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/include/cuvs/neighbors/ivf_flat.cuh b/cpp/include/cuvs/neighbors/ivf_flat.cuh deleted file mode 100644 index 8fd9628a4..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ivf_flat-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat-ext.cuh" -#endif diff --git a/cpp/include/cuvs/neighbors/ivf_flat_codepacker.hpp b/cpp/include/cuvs/neighbors/ivf_flat_codepacker.hpp deleted file mode 100644 index 9f1b43380..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat_codepacker.hpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace cuvs::neighbors::ivf_flat::codepacker { - -/** - * Write one flat code into a block by the given offset. The offset indicates the id of the record - * in the list. This function interleaves the code and is intended to later copy the interleaved - * codes over to the IVF list on device. NB: no memory allocation happens here; the block must fit - * the record (offset + 1). - * - * @tparam T - * - * @param[in] flat_code input flat code - * @param[out] block block of memory to write interleaved codes to - * @param[in] dim dimension of the flat code - * @param[in] veclen size of interleaved data chunks - * @param[in] offset how many records to skip before writing the data into the list - */ -template -_RAFT_HOST_DEVICE void pack_1( - const T* flat_code, T* block, uint32_t dim, uint32_t veclen, uint32_t offset) -{ - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = neighbors::detail::div_utils; - - // Interleave dimensions of the source vector while recording it. - // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = interleaved_group::roundDown(offset); - auto ingroup_id = interleaved_group::mod(offset) * veclen; - - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - block[group_offset * dim + l * kIndexGroupSize + ingroup_id + j] = flat_code[l + j]; - } - } -} - -/** - * Unpack 1 record of a single list (cluster) in the index to fetch the flat code. The offset - * indicates the id of the record. This function fetches one flat code from an interleaved code. - * - * @tparam T - * - * @param[in] block interleaved block. The block can be thought of as the whole inverted list in - * interleaved format. - * @param[out] flat_code output flat code - * @param[in] dim dimension of the flat code - * @param[in] veclen size of interleaved data chunks - * @param[in] offset fetch the flat code by the given offset - */ -template -_RAFT_HOST_DEVICE void unpack_1( - const T* block, T* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset) -{ - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = neighbors::detail::div_utils; - - // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = interleaved_group::roundDown(offset); - auto ingroup_id = interleaved_group::mod(offset) * veclen; - - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - flat_code[l + j] = block[group_offset * dim + l * kIndexGroupSize + ingroup_id + j]; - } - } -} -} // namespace cuvs::neighbors::ivf_flat::codepacker \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/ivf_flat_helpers.cuh b/cpp/include/cuvs/neighbors/ivf_flat_helpers.cuh deleted file mode 100644 index cca83cea0..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat_helpers.cuh +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include - -#include - -namespace cuvs::neighbors::ivf_flat::helpers { -using namespace cuvs::spatial::knn::detail; // NOLINT -/** - * @defgroup ivf_flat_helpers Helper functions for manipulationg IVF Flat Index - * @{ - */ - -namespace codepacker { - -/** - * Write flat codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_vec). - * - * Usage example: - * @code{.cpp} - * auto list_data = index.lists()[label]->data.view(); - * // allocate the buffer for the input codes - * auto codes = raft::make_device_matrix(res, n_vec, index.dim()); - * ... prepare n_vecs to pack into the list in codes ... - * // write codes into the list starting from the 42nd position - * ivf_pq::helpers::codepacker::pack( - * res, make_const_mdspan(codes.view()), index.veclen(), 42, list_data); - * @endcode - * - * @tparam T - * @tparam IdxT - * - * @param[in] res - * @param[in] codes flat codes [n_vec, dim] - * @param[in] veclen size of interleaved data chunks - * @param[in] offset how many records to skip before writing the data into the list - * @param[inout] list_data block to write into - */ -template -void pack( - raft::resources const& res, - raft::device_matrix_view codes, - uint32_t veclen, - uint32_t offset, - raft::device_mdspan::list_extents, raft::row_major> - list_data) -{ - cuvs::neighbors::ivf_flat::detail::pack_list_data(res, codes, veclen, offset, list_data); -} - -/** - * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index - * starting at given `offset`. - * - * Usage example: - * @code{.cpp} - * auto list_data = index.lists()[label]->data.view(); - * // allocate the buffer for the output - * uint32_t n_take = 4; - * auto codes = raft::make_device_matrix(res, n_take, index.dim()); - * uint32_t offset = 0; - * // unpack n_take elements from the list - * ivf_pq::helpers::codepacker::unpack(res, list_data, index.veclen(), offset, codes.view()); - * @endcode - * - * @tparam T - * @tparam IdxT - * - * @param[in] res raft resource - * @param[in] list_data block to read from - * @param[in] veclen size of interleaved data chunks - * @param[in] offset - * How many records in the list to skip. - * @param[inout] codes - * the destination buffer [n_take, index.dim()]. - * The length `n_take` defines how many records to unpack, - * it must be <= the list size. - */ -template -void unpack( - raft::resources const& res, - raft::device_mdspan::list_extents, raft::row_major> - list_data, - uint32_t veclen, - uint32_t offset, - raft::device_matrix_view codes) -{ - cuvs::neighbors::ivf_flat::detail::unpack_list_data( - res, list_data, veclen, offset, codes); -} -} // namespace codepacker - -/** - * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for - * externally modifying the index without going through the build stage. The data and indices of the - * IVF lists will be lost. - * - * Usage example: - * @code{.cpp} - * raft::resources res; - * using namespace cuvs::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // initialize an empty index - * ivf_flat::index index(res, index_params, D); - * // reset the index's state and list sizes - * ivf_flat::helpers::reset_index(res, &index); - * @endcode - * - * @tparam IdxT - * - * @param[in] res raft resource - * @param[inout] index pointer to IVF-PQ index - */ -template -void reset_index(const raft::resources& res, index* index) -{ - auto stream = resource::get_cuda_stream(res); - - utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); - utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); - utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); -} -/** @} */ -} // namespace cuvs::neighbors::ivf_flat::helpers diff --git a/cpp/include/cuvs/neighbors/ivf_flat_serialize.cuh b/cpp/include/cuvs/neighbors/ivf_flat_serialize.cuh deleted file mode 100644 index 37062ea68..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat_serialize.cuh +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/ivf_flat_serialize.cuh" - -namespace cuvs::neighbors::ivf_flat { - -/** - * \defgroup ivf_flat_serialize IVF-Flat Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = ivf_flat::build(...);` - * raft::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, std::ostream& os, const index& index) -{ - detail::serialize(handle, os, index); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = ivf_flat::build(...);` - * raft::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index) -{ - detail::serialize(handle, filename, index); -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return cuvs::neighbors::ivf_flat::index - */ -template -index deserialize(raft::resources const& handle, std::istream& is) -{ - return detail::deserialize(handle, is); -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return cuvs::neighbors::ivf_flat::index - */ -template -index deserialize(raft::resources const& handle, const std::string& filename) -{ - return detail::deserialize(handle, filename); -} - -/**@}*/ - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/include/cuvs/neighbors/ivf_flat_types.hpp b/cpp/include/cuvs/neighbors/ivf_flat_types.hpp deleted file mode 100644 index 28023f474..000000000 --- a/cpp/include/cuvs/neighbors/ivf_flat_types.hpp +++ /dev/null @@ -1,406 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include // std::max -#include -#include -#include - -namespace cuvs::neighbors::ivf_flat { -/** - * @addtogroup ivf_flat - * @{ - */ - -/** Size of the interleaved group (see `index::data` description). */ -constexpr static uint32_t kIndexGroupSize = 32; - -struct index_params : ann::index_params { - /** The number of inverted lists (clusters) */ - uint32_t n_lists = 1024; - /** The number of iterations searching for kmeans centers (index building). */ - uint32_t kmeans_n_iters = 20; - /** The fraction of data to use during iterative kmeans building. */ - double kmeans_trainset_fraction = 0.5; - /** - * By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`, - * and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index - * from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of - * which is no longer representative of the original training set. - * - * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new - * data when it is added. In this case, `index.centers()` are always exactly the centroids of the - * data in the corresponding clusters. The drawback of this behavior is that the centroids depend - * on the order of adding new data (through the classification of the added data); that is, - * `index.centers()` "drift" together with the changing distribution of the newly added data. - */ - bool adaptive_centers = false; - /** - * By default, the algorithm allocates more space than necessary for individual clusters - * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of - * data copies during repeated calls to `extend` (extending the database). - * - * The alternative is the conservative allocation behavior; when enabled, the algorithm always - * allocates the minimum amount of memory required to store the given number of records. Set this - * flag to `true` if you prefer to use as little GPU memory for the database as possible. - */ - bool conservative_memory_allocation = false; -}; - -struct search_params : ann::search_params { - /** The number of clusters to search. */ - uint32_t n_probes = 20; -}; - -static_assert(std::is_aggregate_v); -static_assert(std::is_aggregate_v); - -template -struct list_spec { - using value_type = ValueT; - using list_extents = raft::matrix_extent; - using index_type = IdxT; - - SizeT align_max; - SizeT align_min; - uint32_t dim; - - constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) - : dim(dim), - align_min(kIndexGroupSize), - align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) - { - } - - // Allow casting between different size-types (for safer size and offset calculations) - template - constexpr explicit list_spec(const list_spec& other_spec) - : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} - { - } - - /** Determine the extents of an array enough to hold a given amount of data. */ - constexpr auto make_list_extents(SizeT n_rows) const -> list_extents - { - return make_extents(n_rows, dim); - } -}; - -template -using list_data = ivf::list; - -/** - * @brief IVF-flat index. - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - */ -template -struct index : ann::index { - static_assert(!raft::is_narrowing_v, - "IdxT must be able to represent all values of uint32_t"); - - public: - /** - * Vectorized load/store size in elements, determines the size of interleaved data chunks. - * - * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum - * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711 - */ - [[nodiscard]] constexpr inline auto veclen() const noexcept -> uint32_t { return veclen_; } - /** Distance metric used for clustering. */ - [[nodiscard]] constexpr inline auto metric() const noexcept -> cuvs::distance::DistanceType - { - return metric_; - } - /** Whether `centers()` change upon extending the index (ivf_pq::extend). */ - [[nodiscard]] constexpr inline auto adaptive_centers() const noexcept -> bool - { - return adaptive_centers_; - } - /** - * Inverted list data [size, dim]. - * - * The data consists of the dataset rows, grouped by their labels (into clusters/lists). - * Within each list (cluster), the data is grouped into blocks of `kIndexGroupSize` interleaved - * vectors. Note, the total index length is slightly larger than the source dataset length, - * because each cluster is padded by `kIndexGroupSize` elements. - * - * Interleaving pattern: - * within groups of `kIndexGroupSize` rows, the data is interleaved with the block size equal to - * `veclen * sizeof(T)`. That is, a chunk of `veclen` consecutive components of one row is - * followed by a chunk of the same size of the next row, and so on. - * - * __Example__: veclen = 2, dim = 6, kIndexGroupSize = 32, list_size = 31 - * - * x[ 0, 0], x[ 0, 1], x[ 1, 0], x[ 1, 1], ... x[14, 0], x[14, 1], x[15, 0], x[15, 1], - * x[16, 0], x[16, 1], x[17, 0], x[17, 1], ... x[30, 0], x[30, 1], - , - , - * x[ 0, 2], x[ 0, 3], x[ 1, 2], x[ 1, 3], ... x[14, 2], x[14, 3], x[15, 2], x[15, 3], - * x[16, 2], x[16, 3], x[17, 2], x[17, 3], ... x[30, 2], x[30, 3], - , - , - * x[ 0, 4], x[ 0, 5], x[ 1, 4], x[ 1, 5], ... x[14, 4], x[14, 5], x[15, 4], x[15, 5], - * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , - * - */ - /** Sizes of the lists (clusters) [n_lists] - * NB: This may differ from the actual list size if the shared lists have been extended by another - * index - */ - inline auto list_sizes() noexcept -> raft::device_vector_view - { - return list_sizes_.view(); - } - [[nodiscard]] inline auto list_sizes() const noexcept - -> raft::device_vector_view - { - return list_sizes_.view(); - } - - /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - inline auto centers() noexcept -> raft::device_matrix_view - { - return centers_.view(); - } - [[nodiscard]] inline auto centers() const noexcept - -> raft::device_matrix_view - { - return centers_.view(); - } - - /** - * (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists]. - * - * NB: this may be empty if the index is empty or if the metric does not require the center norms - * calculation. - */ - inline auto center_norms() noexcept -> std::optional> - { - if (center_norms_.has_value()) { - return std::make_optional>(center_norms_->view()); - } else { - return std::nullopt; - } - } - [[nodiscard]] inline auto center_norms() const noexcept - -> std::optional> - { - if (center_norms_.has_value()) { - return std::make_optional>( - center_norms_->view()); - } else { - return std::nullopt; - } - } - - /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return total_size_; } - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return centers_.extent(1); - } - /** Number of clusters/inverted lists. */ - [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t { return lists_.size(); } - - // Don't allow copying the index for performance reasons (try avoiding copying data) - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - ~index() = default; - - /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::resources const& res, - cuvs::distance::DistanceType metric, - uint32_t n_lists, - bool adaptive_centers, - bool conservative_memory_allocation, - uint32_t dim) - : ann::index(), - veclen_(calculate_veclen(dim)), - metric_(metric), - adaptive_centers_(adaptive_centers), - conservative_memory_allocation_{conservative_memory_allocation}, - centers_(raft::make_device_matrix(res, n_lists, dim)), - center_norms_(std::nullopt), - lists_{n_lists}, - list_sizes_{raft::make_device_vector(res, n_lists)}, - data_ptrs_{raft::make_device_vector(res, n_lists)}, - inds_ptrs_{raft::make_device_vector(res, n_lists)}, - total_size_{0} - { - check_consistency(); - } - - /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::resources const& res, const index_params& params, uint32_t dim) - : index(res, - params.metric, - params.n_lists, - params.adaptive_centers, - params.conservative_memory_allocation, - dim) - { - } - - /** Pointers to the inverted lists (clusters) data [n_lists]. */ - inline auto data_ptrs() noexcept -> raft::device_vector_view - { - return data_ptrs_.view(); - } - [[nodiscard]] inline auto data_ptrs() const noexcept - -> raft::device_vector_view - { - return data_ptrs_.view(); - } - - /** Pointers to the inverted lists (clusters) indices [n_lists]. */ - inline auto inds_ptrs() noexcept -> raft::device_vector_view - { - return inds_ptrs_.view(); - } - [[nodiscard]] inline auto inds_ptrs() const noexcept - -> raft::device_vector_view - { - return inds_ptrs_.view(); - } - /** - * Whether to use convervative memory allocation when extending the list (cluster) data - * (see index_params.conservative_memory_allocation). - */ - [[nodiscard]] constexpr inline auto conservative_memory_allocation() const noexcept -> bool - { - return conservative_memory_allocation_; - } - - /** - * Update the state of the dependent index members. - */ - void recompute_internal_state(raft::resources const& res) - { - auto stream = resource::get_cuda_stream(res); - - // Actualize the list pointers - auto this_lists = lists(); - auto this_data_ptrs = data_ptrs(); - auto this_inds_ptrs = inds_ptrs(); - for (uint32_t label = 0; label < this_lists.size(); label++) { - auto& list = this_lists[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - copy(&this_data_ptrs(label), &data_ptr, 1, stream); - copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); - } - auto this_list_sizes = list_sizes().data_handle(); - total_size_ = thrust::reduce(raft::resource::get_thrust_policy(res), - this_list_sizes, - this_list_sizes + this_lists.size(), - 0, - raft::add_op{}); - check_consistency(); - } - - void allocate_center_norms(raft::resources const& res) - { - switch (metric_) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: - case cuvs::distance::DistanceType::L2Unexpanded: - case cuvs::distance::DistanceType::L2SqrtUnexpanded: - center_norms_ = raft::make_device_vector(res, n_lists()); - break; - default: center_norms_ = std::nullopt; - } - } - - /** Lists' data and indices. */ - inline auto lists() noexcept -> std::vector>>& - { - return lists_; - } - [[nodiscard]] inline auto lists() const noexcept - -> const std::vector>>& - { - return lists_; - } - - private: - /** - * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum - * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711 - */ - uint32_t veclen_; - cuvs::distance::DistanceType metric_; - bool adaptive_centers_; - bool conservative_memory_allocation_; - std::vector>> lists_; - raft::device_vector list_sizes_; - raft::device_matrix centers_; - std::optional> center_norms_; - - // Computed members - raft::device_vector data_ptrs_; - raft::device_vector inds_ptrs_; - IdxT total_size_; - - /** Throw an error if the index content is inconsistent. */ - void check_consistency() - { - auto n_lists = lists_.size(); - RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), - "inconsistent number of lists (clusters)"); - } - - static auto calculate_veclen(uint32_t dim) -> uint32_t - { - // TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a - // template parameter (https://github.com/rapidsai/raft/issues/711) - - // NOTE: keep this consistent with the select_interleaved_scan_kernel logic - // in detail/ivf_flat_interleaved_scan-inl.cuh. - uint32_t veclen = std::max(1, 16 / sizeof(T)); - if (dim % veclen != 0) { veclen = 1; } - return veclen; - } -}; - -/** @} */ - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/include/cuvs/neighbors/ivf_list.hpp b/cpp/include/cuvs/neighbors/ivf_list.hpp deleted file mode 100644 index c395980de..000000000 --- a/cpp/include/cuvs/neighbors/ivf_list.hpp +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace cuvs::neighbors::ivf { - -/** The data for a single IVF list. */ -template