-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract scan kernels into NVRTC-compilable header (#3334)
* Extract scan kernels into NVRTC-compilable header * Update cub/cub/device/dispatch/dispatch_scan.cuh Co-authored-by: Georgii Evtushenko <[email protected]> --------- Co-authored-by: Ashwin Srinath <[email protected]> Co-authored-by: Georgii Evtushenko <[email protected]>
- Loading branch information
1 parent
cda5501
commit 3421002
Showing
7 changed files
with
194 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,8 +53,6 @@ | |
|
||
#include <cuda/std/type_traits> | ||
|
||
#include <iterator> | ||
|
||
#include <nv/target> | ||
|
||
CUB_NAMESPACE_BEGIN | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* * Redistributions of source code must retain the above copyright | ||
* notice, this list of conditions and the following disclaimer. | ||
* * Redistributions in binary form must reproduce the above copyright | ||
* notice, this list of conditions and the following disclaimer in the | ||
* documentation and/or other materials provided with the distribution. | ||
* * Neither the name of the NVIDIA CORPORATION nor the | ||
* names of its contributors may be used to endorse or promote products | ||
* derived from this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
* | ||
******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include <cub/config.cuh> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <cub/agent/agent_scan.cuh> | ||
#include <cub/util_macro.cuh> | ||
|
||
CUB_NAMESPACE_BEGIN | ||
|
||
/****************************************************************************** | ||
* Kernel entry points | ||
*****************************************************************************/ | ||
|
||
/** | ||
* @brief Initialization kernel for tile status initialization (multi-block) | ||
* | ||
* @tparam ScanTileStateT | ||
* Tile status interface type | ||
* | ||
* @param[in] tile_state | ||
* Tile status interface | ||
* | ||
* @param[in] num_tiles | ||
* Number of tiles | ||
*/ | ||
template <typename ScanTileStateT> | ||
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles) | ||
{ | ||
// Initialize tile status | ||
tile_state.InitializeStatus(num_tiles); | ||
} | ||
|
||
/** | ||
* Initialization kernel for tile status initialization (multi-block) | ||
* | ||
* @tparam ScanTileStateT | ||
* Tile status interface type | ||
* | ||
* @tparam NumSelectedIteratorT | ||
* Output iterator type for recording the number of items selected | ||
* | ||
* @param[in] tile_state | ||
* Tile status interface | ||
* | ||
* @param[in] num_tiles | ||
* Number of tiles | ||
* | ||
* @param[out] d_num_selected_out | ||
* Pointer to the total number of items selected | ||
* (i.e., length of `d_selected_out`) | ||
*/ | ||
template <typename ScanTileStateT, typename NumSelectedIteratorT> | ||
CUB_DETAIL_KERNEL_ATTRIBUTES void | ||
DeviceCompactInitKernel(ScanTileStateT tile_state, int num_tiles, NumSelectedIteratorT d_num_selected_out) | ||
{ | ||
// Initialize tile status | ||
tile_state.InitializeStatus(num_tiles); | ||
|
||
// Initialize d_num_selected_out | ||
if ((blockIdx.x == 0) && (threadIdx.x == 0)) | ||
{ | ||
*d_num_selected_out = 0; | ||
} | ||
} | ||
|
||
/** | ||
* @brief Scan kernel entry point (multi-block) | ||
* | ||
* | ||
* @tparam ChainedPolicyT | ||
* Chained tuning policy | ||
* | ||
* @tparam InputIteratorT | ||
* Random-access input iterator type for reading scan inputs @iterator | ||
* | ||
* @tparam OutputIteratorT | ||
* Random-access output iterator type for writing scan outputs @iterator | ||
* | ||
* @tparam ScanTileStateT | ||
* Tile status interface type | ||
* | ||
* @tparam ScanOpT | ||
* Binary scan functor type having member | ||
* `auto operator()(const T &a, const U &b)` | ||
* | ||
* @tparam InitValueT | ||
* Initial value to seed the exclusive scan | ||
* (cub::NullType for inclusive scans) | ||
* | ||
* @tparam OffsetT | ||
* Unsigned integer type for global offsets | ||
* | ||
* @paramInput d_in | ||
* data | ||
* | ||
* @paramOutput d_out | ||
* data | ||
* | ||
* @paramTile tile_state | ||
* status interface | ||
* | ||
* @paramThe start_tile | ||
* starting tile for the current grid | ||
* | ||
* @paramBinary scan_op | ||
* scan functor | ||
* | ||
* @paramInitial init_value | ||
* value to seed the exclusive scan | ||
* | ||
* @paramTotal num_items | ||
* number of scan items for the entire problem | ||
*/ | ||
template <typename ChainedPolicyT, | ||
typename InputIteratorT, | ||
typename OutputIteratorT, | ||
typename ScanTileStateT, | ||
typename ScanOpT, | ||
typename InitValueT, | ||
typename OffsetT, | ||
typename AccumT, | ||
bool ForceInclusive> | ||
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS)) | ||
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel( | ||
InputIteratorT d_in, | ||
OutputIteratorT d_out, | ||
ScanTileStateT tile_state, | ||
int start_tile, | ||
ScanOpT scan_op, | ||
InitValueT init_value, | ||
OffsetT num_items) | ||
{ | ||
using RealInitValueT = typename InitValueT::value_type; | ||
using ScanPolicyT = typename ChainedPolicyT::ActivePolicy::ScanPolicyT; | ||
|
||
// Thread block type for scanning input tiles | ||
using AgentScanT = | ||
AgentScan<ScanPolicyT, InputIteratorT, OutputIteratorT, ScanOpT, RealInitValueT, OffsetT, AccumT, ForceInclusive>; | ||
|
||
// Shared memory for AgentScan | ||
__shared__ typename AgentScanT::TempStorage temp_storage; | ||
|
||
RealInitValueT real_init_value = init_value; | ||
|
||
// Process tiles | ||
AgentScanT(temp_storage, d_in, d_out, scan_op, real_init_value).ConsumeRange(num_items, tile_state, start_tile); | ||
} | ||
|
||
CUB_NAMESPACE_END |
Oops, something went wrong.