Skip to content

Commit

Permalink
Replace cub::Traits by numeric_limits and deprecate
Browse files Browse the repository at this point in the history
Fixes: #3381
  • Loading branch information
bernhardmgruber committed Jan 22, 2025
1 parent b1422c0 commit d61de28
Show file tree
Hide file tree
Showing 23 changed files with 515 additions and 446 deletions.
27 changes: 2 additions & 25 deletions c2h/generators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#include <thrust/scan.h>
#include <thrust/tabulate.h>

#include <cuda/std/type_traits>
#include <cuda/type_traits>

#include <cstdint>

Expand Down Expand Up @@ -118,30 +118,7 @@ private:
c2h::device_vector<float> m_distribution;
};

// TODO(bgruber): modelled after cub::Traits. We should generalize this somewhere into libcu++.
template <typename T>
struct is_floating_point : ::cuda::std::is_floating_point<T>
{};
#ifdef _CCCL_HAS_NVFP16
template <>
struct is_floating_point<__half> : ::cuda::std::true_type
{};
#endif // _CCCL_HAS_NVFP16
#ifdef _CCCL_HAS_NVBF16
template <>
struct is_floating_point<__nv_bfloat16> : ::cuda::std::true_type
{};
#endif // _CCCL_HAS_NVBF16
#ifdef __CUDA_FP8_TYPES_EXIST__
template <>
struct is_floating_point<__nv_fp8_e4m3> : ::cuda::std::true_type
{};
template <>
struct is_floating_point<__nv_fp8_e5m2> : ::cuda::std::true_type
{};
#endif // __CUDA_FP8_TYPES_EXIST__

template <typename T, bool = is_floating_point<T>::value>
template <typename T, bool = ::cuda::is_floating_point_v<T>>
struct random_to_item_t
{
float m_min;
Expand Down
26 changes: 15 additions & 11 deletions c2h/include/c2h/bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ struct bfloat16_t
}
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif

/******************************************************************************
* I/O stream overloads
******************************************************************************/
Expand All @@ -229,28 +233,28 @@ inline std::ostream& operator<<(std::ostream& out, const __nv_bfloat16& x)
}

/******************************************************************************
* Traits overloads
* limits
******************************************************************************/

_LIBCUDACXX_BEGIN_NAMESPACE_STD
template <>
struct CUB_NS_QUALIFIER::FpLimits<bfloat16_t>
class numeric_limits<bfloat16_t>
{
static __host__ __device__ __forceinline__ bfloat16_t Max()
public:
static __host__ __device__ __forceinline__ bfloat16_t max()
{
return bfloat16_t::max();
}

static __host__ __device__ __forceinline__ bfloat16_t Lowest()
static __host__ __device__ __forceinline__ bfloat16_t lowest()
{
return bfloat16_t::lowest();
}
};
_LIBCUDACXX_END_NAMESPACE_STD

template <>
struct CUB_NS_QUALIFIER::NumericTraits<bfloat16_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t>
{};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif
struct CUB_NS_QUALIFIER::detail::unsigned_bits<bfloat16_t, void>
{
using type = unsigned short;
};
11 changes: 5 additions & 6 deletions c2h/include/c2h/custom_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,34 +178,33 @@ class accumulateable_t

} // namespace c2h

namespace std
{
_LIBCUDACXX_BEGIN_NAMESPACE_STD
template <template <typename> class... Policies>
class numeric_limits<c2h::custom_type_t<Policies...>>
{
public:
static c2h::custom_type_t<Policies...> max()
static __host__ __device__ c2h::custom_type_t<Policies...> max()
{
c2h::custom_type_t<Policies...> val;
val.key = std::numeric_limits<std::size_t>::max();
val.val = std::numeric_limits<std::size_t>::max();
return val;
}

static c2h::custom_type_t<Policies...> min()
static __host__ __device__ c2h::custom_type_t<Policies...> min()
{
c2h::custom_type_t<Policies...> val;
val.key = std::numeric_limits<std::size_t>::min();
val.val = std::numeric_limits<std::size_t>::min();
return val;
}

static c2h::custom_type_t<Policies...> lowest()
static __host__ __device__ c2h::custom_type_t<Policies...> lowest()
{
c2h::custom_type_t<Policies...> val;
val.key = std::numeric_limits<std::size_t>::lowest();
val.val = std::numeric_limits<std::size_t>::lowest();
return val;
}
};
} // namespace std
_LIBCUDACXX_END_NAMESPACE_STD
45 changes: 5 additions & 40 deletions c2h/include/c2h/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

#include <thrust/detail/config/device_system.h>

#include <limits>
#include <cuda/std/limits>

#include <c2h/custom_type.h>
#include <c2h/vector.h>
Expand All @@ -52,41 +52,6 @@ _CCCL_DIAG_PUSH
_CCCL_DIAG_POP
# endif // _CCCL_CUDACC_AT_LEAST(11, 8)
# endif // _CCCL_HAS_NVBF16

# if defined(__CUDA_FP8_TYPES_EXIST__)
namespace std
{
template <>
class numeric_limits<__nv_fp8_e4m3>
{
public:
static __nv_fp8_e4m3 max()
{
return cub::Traits<__nv_fp8_e4m3>::Max();
}

static __nv_fp8_e4m3 lowest()
{
return cub::Traits<__nv_fp8_e4m3>::Lowest();
}
};

template <>
class numeric_limits<__nv_fp8_e5m2>
{
public:
static __nv_fp8_e5m2 max()
{
return cub::Traits<__nv_fp8_e5m2>::Max();
}

static __nv_fp8_e5m2 lowest()
{
return cub::Traits<__nv_fp8_e5m2>::Lowest();
}
};
} // namespace std
# endif // defined(__CUDA_FP8_TYPES_EXIST__)
#endif // THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA

namespace c2h
Expand Down Expand Up @@ -157,8 +122,8 @@ void init_key_segments(const c2h::device_vector<OffsetT>& segment_offsets, KeyT*
template <template <typename> class... Ps>
void gen(seed_t seed,
c2h::device_vector<c2h::custom_type_t<Ps...>>& data,
c2h::custom_type_t<Ps...> min = std::numeric_limits<c2h::custom_type_t<Ps...>>::lowest(),
c2h::custom_type_t<Ps...> max = std::numeric_limits<c2h::custom_type_t<Ps...>>::max())
c2h::custom_type_t<Ps...> min = ::cuda::std::numeric_limits<c2h::custom_type_t<Ps...>>::lowest(),
c2h::custom_type_t<Ps...> max = ::cuda::std::numeric_limits<c2h::custom_type_t<Ps...>>::max())
{
detail::gen(seed,
reinterpret_cast<char*>(thrust::raw_pointer_cast(data.data())),
Expand All @@ -171,8 +136,8 @@ void gen(seed_t seed,
template <typename T>
void gen(seed_t seed,
c2h::device_vector<T>& data,
T min = std::numeric_limits<T>::lowest(),
T max = std::numeric_limits<T>::max());
T min = ::cuda::std::numeric_limits<T>::lowest(),
T max = ::cuda::std::numeric_limits<T>::max());

template <typename T>
void gen(modulo_t mod, c2h::device_vector<T>& data);
Expand Down
30 changes: 20 additions & 10 deletions c2h/include/c2h/half.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <cub/util_type.cuh>

#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include <cstdint>
Expand Down Expand Up @@ -306,6 +307,10 @@ struct half_t
}
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif

/******************************************************************************
* I/O stream overloads
******************************************************************************/
Expand All @@ -324,28 +329,33 @@ inline std::ostream& operator<<(std::ostream& out, const __half& x)
}

/******************************************************************************
* Traits overloads
* limits
******************************************************************************/

_LIBCUDACXX_BEGIN_NAMESPACE_STD
template <>
struct CUB_NS_QUALIFIER::FpLimits<half_t>
class numeric_limits<half_t>
{
static __host__ __device__ __forceinline__ half_t Max()
public:
static __host__ __device__ __forceinline__ half_t max()
{
return (half_t::max)();
}

static __host__ __device__ __forceinline__ half_t Lowest()
static __host__ __device__ __forceinline__ half_t lowest()
{
return half_t::lowest();
}
};
_LIBCUDACXX_END_NAMESPACE_STD

template <>
struct CUB_NS_QUALIFIER::NumericTraits<half_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t>
{};
struct CUB_NS_QUALIFIER::detail::unsigned_bits<half_t, void>
{
using type = unsigned short;
};

#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif
// template <>
// struct CUB_NS_QUALIFIER::detail::NumericTraits<half_t>
// : CUB_NS_QUALIFIER::detail::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t>
// {};
Loading

0 comments on commit d61de28

Please sign in to comment.