Skip to content

Commit

Permalink
Extract scan kernels into NVRTC-compilable header (#3334)
Browse files Browse the repository at this point in the history
* Extract scan kernels into NVRTC-compilable header

* Update cub/cub/device/dispatch/dispatch_scan.cuh

Co-authored-by: Georgii Evtushenko <[email protected]>

---------

Co-authored-by: Ashwin Srinath <[email protected]>
Co-authored-by: Georgii Evtushenko <[email protected]>
  • Loading branch information
3 people authored Jan 13, 2025
1 parent cda5501 commit 3421002
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 153 deletions.
6 changes: 2 additions & 4 deletions cub/cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN

/******************************************************************************
Expand Down Expand Up @@ -162,15 +160,15 @@ struct AgentScan
// Wrap the native input pointer with CacheModifiedInputIterator
// or directly use the supplied input iterator type
using WrappedInputIteratorT =
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

// Constants
enum
{
// Inclusive scan if no init_value type is provided
HAS_INIT = !std::is_same<InitValueT, NullType>::value,
HAS_INIT = !::cuda::std::is_same<InitValueT, NullType>::value,
IS_INCLUSIVE = ForceInclusive || !HAS_INIT, // We are relying on either initial value not being `NullType`
// or the ForceInclusive tag to be true for inclusive scan
// to get picked up.
Expand Down
2 changes: 0 additions & 2 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@

#include <cuda/std/type_traits>

#include <iterator>

#include <nv/target>

CUB_NAMESPACE_BEGIN
Expand Down
2 changes: 0 additions & 2 deletions cub/cub/detail/strong_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

#include <iterator>

CUB_NAMESPACE_BEGIN

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
Expand Down
144 changes: 3 additions & 141 deletions cub/cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#include <cub/config.cuh>

#include <cub/util_namespace.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
Expand All @@ -46,6 +48,7 @@
#endif // no system header

#include <cub/agent/agent_scan.cuh>
#include <cub/device/dispatch/kernels/scan.cuh>
#include <cub/device/dispatch/tuning/tuning_scan.cuh>
#include <cub/grid/grid_queue.cuh>
#include <cub/thread/thread_operators.cuh>
Expand All @@ -57,149 +60,8 @@

#include <cuda/std/type_traits>

#include <iterator>

CUB_NAMESPACE_BEGIN

/******************************************************************************
* Kernel entry points
*****************************************************************************/

/**
* @brief Initialization kernel for tile status initialization (multi-block)
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @param[in] tile_state
* Tile status interface
*
* @param[in] num_tiles
* Number of tiles
*/
template <typename ScanTileStateT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles)
{
// Initialize tile status
tile_state.InitializeStatus(num_tiles);
}

/**
* Initialization kernel for tile status initialization (multi-block)
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @tparam NumSelectedIteratorT
* Output iterator type for recording the number of items selected
*
* @param[in] tile_state
* Tile status interface
*
* @param[in] num_tiles
* Number of tiles
*
* @param[out] d_num_selected_out
* Pointer to the total number of items selected
* (i.e., length of `d_selected_out`)
*/
template <typename ScanTileStateT, typename NumSelectedIteratorT>
CUB_DETAIL_KERNEL_ATTRIBUTES void
DeviceCompactInitKernel(ScanTileStateT tile_state, int num_tiles, NumSelectedIteratorT d_num_selected_out)
{
// Initialize tile status
tile_state.InitializeStatus(num_tiles);

// Initialize d_num_selected_out
if ((blockIdx.x == 0) && (threadIdx.x == 0))
{
*d_num_selected_out = 0;
}
}

/**
* @brief Scan kernel entry point (multi-block)
*
*
* @tparam ChainedPolicyT
* Chained tuning policy
*
* @tparam InputIteratorT
* Random-access input iterator type for reading scan inputs @iterator
*
* @tparam OutputIteratorT
* Random-access output iterator type for writing scan outputs @iterator
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @tparam ScanOpT
* Binary scan functor type having member
* `auto operator()(const T &a, const U &b)`
*
* @tparam InitValueT
* Initial value to seed the exclusive scan
* (cub::NullType for inclusive scans)
*
* @tparam OffsetT
* Unsigned integer type for global offsets
*
* @paramInput d_in
* data
*
* @paramOutput d_out
* data
*
* @paramTile tile_state
* status interface
*
* @paramThe start_tile
* starting tile for the current grid
*
* @paramBinary scan_op
* scan functor
*
* @paramInitial init_value
* value to seed the exclusive scan
*
* @paramTotal num_items
* number of scan items for the entire problem
*/
template <typename ChainedPolicyT,
typename InputIteratorT,
typename OutputIteratorT,
typename ScanTileStateT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT,
bool ForceInclusive>
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel(
InputIteratorT d_in,
OutputIteratorT d_out,
ScanTileStateT tile_state,
int start_tile,
ScanOpT scan_op,
InitValueT init_value,
OffsetT num_items)
{
using RealInitValueT = typename InitValueT::value_type;
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicyT;

// Thread block type for scanning input tiles
using AgentScanT =
AgentScan<ScanPolicyT, InputIteratorT, OutputIteratorT, ScanOpT, RealInitValueT, OffsetT, AccumT, ForceInclusive>;

// Shared memory for AgentScan
__shared__ typename AgentScanT::TempStorage temp_storage;

RealInitValueT real_init_value = init_value;

// Process tiles
AgentScanT(temp_storage, d_in, d_out, scan_op, real_init_value).ConsumeRange(num_items, tile_state, start_tile);
}

/******************************************************************************
* Dispatch
******************************************************************************/
Expand Down
184 changes: 184 additions & 0 deletions cub/cub/device/dispatch/kernels/scan.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/******************************************************************************
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cub/agent/agent_scan.cuh>
#include <cub/util_macro.cuh>

CUB_NAMESPACE_BEGIN

/******************************************************************************
* Kernel entry points
*****************************************************************************/

/**
* @brief Initialization kernel for tile status initialization (multi-block)
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @param[in] tile_state
* Tile status interface
*
* @param[in] num_tiles
* Number of tiles
*/
template <typename ScanTileStateT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles)
{
// Initialize tile status
tile_state.InitializeStatus(num_tiles);
}

/**
* Initialization kernel for tile status initialization (multi-block)
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @tparam NumSelectedIteratorT
* Output iterator type for recording the number of items selected
*
* @param[in] tile_state
* Tile status interface
*
* @param[in] num_tiles
* Number of tiles
*
* @param[out] d_num_selected_out
* Pointer to the total number of items selected
* (i.e., length of `d_selected_out`)
*/
template <typename ScanTileStateT, typename NumSelectedIteratorT>
CUB_DETAIL_KERNEL_ATTRIBUTES void
DeviceCompactInitKernel(ScanTileStateT tile_state, int num_tiles, NumSelectedIteratorT d_num_selected_out)
{
// Initialize tile status
tile_state.InitializeStatus(num_tiles);

// Initialize d_num_selected_out
if ((blockIdx.x == 0) && (threadIdx.x == 0))
{
*d_num_selected_out = 0;
}
}

/**
* @brief Scan kernel entry point (multi-block)
*
*
* @tparam ChainedPolicyT
* Chained tuning policy
*
* @tparam InputIteratorT
* Random-access input iterator type for reading scan inputs @iterator
*
* @tparam OutputIteratorT
* Random-access output iterator type for writing scan outputs @iterator
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @tparam ScanOpT
* Binary scan functor type having member
* `auto operator()(const T &a, const U &b)`
*
* @tparam InitValueT
* Initial value to seed the exclusive scan
* (cub::NullType for inclusive scans)
*
* @tparam OffsetT
* Unsigned integer type for global offsets
*
* @paramInput d_in
* data
*
* @paramOutput d_out
* data
*
* @paramTile tile_state
* status interface
*
* @paramThe start_tile
* starting tile for the current grid
*
* @paramBinary scan_op
* scan functor
*
* @paramInitial init_value
* value to seed the exclusive scan
*
* @paramTotal num_items
* number of scan items for the entire problem
*/
template <typename ChainedPolicyT,
typename InputIteratorT,
typename OutputIteratorT,
typename ScanTileStateT,
typename ScanOpT,
typename InitValueT,
typename OffsetT,
typename AccumT,
bool ForceInclusive>
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel(
InputIteratorT d_in,
OutputIteratorT d_out,
ScanTileStateT tile_state,
int start_tile,
ScanOpT scan_op,
InitValueT init_value,
OffsetT num_items)
{
using RealInitValueT = typename InitValueT::value_type;
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicyT;

// Thread block type for scanning input tiles
using AgentScanT =
AgentScan<ScanPolicyT, InputIteratorT, OutputIteratorT, ScanOpT, RealInitValueT, OffsetT, AccumT, ForceInclusive>;

// Shared memory for AgentScan
__shared__ typename AgentScanT::TempStorage temp_storage;

RealInitValueT real_init_value = init_value;

// Process tiles
AgentScanT(temp_storage, d_in, d_out, scan_op, real_init_value).ConsumeRange(num_items, tile_state, start_tile);
}

CUB_NAMESPACE_END
Loading

0 comments on commit 3421002

Please sign in to comment.