Skip to content

Commit

Permalink
Deprecate cub::DeviceSpmv (#3320)
Browse files Browse the repository at this point in the history
Fixes: #896
  • Loading branch information
bernhardmgruber authored Jan 14, 2025
1 parent 0e63552 commit d5d3aa6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 21 deletions.
15 changes: 12 additions & 3 deletions cub/cub/agent/agent_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ template <int _BLOCK_THREADS,
CacheLoadModifier _VECTOR_VALUES_LOAD_MODIFIER,
bool _DIRECT_LOAD_NONZEROS,
BlockScanAlgorithm _SCAN_ALGORITHM>
struct AgentSpmvPolicy
struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") AgentSpmvPolicy
{
enum
{
Expand Down Expand Up @@ -148,7 +148,12 @@ struct AgentSpmvPolicy
* Signed integer type for sequence offsets
*/
template <typename ValueT, typename OffsetT>
struct SpmvParams
struct
// with NVHPC, we get a deprecation warning in the implementation of cudaLaunchKernelEx, which we cannot suppress :/
#if !_CCCL_COMPILER(NVHPC)
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
#endif
SpmvParams
{
/// Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix
/// <b>A</b>.
Expand Down Expand Up @@ -211,7 +216,7 @@ template <typename AgentSpmvPolicyT,
bool HAS_ALPHA,
bool HAS_BETA,
int LEGACY_PTX_ARCH = 0>
struct AgentSpmv
struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") AgentSpmv
{
//---------------------------------------------------------------------
// Types and constants
Expand Down Expand Up @@ -308,7 +313,9 @@ struct AgentSpmv
/// Reference to temp_storage
_TempStorage& temp_storage;

_CCCL_SUPPRESS_DEPRECATED_PUSH
SpmvParams<ValueT, OffsetT>& spmv_params;
_CCCL_SUPPRESS_DEPRECATED_POP

/// Wrapped pointer to the array of \p num_nonzeros values of the corresponding nonzero elements
/// of matrix <b>A</b>.
Expand Down Expand Up @@ -341,6 +348,7 @@ struct AgentSpmv
* @param spmv_params
* SpMV input parameter bundle
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
_CCCL_DEVICE _CCCL_FORCEINLINE AgentSpmv(TempStorage& temp_storage, SpmvParams<ValueT, OffsetT>& spmv_params)
: temp_storage(temp_storage.Alias())
, spmv_params(spmv_params)
Expand All @@ -350,6 +358,7 @@ struct AgentSpmv
, wd_vector_x(spmv_params.d_vector_x)
, wd_vector_y(spmv_params.d_vector_y)
{}
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* @brief Consume a merge tile, specialized for direct-load of nonzeros
Expand Down
29 changes: 16 additions & 13 deletions cub/cub/device/device_spmv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ CUB_NAMESPACE_BEGIN
//! @cdp_class{DeviceSpmv}
//!
//! @endrst
struct DeviceSpmv
struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DeviceSpmv
{
//! @name CSR matrix operations
//! @{
Expand Down Expand Up @@ -177,18 +177,19 @@ struct DeviceSpmv
//! **[optional]** CUDA stream to launch kernels within. Default is stream\ :sub:`0`.
//! @endrst
template <typename ValueT>
CUB_RUNTIME_FUNCTION static cudaError_t CsrMV(
void* d_temp_storage,
size_t& temp_storage_bytes,
const ValueT* d_values,
const int* d_row_offsets,
const int* d_column_indices,
const ValueT* d_vector_x,
ValueT* d_vector_y,
int num_rows,
int num_cols,
int num_nonzeros,
cudaStream_t stream = 0)
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_RUNTIME_FUNCTION static cudaError_t
CsrMV(void* d_temp_storage,
size_t& temp_storage_bytes,
const ValueT* d_values,
const int* d_row_offsets,
const int* d_column_indices,
const ValueT* d_vector_x,
ValueT* d_vector_y,
int num_rows,
int num_cols,
int num_nonzeros,
cudaStream_t stream = 0)
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceSpmv::CsrMV");

Expand All @@ -204,7 +205,9 @@ struct DeviceSpmv
spmv_params.alpha = ValueT{1};
spmv_params.beta = ValueT{0};

_CCCL_SUPPRESS_DEPRECATED_PUSH
return DispatchSpmv<ValueT, int>::Dispatch(d_temp_storage, temp_storage_bytes, spmv_params, stream);
_CCCL_SUPPRESS_DEPRECATED_POP
}

//! @} end member group
Expand Down
21 changes: 16 additions & 5 deletions cub/cub/device/dispatch/dispatch_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ CUB_NAMESPACE_BEGIN
* @param[in] spmv_params
* SpMV input parameter bundle
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename AgentSpmvPolicyT, typename ValueT, typename OffsetT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmv1ColKernel(SpmvParams<ValueT, OffsetT> spmv_params)
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmv1ColKernel(SpmvParams<ValueT, OffsetT> spmv_params) //
_CCCL_SUPPRESS_DEPRECATED_POP
{
using VectorValueIteratorT =
CacheModifiedInputIterator<AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, ValueT, OffsetT>;
Expand Down Expand Up @@ -132,8 +135,9 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmv1ColKernel(SpmvParams<ValueT, Offset
* SpMV input parameter bundle
*/
template <typename SpmvPolicyT, typename OffsetT, typename CoordinateT, typename SpmvParamsT>
CUB_DETAIL_KERNEL_ATTRIBUTES void
DeviceSpmvSearchKernel(int num_merge_tiles, CoordinateT* d_tile_coordinates, SpmvParamsT spmv_params)
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvSearchKernel(
int num_merge_tiles, CoordinateT* d_tile_coordinates, SpmvParamsT spmv_params)
{
/// Constants
enum
Expand Down Expand Up @@ -217,6 +221,7 @@ template <typename SpmvPolicyT,
typename CoordinateT,
bool HAS_ALPHA,
bool HAS_BETA>
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
__launch_bounds__(int(SpmvPolicyT::BLOCK_THREADS)) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvKernel(
SpmvParams<ValueT, OffsetT> spmv_params,
CoordinateT* d_tile_coordinates,
Expand All @@ -226,7 +231,9 @@ __launch_bounds__(int(SpmvPolicyT::BLOCK_THREADS)) CUB_DETAIL_KERNEL_ATTRIBUTES
int num_segment_fixup_tiles)
{
// Spmv agent type specialization
_CCCL_SUPPRESS_DEPRECATED_PUSH
using AgentSpmvT = AgentSpmv<SpmvPolicyT, ValueT, OffsetT, HAS_ALPHA, HAS_BETA>;
_CCCL_SUPPRESS_DEPRECATED_POP

// Shared memory for AgentSpmv
__shared__ typename AgentSpmvT::TempStorage temp_storage;
Expand All @@ -248,6 +255,7 @@ __launch_bounds__(int(SpmvPolicyT::BLOCK_THREADS)) CUB_DETAIL_KERNEL_ATTRIBUTES
* Whether the input parameter Beta is 0
*/
template <typename ValueT, typename OffsetT, bool HAS_BETA>
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvEmptyMatrixKernel(SpmvParams<ValueT, OffsetT> spmv_params)
{
const int row = static_cast<int>(threadIdx.x + blockIdx.x * blockDim.x);
Expand Down Expand Up @@ -298,18 +306,21 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvEmptyMatrixKernel(SpmvParams<ValueT,
* @param[in] tile_state
* Tile status interface
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename AgentSegmentFixupPolicyT,
typename PairsInputIteratorT,
typename AggregatesOutputIteratorT,
typename OffsetT,
typename ScanTileStateT>
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
__launch_bounds__(int(AgentSegmentFixupPolicyT::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentFixupKernel(
PairsInputIteratorT d_pairs_in,
AggregatesOutputIteratorT d_aggregates_out,
OffsetT num_items,
int num_tiles,
ScanTileStateT tile_state)
ScanTileStateT tile_state) //
_CCCL_SUPPRESS_DEPRECATED_POP
{
// Thread block type for reducing tiles of value segments
using AgentSegmentFixupT =
Expand Down Expand Up @@ -342,7 +353,7 @@ __launch_bounds__(int(AgentSegmentFixupPolicyT::BLOCK_THREADS))
* Signed integer type for global offsets
*/
template <typename ValueT, typename OffsetT>
struct DispatchSpmv
struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
{
//---------------------------------------------------------------------
// Constants and Types
Expand Down
4 changes: 4 additions & 0 deletions cub/test/test_device_spmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
#include <c2h/device_policy.h>
#include <c2h/vector.h>

_CCCL_SUPPRESS_DEPRECATED_PUSH

bool g_verbose = false;

//==============================================================================
Expand Down Expand Up @@ -605,3 +607,5 @@ int main(int argc, char** argv)

test_types();
}

_CCCL_SUPPRESS_DEPRECATED_POP

0 comments on commit d5d3aa6

Please sign in to comment.