diff --git a/dali/kernels/slice/slice_hwc2chw_normalize_gpu.cu b/dali/kernels/slice/slice_hwc2chw_normalize_gpu.cu index 16c4bbc2728..b13bd25f283 100644 --- a/dali/kernels/slice/slice_hwc2chw_normalize_gpu.cu +++ b/dali/kernels/slice/slice_hwc2chw_normalize_gpu.cu @@ -17,6 +17,8 @@ #include "dali/core/backend_tags.h" #include "dali/core/convert.h" #include "dali/core/error_handling.h" +#include "dali/core/fast_div.h" +#include "dali/core/float16.h" #include "dali/core/permute.h" #include "dali/core/static_switch.h" #include "dali/kernels/slice/slice_hwc2chw_normalize_gpu.h" @@ -27,7 +29,7 @@ namespace kernels { namespace slice_flip_normalize { template -struct Hwc2ChwSampleDesc { +struct Hwc2HwcChwSampleDesc { Out *__restrict__ out; const In *__restrict__ in; const float *__restrict__ norm_add; @@ -46,8 +48,7 @@ struct Hwc2ChwSampleDesc { }; // TODO(klecki): Generalize the utility for binsearch indexing of thread blocks with Cast kernel. -inline __device__ uint32_t FindSampleIdx(const uint32_t *first_blocks, - uint32_t num_samples) { +inline __device__ uint32_t FindSampleIdx(const uint32_t *first_blocks, uint32_t num_samples) { uint32_t i = 0; for (uint32_t jump = (1 << (32 - __clz(num_samples) - 1)); jump; jump >>= 1) { if (i + jump < num_samples && first_blocks[i + jump] <= blockIdx.x) @@ -56,180 +57,140 @@ inline __device__ uint32_t FindSampleIdx(const uint32_t *first_blocks, return i; } -/** @defgroup Hwc2Chw The Slice Hwc2Chw Normalize Mirror-x Pad-channel kernel +/** @defgroup Hwc2HwcChwLoad Data loading for slice Hwc2{Hwc,Chw} Normalize Mirror-x Pad-channel + * kernel Load the data from linear chunk of HWC u8 image into a tile in shared memory. The loading + * loop consists of three stages: + * 1. Prologue - read from the start of the tile to the address that is multiple of 4 byte alignment + * 2. Main loop - read most of the tile via uchar4, utilizing 4-byte read instructions. + * 3. Epilogue - read the remainder of data that is not covered by the two previous loops. * - * Kernel that reads a HWC u8 image and outputs a CHW normalized float image, that can be cropped - * in Y, X coordinates, mirrored in X coordinate, and the channels can be padded. - * - * Overview of the kernel: - * The image is processed in flattened coordinates. The Y, X stays the same between the interleaved - * input layout and planar output layout. Assuming 3-channel input, we can look at the input as - * a sequential stream of values, where we distribute them (sequentially) into 3 output planes. - * Use a thread block size, that is divisible both by channel number (for the output loop), - * and 4 (for input loop). - * The processing steps: - * 1. [Input loop] Load the linear chunk of input into shared memory, utilizing 4-byte aligned loads - * and cast it to float. - * a. Unaligned prologue loop - reads the first chunk till we get to address that is aligned with - * 32 * 4. - * b. Main loop - do as many aligned 4byte reads as possible - * c. Epilogue loop - read the remaining values that were not possible to read as one 4byte read. - * 2. Synchronize - * 3. [Output loop] Each thread corresponds to a (Y, X) sequential offset into a plane, computes - * the values for all the channels and writes them. - * a. Optionally, mirroring is performed by inverting the X-coordinate in the output offset. - * b. Padding the output channels is performed by filling additional planes with fill values. - * @{ + * The slicing variant is addressed reads only the values required by the output, proceeding + * row by row, using the same pattern as above for each row. + * Samples are adjusted so that rows slices start at 0, and only the end of row is sliced. + * @{ */ /** - * @brief Hwc2Chw Normalize Mirror-x Pad-channel kernel - * This kernel does not support cropping the x coordinate, so the reads are fully linear. + * @brief Load the linear tile into linear smem buffer. + * + * @tparam kBlockSize Tile size + * @tparam kStaticChannels Number of input channels + * @tparam Tile Type of the data kept after loading in the smem tile. + * @tparam Out Output data type + * @tparam In Input data type + * @tparam kLoadAlign - Alignment (in bytes) of the main loop. The access to smem is also aligned + * to this value, so depending on the prologue length, the data after loading may not start + * at the tile[0]. The start of actual data is returned. + * The smem tile must hold at least kBlockSize + kLoadAlign elements. + * @param tile Shared memory where to load the data. + * @param sample Sample description + * @return Tile * - the pointer to the smem where the start of the loaded data is. */ -template -__global__ void Hwc2ChwNormalize(const Hwc2ChwSampleDesc *samples, uint32_t *first_blocks, - uint32_t num_samples) { - // TODO(klecki): generalize for wider input types - static_assert(std::is_same::value, "Only uint8_t supported as input"); +template +__device__ __forceinline__ Tile *load_linear_tile(Tile *tile, + const Hwc2HwcChwSampleDesc sample) { + static_assert(std::is_same_v, "Only uint8_t types allowed now."); + static_assert(kStaticChannels == 3, "Only 3 input channels allowed now."); + static_assert(kLoadAlign % 4 == 0, "The loading alignment should be divisible by 4."); - int sample_idx = FindSampleIdx(first_blocks, num_samples); - const auto sample = samples[sample_idx]; int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize; int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size); - __shared__ float tile[kBlockSize + 32 * 4]; - - // Preload the norm values so they are accessed via registers and not from gmem via pointer. - float norm_mul[kStaticChannels], norm_add[kStaticChannels]; - - #pragma unroll kStaticChannels - for (int c = 0; c < kStaticChannels; c++) { - norm_mul[c] = sample.norm_mul[c]; - norm_add[c] = sample.norm_add[c]; - } auto in_start = reinterpret_cast(sample.in + start_x); - auto aligned_in_start = align_up(in_start, 32 * 4); - auto bytes_skipped = - ::min(static_cast(aligned_in_start - in_start), end_x - start_x); + auto aligned_in_start = align_up(in_start, kLoadAlign); + + // In case if end_x - start_x < kLoadAlign, we never get to the aligned main loop + uint32_t bytes_to_alignment = ::min(aligned_in_start - in_start, end_x - start_x); - float *aligned_tile = tile + 32 * 4; - float *prologue_tile = aligned_tile - bytes_skipped; + Tile *aligned_tile = tile + kLoadAlign; + Tile *prologue_tile = aligned_tile - bytes_to_alignment; const In *prologue_in = sample.in + start_x; - const uchar4 *aligned_in_char4 = - reinterpret_cast(sample.in + start_x + bytes_skipped); + const uchar4 *aligned_in_uchar4 = + reinterpret_cast(sample.in + start_x + bytes_to_alignment); // prologue - for (int64_t idx = threadIdx.x; idx < bytes_skipped; idx += blockDim.x) { + for (uint32_t idx = threadIdx.x; idx < bytes_to_alignment; idx += blockDim.x) { prologue_tile[idx] = prologue_in[idx]; } - int64_t left_after_prologue = end_x - start_x - bytes_skipped; - // aligned load - for (int64_t idx = threadIdx.x; idx < left_after_prologue >> 2; idx += blockDim.x) { - uchar4 in = aligned_in_char4[idx]; + // this might be 0, as the prologue may be the full extend of the tile + uint32_t left_after_prologue = end_x - start_x - bytes_to_alignment; + + // We read 4 values in each iteration + uint32_t main_loop_length = left_after_prologue >> 2; + + // main loop: aligned load + for (uint32_t idx = threadIdx.x; idx < main_loop_length; idx += blockDim.x) { + uchar4 in = aligned_in_uchar4[idx]; aligned_tile[idx * 4 + 0] = in.x; aligned_tile[idx * 4 + 1] = in.y; aligned_tile[idx * 4 + 2] = in.z; aligned_tile[idx * 4 + 3] = in.w; } - int64_t processed_in_main = left_after_prologue & -4; // equivalent to (x / 4) * 4 - int64_t left_after_main = left_after_prologue - processed_in_main; + + uint32_t processed_in_main = left_after_prologue & -4; // equivalent to (x / 4) * 4 + uint32_t left_after_main = left_after_prologue - processed_in_main; // epilogue - float *epilogue_tile = aligned_tile + processed_in_main; - const In *epilogue_in = reinterpret_cast(aligned_in_char4 + (processed_in_main >> 2)); + Tile *epilogue_tile = aligned_tile + processed_in_main; + const In *epilogue_in = reinterpret_cast(aligned_in_uchar4 + main_loop_length); - for (int64_t idx = threadIdx.x; idx < left_after_main; idx++) { + for (uint32_t idx = threadIdx.x; idx < left_after_main; idx++) { epilogue_tile[idx] = epilogue_in[idx]; } - __syncthreads(); - const auto *__restrict__ fill_values = static_cast(sample.fill_values); - - // idx is not divided by the static channels (mostly the start_x) - for (int64_t idx = threadIdx.x + start_x / kStaticChannels, base_x = threadIdx.x; - idx < end_x / kStaticChannels; idx += blockDim.x, base_x += blockDim.x) { - // TODO(klecki): forceinline device function - int64_t out_offset; - if constexpr (enable_mirror) { - if (sample.flip_x) { - int y = idx / sample.W; - int x = idx - (int64_t)y * sample.W; - int target_x = sample.W - 1 - x; - out_offset = (int64_t)y * sample.W + target_x; - } else { - out_offset = idx; - } - } else { - out_offset = idx; - } - - #pragma unroll kStaticChannels - for (int c = 0; c < kStaticChannels; c++) { - // the kStaticChannels == input_C - float fpin = prologue_tile[base_x * sample.input_C + c]; - float fpout = fmaf(fpin, norm_mul[c], norm_add[c]); - sample.out[c * sample.H * sample.W + out_offset] = ConvertSat(fpout); - } - - if constexpr (enable_pad) { - for (int c = kStaticChannels; c < sample.C; c++) { - sample.out[c * sample.H * sample.W + out_offset] = fill_values[c]; - } - } - } + // Return the start of the tile + return prologue_tile; } /** - * @brief Slice Hwc2Chw Normalize Mirror-x Pad-channel kernel - * This kernel supports cropping in x-coordinate. - * It extends the input loop, by utilizing the (unaligned prologue, aligned main loop, epilogue) - * pattern in a row-by-row loop. - * Indexing is based on the output coordinates, specifically we read rows for coordinate X - * between 0 and output H. (The samples are shifted so they always start from 0 in X). + * @brief Load the slices of linear tile into linear smem buffer. + * + * The kernel proceeds row-by-row, reading the output width elements/pixels, skipping the remaining + * input_width - output_width pixels. + * + * @tparam kBlockSize Tile size + * @tparam kStaticChannels Number of input channels + * @tparam Tile Type of the data kept after loading in the smem tile. + * @tparam Out Output data type + * @tparam In Input data type + * @tparam kLoadAlign - Alignment (in bytes) of the main loop. + * The smem tile must hold at least kBlockSize + kLoadAlign elements. + * @param tile Shared memory where to load the data. + * @param sample Sample description + * @return Tile * - the pointer to the smem where the start of the loaded data is. */ -template -__global__ void SliceHwc2ChwNormalize(const Hwc2ChwSampleDesc *samples, - uint32_t *first_blocks, uint32_t num_samples) { - // TODO(klecki): generalize for wider input types - static_assert(std::is_same::value, "Only uint8_t supported as input"); - - int sample_idx = FindSampleIdx(first_blocks, num_samples); +template +__device__ __forceinline__ Tile *slice_load_linear_tile( + Tile *tile, const Hwc2HwcChwSampleDesc sample) { + static_assert(std::is_same_v, "Only uint8_t types allowed now."); + static_assert(kStaticChannels == 3, "Only 3 input channels allowed now."); + static_assert(kLoadAlign % 4 == 0, "The loading alignment should be divisible by 4."); - const auto sample = samples[sample_idx]; int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize; int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size); - __shared__ float tile[kBlockSize + 32 * 4]; - - // Preload the norm values so they are accessed via registers and not from gmem via pointer. - float norm_mul[kStaticChannels], norm_add[kStaticChannels]; - - #pragma unroll kStaticChannels - for (int c = 0; c < kStaticChannels; c++) { - norm_mul[c] = sample.norm_mul[c]; - norm_add[c] = sample.norm_add[c]; - } - // Strides use the input number of channels without the padding int in_stride = sample.input_W * sample.input_C; - int out_stride = sample.W * sample.input_C; + // this is intermediate stride, as if we were never padding the data, + // so it is useful for filling the linear tile, keeping the xy offset + int tile_stride = sample.W * sample.input_C; // The rows we start and end with, we are indexed by output coordinates - int y_start = start_x / out_stride; - int y_end = end_x / out_stride + 1; + int y_start = start_x / tile_stride; + int y_end = end_x / tile_stride + 1; - float *tile_row = tile; + Tile *tile_row = tile; for (int y = y_start; y < y_end; y++) { int xc_start, xc_end; // The first row doesn't start with 0 due to tiling, the rest do. if (y == y_start) { - xc_start = start_x - y_start * out_stride; + xc_start = start_x - y_start * tile_stride; } else { xc_start = 0; @@ -237,78 +198,113 @@ __global__ void SliceHwc2ChwNormalize(const Hwc2ChwSampleDesc *samples, // Similarly for the end of row for last row if (y == y_end - 1) { - xc_end = end_x - (y_end - 1) * out_stride; + xc_end = end_x - (y_end - 1) * tile_stride; } else { - xc_end = out_stride; + xc_end = tile_stride; } const In *prologue_in = sample.in + y * in_stride + xc_start; auto in_start = reinterpret_cast(prologue_in); // align to 4 - auto aligned_in_start = align_up(in_start, 4); - auto bytes_skipped = + auto aligned_in_start = align_up(in_start, kLoadAlign); + uint32_t bytes_to_alignment = ::min(static_cast(aligned_in_start - in_start), xc_end - xc_start); - float *prologue_tile = tile_row; - float *aligned_tile = tile_row + bytes_skipped; + Tile *prologue_tile = tile_row; + Tile *aligned_tile = tile_row + bytes_to_alignment; - const uchar4 *aligned_in_char4 = reinterpret_cast(prologue_in + bytes_skipped); + const uchar4 *aligned_in_uchar4 = + reinterpret_cast(prologue_in + bytes_to_alignment); // prologue - for (int64_t idx = threadIdx.x; idx < bytes_skipped; idx += blockDim.x) { + for (uint32_t idx = threadIdx.x; idx < bytes_to_alignment; idx += blockDim.x) { prologue_tile[idx] = prologue_in[idx]; } - int64_t left_after_prologue = xc_end - xc_start - bytes_skipped; + + // this might be 0, as the prologue may be the full extend of the tile + uint32_t left_after_prologue = xc_end - xc_start - bytes_to_alignment; + + // We read 4 values in each iteration + uint32_t main_loop_length = left_after_prologue >> 2; // aligned load - for (int64_t idx = threadIdx.x; idx < left_after_prologue >> 2; idx += blockDim.x) { - uchar4 in = aligned_in_char4[idx]; + for (uint32_t idx = threadIdx.x; idx < main_loop_length; idx += blockDim.x) { + uchar4 in = aligned_in_uchar4[idx]; aligned_tile[idx * 4 + 0] = in.x; aligned_tile[idx * 4 + 1] = in.y; aligned_tile[idx * 4 + 2] = in.z; aligned_tile[idx * 4 + 3] = in.w; } - int64_t processed_in_main = left_after_prologue & -4; // equivalent to (x / 4) * 4 - int64_t left_after_main = left_after_prologue - processed_in_main; + uint32_t processed_in_main = left_after_prologue & -4; // equivalent to (x / 4) * 4 + uint32_t left_after_main = left_after_prologue - processed_in_main; // epilogue - float *epilogue_tile = aligned_tile + processed_in_main; - const In *epilogue_in = - reinterpret_cast(aligned_in_char4 + (processed_in_main >> 2)); + Tile *epilogue_tile = aligned_tile + processed_in_main; + const In *epilogue_in = reinterpret_cast(aligned_in_uchar4 + main_loop_length); - for (int64_t idx = threadIdx.x; idx < left_after_main; idx++) { + for (uint32_t idx = threadIdx.x; idx < left_after_main; idx++) { epilogue_tile[idx] = epilogue_in[idx]; } tile_row += (xc_end - xc_start); } + return tile; +} + +/** @} */ // end of Hwc2HwcChwLoad + + +/** @defgroup Hwc2HwcChwStore Data storing for slice Hwc2{Hwc,Chw} Normalize Mirror-x Pad-channel + * kernel + * @{ + */ + +/** + * @brief Calculate the planar output offset to take optional mirroring into account. + */ +template +__device__ __forceinline__ int64_t +calculate_offset_chw(int64_t planar_idx, const Hwc2HwcChwSampleDesc sample) { + if constexpr (enable_mirror) { + if (sample.flip_x) { + int y = planar_idx / sample.W; + int x = planar_idx - (int64_t)y * sample.W; + int target_x = sample.W - 1 - x; + return (int64_t)y * sample.W + target_x; + } + } + return planar_idx; +} + +template +__device__ __forceinline__ void store_chw(Tile *tile, const Hwc2HwcChwSampleDesc sample) { + int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize; + int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size); - __syncthreads(); const auto *__restrict__ fill_values = static_cast(sample.fill_values); + // Preload the norm values so they are accessed via registers and not from gmem via pointer. + Compute norm_mul[kStaticChannels], norm_add[kStaticChannels]; + +#pragma unroll kStaticChannels + for (int c = 0; c < kStaticChannels; c++) { + norm_mul[c] = sample.norm_mul[c]; + norm_add[c] = sample.norm_add[c]; + } + + // idx is not divided by the static channels (mostly the start_x) for (int64_t idx = threadIdx.x + start_x / kStaticChannels, base_x = threadIdx.x; idx < end_x / kStaticChannels; idx += blockDim.x, base_x += blockDim.x) { - int64_t out_offset; - if constexpr (enable_mirror) { - if (sample.flip_x) { - int y = idx / sample.W; - int x = idx - (int64_t)y * sample.W; - int target_x = sample.W - 1 - x; - out_offset = (int64_t)y * sample.W + target_x; - } else { - out_offset = idx; - } - } else { - out_offset = idx; - } + int64_t out_offset = calculate_offset_chw(idx, sample); - #pragma unroll kStaticChannels +#pragma unroll kStaticChannels for (int c = 0; c < kStaticChannels; c++) { // the kStaticChannels == input_C - float fpin = tile[base_x * sample.input_C + c]; - float fpout = fmaf(fpin, norm_mul[c], norm_add[c]); + Compute fpin = tile[base_x * sample.input_C + c]; + Compute fpout = fmaf(fpin, norm_mul[c], norm_add[c]); sample.out[c * sample.H * sample.W + out_offset] = ConvertSat(fpout); } @@ -320,12 +316,220 @@ __global__ void SliceHwc2ChwNormalize(const Hwc2ChwSampleDesc *samples, } } -/** @} */ // end of Hwc2Chw +template +__device__ __forceinline__ int divide_by_channel(int xc) { + if constexpr (kOutChannels == 3) { + return xc / kOutChannels; + } + return xc >> 2; +} + +/** + * @brief Calculate the flat output offset for interleaved images to take optional mirroring into + * account. + */ +template +__device__ __forceinline__ int64_t +calculate_offset_hwc(int64_t flat_idx, int c, const Hwc2HwcChwSampleDesc sample) { + constexpr int kOutChannels = enable_pad ? 4 : 3; + if constexpr (enable_mirror) { + if (sample.flip_x) { + int y = flat_idx / (sample.W * kOutChannels); + int xc = flat_idx - (int64_t)y * sample.W * kOutChannels; + int x = divide_by_channel(xc); + int target_x = sample.W - 1 - x; + return (int64_t)y * sample.W * kOutChannels + target_x * kOutChannels + c; + } + } + return flat_idx; +} + +// TODO(klecki): Prepare a generic version that supports the planar layout in smem and evaluate. +template +__device__ __forceinline__ void store_hwc(Tile *tile, const Hwc2HwcChwSampleDesc sample) { + int64_t start_x = (blockIdx.x - sample.first_block) * kBlockSize; + int64_t end_x = ::min(start_x + kBlockSize, sample.sample_size); + + const auto *__restrict__ fill_values = static_cast(sample.fill_values); + + // Preload the norm values so they are accessed via registers and not from gmem via pointer. + Compute norm_mul[kStaticChannels], norm_add[kStaticChannels]; + +#pragma unroll kStaticChannels + for (int c = 0; c < kStaticChannels; c++) { + norm_mul[c] = sample.norm_mul[c]; + norm_add[c] = sample.norm_add[c]; + } + + // Assuming all samples are padded + if constexpr (enable_pad) { + constexpr int kOutChannels = kStaticChannels + 1; + int64_t block_4 = (kBlockSize / kStaticChannels) * kOutChannels; + int64_t sample_size_4 = (sample.sample_size / kStaticChannels) * kOutChannels; + int64_t start_x_padded = static_cast(blockIdx.x - sample.first_block) * block_4; + int64_t end_x_padded = ::min(start_x_padded + block_4, sample_size_4); + + for (int64_t idx = threadIdx.x + start_x_padded, base_x = threadIdx.x; idx < end_x_padded; + idx += blockDim.x, base_x += blockDim.x) { + int base_offset = base_x >> 2; + int c = idx & 3; + + int64_t out_offset = calculate_offset_hwc(idx, c, sample); + + if (c < kStaticChannels) { + Compute fpin = tile[base_offset * sample.input_C + c]; + Compute fpout = fma(fpin, norm_mul[c], norm_add[c]); + sample.out[out_offset] = ConvertSat(fpout); + } else { + sample.out[out_offset] = fill_values[c]; + } + } + } else { + // No padding, we just with the same offset (or mirrored x offset) + fast_div channels(kStaticChannels); + for (int64_t idx = threadIdx.x + start_x, base_x = threadIdx.x; idx < end_x; + idx += blockDim.x, base_x += blockDim.x) { + int c = idx % channels; + + int64_t out_offset = calculate_offset_hwc(idx, c, sample); + + Compute fpin = tile[base_x]; + Compute fpout = fma(fpin, norm_mul[c], norm_add[c]); + sample.out[out_offset] = ConvertSat(fpout); + } + } +} + + +/** @} */ // end of Hwc2HwcChwStore + +/** @defgroup Hwc2HwcChw The Slice Hwc2{Hwc,Chw} Normalize Mirror-x Pad-channel kernel + * + * Kernel that reads a HWC u8 image and outputs a HWC or CHW normalized float image, that can be + * cropped in Y, X coordinates, mirrored in X coordinate, and the channels can be padded. + * + * High level structure of the kernel: + * 1. Load tile of linear data from the image into shared memory, doing a cast to floating type. + * a. Note, that the tile in shared memory can be represented either as an linear chunk with + * interleaved channels or as separate channel planes. See the loading functions for details. + * b. Each thread in loader loop maps to one value of the loaded image. + * c. Tile in shared memory doesn't take the padded channels into account, it stores only the + * input channels. + * 2. Synchronize + * 3. Output the data in correct layout, reading from the shared memory. + * a. For CHW output each thread corresponds to a (Y, X) sequential offset into a plane, computes + * the values for all the channels and writes them. Assuming 3-channel input, we can look + * at the input as a sequential stream of values, where we distribute them (sequentially) + * into 3 output planes. + * b. Padding the output channels for CHW is done by filling additional planes with fill values. + * c. For HWC output, in the simples case we can store the linear tile in the same order + * as it was read. In case of padding, fill values must be inserted. + * d. Mirroring is done by swapping the X-coordinate and recomputing the target offset for both + * layouts. + * + * The kernel use a thread block size, that is divisible both by channel number: 3 (for the + * non-padded output loop), and 4 (alignment for input loop and padded output loop). + * + * For better throughput, the read and write accesses to global memory are sequential, + * using aligned 4-byte-wide access when possible. + * @{ + */ + +// TODO(klecki): generalize for wider input types + +/** + * @brief Hwc2HwcChw Normalize Mirror-x Pad-channel kernel + * This kernel does not support cropping the x coordinate, so the reads are fully linear. + */ +template +__global__ void Hwc2HwcChwNormalize(const Hwc2HwcChwSampleDesc *samples, + uint32_t *first_blocks, uint32_t num_samples) { + static_assert(std::is_same::value, "Only uint8_t supported as input"); + + int sample_idx = FindSampleIdx(first_blocks, num_samples); + const auto sample = samples[sample_idx]; + + __shared__ float tile[kBlockSize + 32 * 4]; + + float *loaded_tile = load_linear_tile(tile, sample); + + __syncthreads(); + + store_chw(loaded_tile, sample); +} + +/** + * @brief Slice Hwc2HwcChw Normalize [Mirror-x] [Pad-channel] kernel + * This kernel supports cropping in x-coordinate. + */ +template +__global__ void SliceHwc2HwcChwNormalize(const Hwc2HwcChwSampleDesc *samples, + uint32_t *first_blocks, uint32_t num_samples) { + static_assert(std::is_same::value, "Only uint8_t supported as input"); + + int sample_idx = FindSampleIdx(first_blocks, num_samples); + const auto sample = samples[sample_idx]; + + __shared__ float tile[kBlockSize + 32 * 4]; + float *loaded_tile = slice_load_linear_tile(tile, sample); + + __syncthreads(); + + store_chw(loaded_tile, sample); +} + +/** + * @brief Hwc2Hwc Normalize [Mirror-x] [Pad-channel] kernel + * This kernel does not support cropping the x coordinate, so the reads are fully linear. + */ +template +__global__ void Hwc2HwcNormalize(const Hwc2HwcChwSampleDesc *samples, + uint32_t *first_blocks, uint32_t num_samples) { + static_assert(std::is_same::value, "Only uint8_t supported as input"); + + int sample_idx = FindSampleIdx(first_blocks, num_samples); + const auto sample = samples[sample_idx]; + + __shared__ float tile[kBlockSize + 32 * 4]; + float *loaded_tile = load_linear_tile(tile, sample); + + __syncthreads(); + + store_hwc(loaded_tile, sample); +} + +/** + * @brief Slice Hwc2Hwc Normalize [Mirror-x] [Pad-channel] kernel + * This kernel supports cropping in x-coordinate. + */ +template +__global__ void SliceHwc2HwcNormalize(const Hwc2HwcChwSampleDesc *samples, + uint32_t *first_blocks, uint32_t num_samples) { + static_assert(std::is_same::value, "Only uint8_t supported as input"); + + int sample_idx = FindSampleIdx(first_blocks, num_samples); + const auto sample = samples[sample_idx]; + + __shared__ float tile[kBlockSize + 32 * 4]; + float *loaded_tile = slice_load_linear_tile(tile, sample); + + __syncthreads(); + + store_hwc(loaded_tile, sample); +} + +/** @} */ // end of Hwc2HwcChw template -KernelRequirements SliceHwc2ChwNormalizeGPU::Setup(KernelContext &ctx, - const TensorListShape &input_shape, - span args) { +KernelRequirements SliceHwc2HwcChwNormalizeGPU::Setup(KernelContext &ctx, + const TensorListShape &input_shape, + span args, + TensorLayout output_layout) { (void)ctx; int num_samples = input_shape.num_samples(); DALI_ENFORCE(num_samples == static_cast(args.size()), @@ -333,7 +537,12 @@ KernelRequirements SliceHwc2ChwNormalizeGPU::Setup(KernelContext &ctx, out_shape_ = TensorListShape(num_samples, ndim); collapsed_tiling_shape_ = TensorListShape<1>(num_samples, 1); + perm_ = output_layout == "HWC" ? std::array{0, 1, 2} : std::array{2, 0, 1}; + output_layout_ = output_layout; + SetupNumChannels(input_shape, args); + DALI_ENFORCE(output_layout == "HWC" || output_layout == "CHW", + "Only CHW and HWC output layouts allowed"); for (int i = 0; i < num_samples; i++) { // N.B. this function produces a HWC shape, that's why we need the permute @@ -353,7 +562,7 @@ KernelRequirements SliceHwc2ChwNormalizeGPU::Setup(KernelContext &ctx, } template -std::tuple SliceHwc2ChwNormalizeGPU::SetupParams( +std::tuple SliceHwc2HwcChwNormalizeGPU::SetupParams( KernelContext &ctx, span args) { int num_samples = args.size(); float *norm_add_cpu = ctx.scratchpad->AllocatePinned(num_samples * nchannels_); @@ -390,8 +599,8 @@ std::tuple SliceHwc2ChwNormalizeGPU::SetupParams( } template -auto SliceHwc2ChwNormalizeGPU::RealignSample(TensorView in_sample, - Roi roi) +auto SliceHwc2HwcChwNormalizeGPU::RealignSample( + TensorView in_sample, Roi roi) -> std::tuple, Roi> { const auto *data = in_sample.data; auto shape = in_sample.shape; @@ -407,8 +616,8 @@ auto SliceHwc2ChwNormalizeGPU::RealignSample(TensorView -void SliceHwc2ChwNormalizeGPU::SetupNumChannels(const TensorListShape &input_shape, - span args) { +void SliceHwc2HwcChwNormalizeGPU::SetupNumChannels(const TensorListShape &input_shape, + span args) { if (input_shape.num_samples() == 0) { return; } @@ -422,22 +631,29 @@ void SliceHwc2ChwNormalizeGPU::SetupNumChannels(const TensorListShape } DALI_ENFORCE( input_shape.num_samples() == static_cast(args.size()), - "Number of samples in the arguments should match the number of samples in the shape"); + "Number of samples in the arguments should match the number of samples in the shape."); out_nchannels_ = std::max(nchannels_, static_cast(args[0].fill_values.size())); for (int i = 1; i < input_shape.num_samples(); i++) { DALI_ENFORCE(args[i].fill_values.size() == args[0].fill_values.size(), - "All sample arguments should have the same number of fill values"); + "All sample arguments should have the same number of fill values."); + } + DALI_ENFORCE(nchannels_ == kStaticChannels, "Only 3 input channels are supported."); + if (output_layout_ == "HWC") { + // Padding in the operator cannot go higher than the closest power of 2, + // but better have the check in place. + DALI_ENFORCE(out_nchannels_ == kStaticChannels || out_nchannels_ == kStaticChannels + 1, + "Only 3 or 4 output channels are supported for HWC output layout."); } } template -void SliceHwc2ChwNormalizeGPU::Run(KernelContext &ctx, - const TensorListView &out, - const TensorListView &in, - span args) { - using SampleDesc = Hwc2ChwSampleDesc; +void SliceHwc2HwcChwNormalizeGPU::Run(KernelContext &ctx, + const TensorListView &out, + const TensorListView &in, + span args) { + using SampleDesc = Hwc2HwcChwSampleDesc; int num_samples = in.num_samples(); SampleDesc *sample_descs_cpu = ctx.scratchpad->AllocatePinned(num_samples); @@ -473,9 +689,15 @@ void SliceHwc2ChwNormalizeGPU::Run(KernelContext &ctx, offset_blk += div_ceil(sample_size, kBlockSizeMul * kBlockWidth); // The output shape here is after the permutation - sample_desc.H = out.tensor_shape(sample_id)[1]; - sample_desc.W = out.tensor_shape(sample_id)[2]; - sample_desc.C = out.tensor_shape(sample_id)[0]; // out_nchannels_ + if (output_layout_ == "CHW") { + sample_desc.H = out.tensor_shape(sample_id)[1]; + sample_desc.W = out.tensor_shape(sample_id)[2]; + sample_desc.C = out.tensor_shape(sample_id)[0]; // out_nchannels_ + } else { + sample_desc.H = out.tensor_shape(sample_id)[0]; + sample_desc.W = out.tensor_shape(sample_id)[1]; + sample_desc.C = out.tensor_shape(sample_id)[2]; // out_nchannels_ + } sample_desc.input_W = in_sample.shape[1]; sample_desc.input_C = in_sample.shape[2]; // nchannels_ @@ -495,39 +717,72 @@ void SliceHwc2ChwNormalizeGPU::Run(KernelContext &ctx, ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, make_span(sample_descs_cpu, nonempty_samples), make_span(first_blocks_cpu, nonempty_samples)); - auto dispatch = [samples = sample_descs_gpu, blocks = first_blocks_gpu, &ctx, need_crop_x, - offset_blk, nonempty_samples](auto pad_v, auto flip_x_v) { - if (need_crop_x) { - SliceHwc2ChwNormalize - <<>>(samples, blocks, nonempty_samples); + // TODO(klecki): Maybe this selection can be simplified, but making the output layout + // a parameter would probably make it even less readable. + // This version allows utilizing specialized implementations for every layout more easily. + if (output_layout_ == "CHW") { + auto dispatch = [samples = sample_descs_gpu, blocks = first_blocks_gpu, &ctx, need_crop_x, + offset_blk, nonempty_samples](auto pad_v, auto flip_x_v) { + if (need_crop_x) { + SliceHwc2HwcChwNormalize + <<>>(samples, blocks, + nonempty_samples); + } else { + Hwc2HwcChwNormalize<<>>( + samples, blocks, nonempty_samples); + } + }; + + auto dispatch_flip = [&](auto pad_v, bool flip_x) { + if (flip_x) { + dispatch(pad_v, std::true_type{}); + } else { + dispatch(pad_v, std::false_type{}); + } + }; + + if (need_pad) { + dispatch_flip(std::true_type{}, need_flip_x); } else { - Hwc2ChwNormalize - <<>>(samples, blocks, nonempty_samples); + dispatch_flip(std::false_type{}, need_flip_x); } - }; + } else { + auto dispatch = [samples = sample_descs_gpu, blocks = first_blocks_gpu, &ctx, need_crop_x, + offset_blk, nonempty_samples](auto pad_v, auto flip_x_v) { + if (need_crop_x) { + SliceHwc2HwcNormalize<<>>( + samples, blocks, nonempty_samples); + } else { + Hwc2HwcNormalize<<>>( + samples, blocks, nonempty_samples); + } + }; + + auto dispatch_flip = [&](auto pad_v, bool flip_x) { + if (flip_x) { + dispatch(pad_v, std::true_type{}); + } else { + dispatch(pad_v, std::false_type{}); + } + }; - auto dispatch_flip = [&](auto pad_v, bool flip_x) { - if (flip_x) { - dispatch(pad_v, std::true_type{}); + if (need_pad) { + dispatch_flip(std::true_type{}, need_flip_x); } else { - dispatch(pad_v, std::false_type{}); + dispatch_flip(std::false_type{}, need_flip_x); } - }; - - if (need_pad) { - dispatch_flip(std::true_type{}, need_flip_x); - } else { - dispatch_flip(std::false_type{}, need_flip_x); } CUDA_CALL(cudaGetLastError()); } -template class DLL_PUBLIC SliceHwc2ChwNormalizeGPU; -template class DLL_PUBLIC SliceHwc2ChwNormalizeGPU; +template class DLL_PUBLIC SliceHwc2HwcChwNormalizeGPU; +template class DLL_PUBLIC SliceHwc2HwcChwNormalizeGPU; } // namespace slice_flip_normalize diff --git a/dali/kernels/slice/slice_hwc2chw_normalize_gpu.h b/dali/kernels/slice/slice_hwc2chw_normalize_gpu.h index df5f495e1f8..8e4f0328a71 100644 --- a/dali/kernels/slice/slice_hwc2chw_normalize_gpu.h +++ b/dali/kernels/slice/slice_hwc2chw_normalize_gpu.h @@ -21,6 +21,7 @@ #include "dali/core/common.h" #include "dali/core/geom/vec.h" #include "dali/core/small_vector.h" +#include "dali/core/tensor_layout.h" #include "dali/core/tensor_shape.h" #include "dali/kernels/common/block_setup.h" #include "dali/kernels/imgproc/roi.h" @@ -32,19 +33,40 @@ namespace kernels { namespace slice_flip_normalize { /** - * @brief Specialized version of SliceFlipNormalize for HWC->CHW conversion and normalization. + * @brief Specialized version of SliceFlipNormalize that reads a HWC u8 image (with 3 channels) + * and outputs a HWC or CHW normalized float image, that can be cropped in Y, X coordinates, + * mirrored in X coordinate, and the channels can be padded. * - * Optionally allows for cropping the input in y, x (HW) coordinates, flipping in x (W) coordinate - * and padding the channels to the multiple of 2. + * Cropping the input in y, x (HW) coordinates, flipping in x (W) coordinate + * and padding the channels (from 3 to 4 in HWC->HWC variant) are optional, optimized implementation + * will be selected when those features are not used across the batch. * - * The input is assumed to be u8. + * Overview of the kernel: + * The image is processed in flattened coordinates. The Y, X stays the same between the interleaved + * input layout and planar output layout. Assuming 3-channel input, we can look at the input as + * a sequential stream of values, where we distribute them (sequentially) into 3 output planes. + * Use a thread block size, that is divisible both by channel number (for the output loop), + * and 4 (for input loop). + * The processing steps: + * 1. [Input loop] Load the linear chunk of input into shared memory, utilizing 4-byte aligned loads + * and cast it to float. + * a. Unaligned prologue loop - reads the first chunk till we get to address that is aligned with + * 32 * 4. + * b. Main loop - do as many aligned 4byte reads as possible + * c. Epilogue loop - read the remaining values that were not possible to read as one 4byte read. + * 2. Synchronize + * 3. [Output loop] Each thread corresponds to a (Y, X) sequential offset into a plane, computes + * the values for all the channels and writes them. + * a. Optionally, mirroring is performed by inverting the X-coordinate in the output offset. + * b. Padding the output channels is performed by filling additional planes with fill values. * - * @tparam Out output type + * + * @tparam Out output type - fp16 and fp32 allowed. * * @details see SliceFlipNormalizeGPU::Args */ template -class DLL_PUBLIC SliceHwc2ChwNormalizeGPU { +class DLL_PUBLIC SliceHwc2HwcChwNormalizeGPU { public: static constexpr int spatial_dim = 2; static constexpr int channel_dim = 2; @@ -60,13 +82,13 @@ class DLL_PUBLIC SliceHwc2ChwNormalizeGPU { bool flip_x; // wether to mirror in x-axis, EnableFlipX must be true. }; - SliceHwc2ChwNormalizeGPU() = default; + SliceHwc2HwcChwNormalizeGPU() = default; - ~SliceHwc2ChwNormalizeGPU() = default; + ~SliceHwc2HwcChwNormalizeGPU() = default; DLL_PUBLIC KernelRequirements Setup(KernelContext &ctx, const TensorListShape &input_shape, - span args); + span args, TensorLayout output_layout); void Run(KernelContext &ctx, const TensorListView &out, const TensorListView &in, span args); @@ -115,8 +137,9 @@ class DLL_PUBLIC SliceHwc2ChwNormalizeGPU { int nchannels_ = -1; // number of channels in the output image (in case of padding) int out_nchannels_ = -1; - // HWC -> CHW permutation - static constexpr std::array perm_ = {2, 0, 1}; + // HWC -> {CHW, HWC} permutation + std::array perm_; + TensorLayout output_layout_; }; } // namespace slice_flip_normalize diff --git a/dali/operators/image/crop/new_crop_mirror_normalize.cu b/dali/operators/image/crop/new_crop_mirror_normalize.cu index b89e91d9e9c..27296da65ed 100644 --- a/dali/operators/image/crop/new_crop_mirror_normalize.cu +++ b/dali/operators/image/crop/new_crop_mirror_normalize.cu @@ -71,7 +71,7 @@ class NewCropMirrorNormalizeGPU : public Operator { protected: enum class CmnImplKind { SliceFlipNormalizeGpuGeneric, - SliceHwc2ChwNormalize, + SliceHwc2HwcChwNormalize, FallbackGeneric }; @@ -88,24 +88,24 @@ class NewCropMirrorNormalizeGPU : public Operator { } // check for optimized version if (in_type == DALI_UINT8 && (out_type == DALI_FLOAT || out_type == DALI_FLOAT16) && - in_layout == "HWC" && out_layout == "CHW" && + in_layout == "HWC" && (out_layout == "CHW" || out_layout == "HWC") && (oobp == OutOfBoundsPolicy::Error || oobp == OutOfBoundsPolicy::TrimToShape)) { - // Only 3-channels supported in this version if (in_shape.num_samples() > 0 && in_shape.tensor_shape_span(0)[2] == 3) - return CmnImplKind::SliceHwc2ChwNormalize; + return CmnImplKind::SliceHwc2HwcChwNormalize; } return CmnImplKind::SliceFlipNormalizeGpuGeneric; } - bool SetupSliceHwc2ChwNormalize(std::vector &output_desc, const Workspace &ws) { + bool SetupSliceHwc2HwcChwNormalize(std::vector &output_desc, const Workspace &ws) { TYPE_SWITCH(output_type_, type2id, OutputType, (float, float16), ( - return SetupSliceHwc2ChwNormalizeTyped(output_desc, ws); + return SetupSliceHwc2HwcChwNormalizeTyped(output_desc, ws); ), DALI_FAIL(make_string("Unsupported output type:", output_type_));); // NOLINT } template - bool SetupSliceHwc2ChwNormalizeTyped(std::vector &output_desc, const Workspace &ws) { - using Kernel = kernels::slice_flip_normalize::SliceHwc2ChwNormalizeGPU; + bool SetupSliceHwc2HwcChwNormalizeTyped(std::vector &output_desc, + const Workspace &ws) { + using Kernel = kernels::slice_flip_normalize::SliceHwc2HwcChwNormalizeGPU; if (!kernel_args_.has_value()) kernel_args_ = std::vector{}; auto &args = any_cast &>(kernel_args_); @@ -141,7 +141,7 @@ class NewCropMirrorNormalizeGPU : public Operator { // const auto &req = k.Setup(ctx, sh, cargs); // // k.test(); auto cargs = make_cspan(args); - auto &req = kmgr_.Setup(0, ctx, sh, cargs); + auto &req = kmgr_.Setup(0, ctx, sh, cargs, output_layout_); output_desc[0].type = output_type_; output_desc[0].shape = req.output_shapes[0]; return true; @@ -259,19 +259,19 @@ class NewCropMirrorNormalizeGPU : public Operator { return SetupSfnGpuGeneric(output_desc, ws); } - return SetupSliceHwc2ChwNormalize(output_desc, ws); + return SetupSliceHwc2HwcChwNormalize(output_desc, ws); } - void RunSliceHwc2ChwNormalize(Workspace &ws) { + void RunSliceHwc2HwcChwNormalize(Workspace &ws) { if (output_type_ == DALI_FLOAT) { - using Kernel = kernels::slice_flip_normalize::SliceHwc2ChwNormalizeGPU; + using Kernel = kernels::slice_flip_normalize::SliceHwc2HwcChwNormalizeGPU; auto &args = any_cast &>(kernel_args_); auto cargs = make_cspan(args); RunSfnKernel(ws, cargs); return; } else if (output_type_ == DALI_FLOAT16) { - using Kernel = kernels::slice_flip_normalize::SliceHwc2ChwNormalizeGPU; + using Kernel = kernels::slice_flip_normalize::SliceHwc2HwcChwNormalizeGPU; auto &args = any_cast &>(kernel_args_); auto cargs = make_cspan(args); @@ -322,7 +322,7 @@ class NewCropMirrorNormalizeGPU : public Operator { return; } - RunSliceHwc2ChwNormalize(ws); + RunSliceHwc2HwcChwNormalize(ws); } bool CanInferOutputs() const override { diff --git a/dali/test/python/operator_1/test_crop_mirror_normalize.py b/dali/test/python/operator_1/test_crop_mirror_normalize.py index 23b7f3d9eba..5fd5731bcf5 100644 --- a/dali/test/python/operator_1/test_crop_mirror_normalize.py +++ b/dali/test/python/operator_1/test_crop_mirror_normalize.py @@ -684,10 +684,11 @@ def test_crop_mirror_normalize_empty_layout(): pads = [False, True] mirrors = [False, True] crops = [(1.0, 0.25), (0.25, 0.25), (0.25, 1.0), (0.5, 0.75), (None, None)] +layouts = ["HWC", "CHW"] -@params(*itertools.product(batch_sizes, shapes, dtypes, pads, mirrors, crops)) -def test_cmn_optimized_vs_cpu(batch_size, shape, dtype, pad, mirror, crops): +@params(*itertools.product(batch_sizes, shapes, dtypes, pads, mirrors, crops, layouts)) +def test_cmn_optimized_vs_cpu(batch_size, shape, dtype, pad, mirror, crops, layout): @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) def pipe(device): @@ -706,7 +707,8 @@ def get_data(): return fn.crop_mirror_normalize(data, device=device, dtype=dtype, pad_output=pad, mirror=mirror, crop_h=crop_h_int, crop_w=crop_w_int, mean=[0.1, 0.2, 0.3], - fill_values=[0.0, 0.0, 0.0, 42.0] if pad else None) + fill_values=[0.0, 0.0, 0.0, 42.0] if pad else None, + output_layout=layout) pipe_baseline = pipe("cpu") pipe_opt = pipe("gpu")