diff --git a/cub/cub/device/device_segmented_sort.cuh b/cub/cub/device/device_segmented_sort.cuh index 1fb5656b82f..ebf16788ab2 100644 --- a/cub/cub/device/device_segmented_sort.cuh +++ b/cub/cub/device/device_segmented_sort.cuh @@ -45,8 +45,41 @@ #include #include +#include + CUB_NAMESPACE_BEGIN +template +class OffsetIteratorT : public THRUST_NS_QUALIFIER::iterator_adaptor, Iterator> +{ +public: + using super_t = THRUST_NS_QUALIFIER::iterator_adaptor, Iterator>; + + OffsetIteratorT() = default; + + _CCCL_HOST_DEVICE OffsetIteratorT(const Iterator& it, OffsetItT offset_it) + : super_t(it) + , offset_it(offset_it) + {} + + // befriend thrust::iterator_core_access to allow it access to the private interface below + friend class THRUST_NS_QUALIFIER::iterator_core_access; + +private: + OffsetItT offset_it; + + _CCCL_HOST_DEVICE typename super_t::reference dereference() const + { + return *(this->base() + (*offset_it)); + } +}; + +template +_CCCL_HOST_DEVICE OffsetIteratorT make_offset_iterator(const Iterator& it, OffsetItT offset_it) +{ + return OffsetIteratorT{it, offset_it}; +} + //! @rst //! DeviceSegmentedSort provides device-wide, parallel operations for //! computing a batched sort across multiple, non-overlapping sequences of @@ -124,6 +157,14 @@ CUB_NAMESPACE_BEGIN //! // d_values_out <-- [1, 2, 0, 5, 4, 3, 6] //! //! @endrst + +template +__global__ void TestKernel(EndOffsetIteratorT m_iterator, int n) +{ + EndOffsetIteratorT m_iterator2 = static_cast(m_iterator + n); + printf("Offset: %ld\n", static_cast(*m_iterator2)); +} + struct DeviceSegmentedSort { private: @@ -138,33 +179,26 @@ private: CUB_RUNTIME_FUNCTION static cudaError_t SortKeysNoNVTX( void* d_temp_storage, std::size_t& temp_storage_bytes, - const KeyT* d_keys_in, - KeyT* d_keys_out, + const KeyT*, + KeyT*, int num_items, int num_segments, - BeginOffsetIteratorT d_begin_offsets, + BeginOffsetIteratorT, EndOffsetIteratorT d_end_offsets, cudaStream_t stream = 0) { - constexpr bool is_descending = false; - constexpr bool is_overwrite_okay = false; - using DispatchT = - DispatchSegmentedSort; - - DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); - DoubleBuffer d_values; + if (d_temp_storage == nullptr) + { + temp_storage_bytes = 1; + return cudaSuccess; + } + if(num_items == 0 || num_segments == 0) + { + return cudaSuccess; + } - return DispatchT::Dispatch( - d_temp_storage, - temp_storage_bytes, - d_keys, - d_values, - num_items, - num_segments, - d_begin_offsets, - d_end_offsets, - is_overwrite_okay, - stream); + TestKernel<<<1, 1, 0, stream>>>(d_end_offsets, 0); + return cudaSuccess; } public: @@ -321,8 +355,8 @@ private: { constexpr bool is_descending = true; constexpr bool is_overwrite_okay = false; - using DispatchT = - DispatchSegmentedSort; + using offset_it_t = OffsetIteratorT>; + using DispatchT = DispatchSegmentedSort; DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; @@ -335,7 +369,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -488,9 +522,9 @@ private: { constexpr bool is_descending = false; constexpr bool is_overwrite_okay = true; + using offset_it_t = OffsetIteratorT>; - using DispatchT = - DispatchSegmentedSort; + using DispatchT = DispatchSegmentedSort; DoubleBuffer d_values; @@ -502,7 +536,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -658,9 +692,9 @@ private: { constexpr bool is_descending = true; constexpr bool is_overwrite_okay = true; + using offset_it_t = OffsetIteratorT>; - using DispatchT = - DispatchSegmentedSort; + using DispatchT = DispatchSegmentedSort; DoubleBuffer d_values; @@ -672,7 +706,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -1379,7 +1413,8 @@ private: { constexpr bool is_descending = false; constexpr bool is_overwrite_okay = false; - using DispatchT = DispatchSegmentedSort; + using offset_it_t = OffsetIteratorT>; + using DispatchT = DispatchSegmentedSort; DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); @@ -1392,7 +1427,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -1578,7 +1613,8 @@ private: { constexpr bool is_descending = true; constexpr bool is_overwrite_okay = false; - using DispatchT = DispatchSegmentedSort; + using offset_it_t = OffsetIteratorT>; + using DispatchT = DispatchSegmentedSort; DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); @@ -1591,7 +1627,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -1771,7 +1807,8 @@ private: { constexpr bool is_descending = false; constexpr bool is_overwrite_okay = true; - using DispatchT = DispatchSegmentedSort; + using offset_it_t = OffsetIteratorT>; + using DispatchT = DispatchSegmentedSort; return DispatchT::Dispatch( d_temp_storage, @@ -1781,7 +1818,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); } @@ -1966,7 +2003,8 @@ private: { constexpr bool is_descending = true; constexpr bool is_overwrite_okay = true; - using DispatchT = DispatchSegmentedSort; + using offset_it_t = OffsetIteratorT>; + using DispatchT = DispatchSegmentedSort; return DispatchT::Dispatch( d_temp_storage, @@ -1976,7 +2014,7 @@ private: num_items, num_segments, d_begin_offsets, - d_end_offsets, + {d_end_offsets, thrust::make_constant_iterator(0)}, is_overwrite_okay, stream); }