Skip to content

Commit

Permalink
trying value type instead of ref type
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jan 11, 2025
1 parent 0d89862 commit f5a8fbc
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions cub/cub/device/dispatch/dispatch_segmented_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,38 +80,47 @@ using per_invocation_segment_offset_t = int;
// Type used for total number of segments and to index within segments globally
using global_segment_offset_t = int;

template <typename T>
struct DereferenceHelper {
using type = decltype(*std::declval<T>().base());
};

template <typename Iterator, typename OffsetItT>
class OffsetIteratorT
: public THRUST_NS_QUALIFIER::iterator_adaptor<OffsetIteratorT<Iterator, OffsetItT>,
Iterator,
THRUST_NS_QUALIFIER::use_default,
THRUST_NS_QUALIFIER::any_system_tag,
THRUST_NS_QUALIFIER::random_access_traversal_tag>
THRUST_NS_QUALIFIER::random_access_traversal_tag,
typename ::cuda::std::iterator_traits<Iterator>::value_type>
{
public:
using super_t =
THRUST_NS_QUALIFIER::iterator_adaptor<OffsetIteratorT<Iterator, OffsetItT>,
Iterator,
THRUST_NS_QUALIFIER::use_default,
THRUST_NS_QUALIFIER::any_system_tag,
THRUST_NS_QUALIFIER::random_access_traversal_tag>;
THRUST_NS_QUALIFIER::random_access_traversal_tag,
typename ::cuda::std::iterator_traits<Iterator>::value_type>;

OffsetIteratorT() = default;

_CCCL_HOST_DEVICE OffsetIteratorT(const Iterator& it, OffsetItT offset_it)
: super_t(it)
, offset_it(offset_it)
, it(it)
{}

// befriend thrust::iterator_core_access to allow it access to the private interface below
friend class THRUST_NS_QUALIFIER::iterator_core_access;

private:
OffsetItT offset_it;
Iterator it;

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE typename super_t::reference dereference() const
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE typename ::cuda::std::iterator_traits<Iterator>::value_type dereference() const
{
return *(this->base() + static_cast<typename super_t::difference_type>(*offset_it));
return *(it + static_cast<typename super_t::difference_type>(*offset_it));
}
};

Expand Down

0 comments on commit f5a8fbc

Please sign in to comment.