From 5c826d7320486852c30a18f6e039d0cda83c5c62 Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 14 Jan 2025 00:34:20 +0100 Subject: [PATCH] Add support for different data type of bitset (#2535) This PR is useful for Milvus. Previously the `bitset_view` object only supported the data type used to create the bitset. With the proposed modifications, a `bitset_view` object can support any data type used to create the bitset by specifying the `original_nbits` parameter in the class constructor. Authors: - Micka (https://github.com/lowener) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - rhdong (https://github.com/rhdong) URL: https://github.com/rapidsai/raft/pull/2535 --- cpp/include/raft/core/bitmap.hpp | 24 ++++++-- cpp/include/raft/core/bitset.cuh | 53 ++++++++++++++--- cpp/include/raft/core/bitset.hpp | 34 +++++++++-- cpp/test/core/bitset.cu | 98 ++++++++++++++++++++++++++++++-- 4 files changed, 188 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp index 86b2d77478..5a6656f572 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.hpp @@ -53,9 +53,18 @@ struct bitmap_view : public bitset_view { * @param bitmap_ptr Device raw pointer * @param rows Number of row in the matrix. * @param cols Number of col in the matrix. + * @param original_nbits Original number of bits used when the bitmap was created, to handle + * potential mismatches of data types. This is useful for using ANN indexes when a bitmap was + * originally created with a different data type than the ones currently supported in cuVS ANN + * indexes. */ - _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, + index_t rows, + index_t cols, + index_t original_nbits = 0) + : bitset_view(bitmap_ptr, rows * cols, original_nbits), + rows_(rows), + cols_(cols) { } @@ -65,11 +74,18 @@ struct bitmap_view : public bitset_view { * @param bitmap_span Device vector view of the bitmap * @param rows Number of row in the matrix. * @param cols Number of col in the matrix. + * @param original_nbits Original number of bits used when the bitmap was created, to handle + * potential mismatches of data types. This is useful for using ANN indexes when a bitmap was + * originally created with a different data type than the ones currently supported in cuVS ANN + * indexes. */ _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, index_t rows, - index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + index_t cols, + index_t original_nbits = 0) + : bitset_view(bitmap_span, rows * cols, original_nbits), + rows_(rows), + cols_(cols) { } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index d1bffdb81e..feaef1a172 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -32,12 +32,41 @@ namespace raft::core { +template +_RAFT_HOST_DEVICE void inline compute_original_nbits_position(const index_t original_nbits, + const index_t nbits, + const index_t sample_index, + index_t& new_bit_index, + index_t& new_bit_offset) +{ + const index_t original_bit_index = sample_index / original_nbits; + const index_t original_bit_offset = sample_index % original_nbits; + new_bit_index = original_bit_index * original_nbits / nbits; + new_bit_offset = 0; + if (original_nbits > nbits) { + new_bit_index += original_bit_offset / nbits; + new_bit_offset = original_bit_offset % nbits; + } else { + index_t ratio = nbits / original_nbits; + new_bit_offset += (original_bit_index % ratio) * original_nbits; + new_bit_offset += original_bit_offset % nbits; + } +} + template _RAFT_HOST_DEVICE inline bool bitset_view::test(const index_t sample_index) const { - const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; - const index_t bit_index = sample_index % bitset_element_size; - const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; + const index_t nbits = sizeof(bitset_t) * 8; + index_t bit_index = 0; + index_t bit_offset = 0; + if (original_nbits_ == 0 || nbits == original_nbits_) { + bit_index = sample_index / bitset_element_size; + bit_offset = sample_index % bitset_element_size; + } else { + compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset); + } + const bitset_t bit_element = bitset_ptr_[bit_index]; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_offset)) != 0; return is_bit_set; } @@ -51,14 +80,22 @@ template _RAFT_DEVICE void bitset_view::set(const index_t sample_index, bool set_value) const { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; + const index_t nbits = sizeof(bitset_t) * 8; + index_t bit_index = 0; + index_t bit_offset = 0; + + if (original_nbits_ == 0 || nbits == original_nbits_) { + bit_index = sample_index / bitset_element_size; + bit_offset = sample_index % bitset_element_size; + } else { + compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset); + } + const bitset_t bitmask = bitset_t{1} << bit_offset; if (set_value) { - atomicOr(bitset_ptr_ + bit_element, bitmask); + atomicOr(bitset_ptr_ + bit_index, bitmask); } else { const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr_ + bit_element, bitmask2); + atomicAnd(bitset_ptr_ + bit_index, bitmask2); } } diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp index be828def87..e4bea2c0c5 100644 --- a/cpp/include/raft/core/bitset.hpp +++ b/cpp/include/raft/core/bitset.hpp @@ -42,8 +42,20 @@ template struct bitset_view { static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; - _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) - : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} + /** + * @brief Create a bitset view from a device pointer to the bitset. + * + * @param bitset_ptr Device pointer to the bitset + * @param bitset_len Number of bits in the bitset + * @param original_nbits Original number of bits used when the bitset was created, to handle + * potential mismatches of data types. This is useful for using ANN indexes when a bitset was + * originally created with a different data type than the ones currently supported in cuVS ANN + * indexes. + */ + _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits = 0) + : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}, original_nbits_{original_nbits} { } /** @@ -51,10 +63,17 @@ struct bitset_view { * * @param bitset_span Device vector view of the bitset * @param bitset_len Number of bits in the bitset + * @param original_nbits Original number of bits used when the bitset was created, to handle + * potential mismatches of data types. This is useful for using ANN indexes when a bitset was + * originally created with a different data type than the ones currently supported in cuVS ANN + * indexes. */ _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span, - index_t bitset_len) - : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len} + index_t bitset_len, + index_t original_nbits = 0) + : bitset_ptr_{bitset_span.data_handle()}, + bitset_len_{bitset_len}, + original_nbits_{original_nbits} { } /** @@ -180,9 +199,16 @@ struct bitset_view { return (bitset_len + bits_per_element - 1) / bits_per_element; } + /** + * @brief Get the original number of bits of the bitset. + */ + auto get_original_nbits() const -> index_t { return original_nbits_; } + void set_original_nbits(index_t original_nbits) { original_nbits_ = original_nbits; } + private: bitset_t* bitset_ptr_; index_t bitset_len_; + index_t original_nbits_; }; /** diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index ac601274c1..f094f60ded 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -24,6 +24,8 @@ #include #include +#include +#include #include namespace raft::core { @@ -73,6 +75,40 @@ void test_cpu_bitset(const std::vector& bitset, } } +template +void test_cpu_bitset_nbits(const bitset_t* bitset, + const std::vector& queries, + std::vector& result, + unsigned original_nbits_) +{ + constexpr size_t nbits = sizeof(bitset_t) * 8; + if (original_nbits_ == nbits) { + for (size_t i = 0; i < queries.size(); i++) { + result[i] = + uint8_t((bitset[queries[i] / nbits] & (bitset_t{1} << (queries[i] % nbits))) != 0); + } + } + for (size_t i = 0; i < queries.size(); i++) { + const index_t sample_index = queries[i]; + const index_t original_bit_index = sample_index / original_nbits_; + const index_t original_bit_offset = sample_index % original_nbits_; + index_t new_bit_index = original_bit_index * original_nbits_ / nbits; + index_t new_bit_offset = 0; + if (original_nbits_ > nbits) { + new_bit_index += original_bit_offset / nbits; + new_bit_offset = original_bit_offset % nbits; + } else { + index_t ratio = nbits / original_nbits_; + new_bit_offset += (original_bit_index % ratio) * original_nbits_; + new_bit_offset += original_bit_offset % nbits; + } + const bitset_t bit_element = bitset[new_bit_index]; + const bool is_bit_set = (bit_element & (bitset_t{1} << new_bit_offset)) != 0; + + result[i] = uint8_t(is_bit_set); + } +} + template void flip_cpu_bitset(std::vector& bitset) { @@ -168,11 +204,12 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); - auto query_device = raft::make_device_vector(res, spec.query_len); - auto result_device = raft::make_device_vector(res, spec.query_len); - auto query_cpu = std::vector(spec.query_len); - auto result_cpu = std::vector(spec.query_len); - auto result_ref = std::vector(spec.query_len); + auto query_device = raft::make_device_vector(res, spec.query_len); + auto result_device = raft::make_device_vector(res, spec.query_len); + auto query_cpu = std::vector(spec.query_len); + auto result_cpu = std::vector(spec.query_len); + auto result_ref_nbits = std::vector(spec.query_len); + auto result_ref = std::vector(spec.query_len); // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); @@ -194,6 +231,57 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + // Reinterpret the bitset as uint8_t, uint32 then uint64_t + { + // Test CPU logic + test_cpu_bitset(bitset_ref, query_cpu, result_ref); + uint8_t* bitset_cpu_uint8 = (uint8_t*)std::malloc(sizeof(bitset_t) * bitset_ref.size()); + std::memcpy(bitset_cpu_uint8, bitset_ref.data(), sizeof(bitset_t) * bitset_ref.size()); + test_cpu_bitset_nbits(bitset_cpu_uint8, query_cpu, result_ref_nbits, sizeof(bitset_t) * 8); + ASSERT_TRUE(hostVecMatch(result_ref, result_ref_nbits, raft::Compare())); + std::free(bitset_cpu_uint8); + + // Test GPU uint8_t, uint32_t, uint64_t + auto my_bitset_view_uint8_t = raft::core::bitset_view( + reinterpret_cast(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8); + raft::linalg::map( + res, + result_device.view(), + [my_bitset_view_uint8_t] __device__(index_t query) { + return my_bitset_view_uint8_t.test(query); + }, + raft::make_const_mdspan(query_device.view())); + update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare())); + + auto my_bitset_view_uint32_t = raft::core::bitset_view( + reinterpret_cast(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8); + raft::linalg::map( + res, + result_device.view(), + [my_bitset_view_uint32_t] __device__(index_t query) { + return my_bitset_view_uint32_t.test(query); + }, + raft::make_const_mdspan(query_device.view())); + update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare())); + + auto my_bitset_view_uint64_t = raft::core::bitset_view( + reinterpret_cast(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8); + raft::linalg::map( + res, + result_device.view(), + [my_bitset_view_uint64_t] __device__(index_t query) { + return my_bitset_view_uint64_t.test(query); + }, + raft::make_const_mdspan(query_device.view())); + update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare())); + } + // test sparsity, repeat and eval_n_elements { auto my_bitset_view = my_bitset.view();