From b7d2f75b1c42047f37d283e4f348daec5538318c Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Fri, 10 Jan 2025 08:35:21 -0800 Subject: [PATCH] limits offset types for merge sort --- cub/cub/device/device_merge_sort.cuh | 25 ++++++++++------------- cub/test/catch2_test_device_merge_sort.cu | 1 - 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/cub/cub/device/device_merge_sort.cuh b/cub/cub/device/device_merge_sort.cuh index d42f6033a7e..b344fffab98 100644 --- a/cub/cub/device/device_merge_sort.cuh +++ b/cub/cub/device/device_merge_sort.cuh @@ -129,10 +129,10 @@ private: CompareOpT compare_op, cudaStream_t stream = 0) { - using PromotedOffsetT = detail::promote_small_offset_t; + using ChooseOffsetT = detail::choose_offset_t; using DispatchMergeSortT = - DispatchMergeSort; + DispatchMergeSort; return DispatchMergeSortT::Dispatch( d_temp_storage, temp_storage_bytes, d_keys, d_items, d_keys, d_items, num_items, compare_op, stream); @@ -393,10 +393,10 @@ public: cudaStream_t stream = 0) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName()); - using PromotedOffsetT = detail::promote_small_offset_t; + using ChooseOffsetT = detail::choose_offset_t; using DispatchMergeSortT = - DispatchMergeSort; + DispatchMergeSort; return DispatchMergeSortT::Dispatch( d_temp_storage, @@ -455,10 +455,10 @@ private: CompareOpT compare_op, cudaStream_t stream = 0) { - using PromotedOffsetT = detail::promote_small_offset_t; + using ChooseOffsetT = detail::choose_offset_t; using DispatchMergeSortT = - DispatchMergeSort; + DispatchMergeSort; return DispatchMergeSortT::Dispatch( d_temp_storage, @@ -599,10 +599,10 @@ private: CompareOpT compare_op, cudaStream_t stream = 0) { - using PromotedOffsetT = detail::promote_small_offset_t; + using ChooseOffsetT = detail::choose_offset_t; using DispatchMergeSortT = - DispatchMergeSort; + DispatchMergeSort; return DispatchMergeSortT::Dispatch( d_temp_storage, @@ -850,9 +850,8 @@ public: cudaStream_t stream = 0) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName()); - using PromotedOffsetT = detail::promote_small_offset_t; - return SortPairsNoNVTX( + return SortPairsNoNVTX( d_temp_storage, temp_storage_bytes, d_keys, d_items, num_items, compare_op, stream); } @@ -969,9 +968,8 @@ public: cudaStream_t stream = 0) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName()); - using PromotedOffsetT = detail::promote_small_offset_t; - return SortKeysNoNVTX( + return SortKeysNoNVTX( d_temp_storage, temp_storage_bytes, d_keys, num_items, compare_op, stream); } @@ -1101,8 +1099,7 @@ public: cudaStream_t stream = 0) { CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName()); - using PromotedOffsetT = detail::promote_small_offset_t; - return SortKeysCopyNoNVTX( + return SortKeysCopyNoNVTX( d_temp_storage, temp_storage_bytes, d_input_keys, d_output_keys, num_items, compare_op, stream); } }; diff --git a/cub/test/catch2_test_device_merge_sort.cu b/cub/test/catch2_test_device_merge_sort.cu index 86fa2331fee..f8ec150b845 100644 --- a/cub/test/catch2_test_device_merge_sort.cu +++ b/cub/test/catch2_test_device_merge_sort.cu @@ -78,7 +78,6 @@ struct type_tuple }; using offset_types = c2h::type_list, - type_tuple, type_tuple, type_tuple, type_tuple>;