Skip to content

Commit

Permalink
limits offset types for merge sort (#3328)
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle authored Jan 11, 2025
1 parent 9a04941 commit cc7c1bb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
25 changes: 11 additions & 14 deletions cub/cub/device/device_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ private:
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;
using ChooseOffsetT = detail::choose_offset_t<OffsetT>;

using DispatchMergeSortT =
DispatchMergeSort<KeyIteratorT, ValueIteratorT, KeyIteratorT, ValueIteratorT, PromotedOffsetT, CompareOpT>;
DispatchMergeSort<KeyIteratorT, ValueIteratorT, KeyIteratorT, ValueIteratorT, ChooseOffsetT, CompareOpT>;

return DispatchMergeSortT::Dispatch(
d_temp_storage, temp_storage_bytes, d_keys, d_items, d_keys, d_items, num_items, compare_op, stream);
Expand Down Expand Up @@ -374,10 +374,10 @@ public:
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName());
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;
using ChooseOffsetT = detail::choose_offset_t<OffsetT>;

using DispatchMergeSortT =
DispatchMergeSort<KeyInputIteratorT, ValueInputIteratorT, KeyIteratorT, ValueIteratorT, PromotedOffsetT, CompareOpT>;
DispatchMergeSort<KeyInputIteratorT, ValueInputIteratorT, KeyIteratorT, ValueIteratorT, ChooseOffsetT, CompareOpT>;

return DispatchMergeSortT::Dispatch(
d_temp_storage,
Expand All @@ -402,10 +402,10 @@ private:
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;
using ChooseOffsetT = detail::choose_offset_t<OffsetT>;

using DispatchMergeSortT =
DispatchMergeSort<KeyIteratorT, NullType*, KeyIteratorT, NullType*, PromotedOffsetT, CompareOpT>;
DispatchMergeSort<KeyIteratorT, NullType*, KeyIteratorT, NullType*, ChooseOffsetT, CompareOpT>;

return DispatchMergeSortT::Dispatch(
d_temp_storage,
Expand Down Expand Up @@ -528,10 +528,10 @@ private:
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;
using ChooseOffsetT = detail::choose_offset_t<OffsetT>;

using DispatchMergeSortT =
DispatchMergeSort<KeyInputIteratorT, NullType*, KeyIteratorT, NullType*, PromotedOffsetT, CompareOpT>;
DispatchMergeSort<KeyInputIteratorT, NullType*, KeyIteratorT, NullType*, ChooseOffsetT, CompareOpT>;

return DispatchMergeSortT::Dispatch(
d_temp_storage,
Expand Down Expand Up @@ -760,9 +760,8 @@ public:
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName());
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

return SortPairsNoNVTX<KeyIteratorT, ValueIteratorT, PromotedOffsetT, CompareOpT>(
return SortPairsNoNVTX<KeyIteratorT, ValueIteratorT, OffsetT, CompareOpT>(
d_temp_storage, temp_storage_bytes, d_keys, d_items, num_items, compare_op, stream);
}

Expand Down Expand Up @@ -860,9 +859,8 @@ public:
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName());
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

return SortKeysNoNVTX<KeyIteratorT, PromotedOffsetT, CompareOpT>(
return SortKeysNoNVTX<KeyIteratorT, OffsetT, CompareOpT>(
d_temp_storage, temp_storage_bytes, d_keys, num_items, compare_op, stream);
}

Expand Down Expand Up @@ -974,8 +972,7 @@ public:
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, GetName());
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;
return SortKeysCopyNoNVTX<KeyInputIteratorT, KeyIteratorT, PromotedOffsetT, CompareOpT>(
return SortKeysCopyNoNVTX<KeyInputIteratorT, KeyIteratorT, OffsetT, CompareOpT>(
d_temp_storage, temp_storage_bytes, d_input_keys, d_output_keys, num_items, compare_op, stream);
}
};
Expand Down
1 change: 0 additions & 1 deletion cub/test/catch2_test_device_merge_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ struct type_tuple
};
using offset_types =
c2h::type_list<type_tuple<std::int16_t>,
type_tuple<std::int32_t>,
type_tuple<std::int32_t, std::uint32_t>,
type_tuple<std::uint32_t>,
type_tuple<std::uint64_t>>;
Expand Down

0 comments on commit cc7c1bb

Please sign in to comment.