Skip to content

Commit

Permalink
Compute and use the initial string offset when building nested larg…
Browse files Browse the repository at this point in the history
…e string cols with chunked parquet reader (#17702)

Closes #17692.

This PR enables computing the `str_offset` required to correctly compute the offsets columns for nested large strings columns with chunked Parquet reader when `chunk_read_limit` is small resulting in multiple output table chunks per subpass.

Authors:
  - Muhammad Haseeb (https://github.com/mhaseeb123)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Ed Seidl (https://github.com/etseidl)
  - Vukasin Milovanovic (https://github.com/vuule)

URL: #17702
  • Loading branch information
mhaseeb123 authored Jan 28, 2025
1 parent 03e1f64 commit e0fe51d
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 53 deletions.
5 changes: 5 additions & 0 deletions cpp/benchmarks/io/parquet/parquet_reader_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ void BM_parquet_read_long_strings(nvbench::state& state)
cycle_dtypes(d_type, num_cols), table_size_bytes{data_size}, profile); // THIS
auto const view = tbl->view();

// set smaller threshold to reduce file size and execution time
auto const threshold = 1;
setenv("LIBCUDF_LARGE_STRINGS_THRESHOLD", std::to_string(threshold).c_str(), 1);

cudf::io::parquet_writer_options write_opts =
cudf::io::parquet_writer_options::builder(source_sink.make_sink_info(), view)
.compression(compression);
Expand All @@ -129,6 +133,7 @@ void BM_parquet_read_long_strings(nvbench::state& state)
}();

parquet_read_common(num_rows_written, num_cols, source_sink, state);
unsetenv("LIBCUDF_LARGE_STRINGS_THRESHOLD");
}

template <data_type DataType>
Expand Down
10 changes: 7 additions & 3 deletions cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -255,12 +255,14 @@ static sizes_to_offsets_iterator<ScanIterator, LastType> make_sizes_to_offsets_i
* @param begin Input iterator for scan
* @param end End of the input iterator
* @param result Output iterator for scan result
* @param initial_offset Initial offset to add to scan
* @return The last element of the scan
*/
template <typename SizesIterator, typename OffsetsIterator>
auto sizes_to_offsets(SizesIterator begin,
SizesIterator end,
OffsetsIterator result,
int64_t initial_offset,
rmm::cuda_stream_view stream)
{
using SizeType = typename thrust::iterator_traits<SizesIterator>::value_type;
Expand All @@ -273,7 +275,8 @@ auto sizes_to_offsets(SizesIterator begin,
make_sizes_to_offsets_iterator(result, result + std::distance(begin, end), last_element.data());
// This function uses the type of the initialization parameter as the accumulator type
// when computing the individual scan output elements.
thrust::exclusive_scan(rmm::exec_policy(stream), begin, end, output_itr, LastType{0});
thrust::exclusive_scan(
rmm::exec_policy_nosync(stream), begin, end, output_itr, static_cast<LastType>(initial_offset));
return last_element.value(stream);
}

Expand Down Expand Up @@ -319,7 +322,8 @@ std::pair<std::unique_ptr<column>, size_type> make_offsets_child_column(
});
auto input_itr = cudf::detail::make_counting_transform_iterator(0, map_fn);
// Use the sizes-to-offsets iterator to compute the total number of elements
auto const total_elements = sizes_to_offsets(input_itr, input_itr + count + 1, d_offsets, stream);
auto const total_elements =
sizes_to_offsets(input_itr, input_itr + count + 1, d_offsets, 0, stream);
CUDF_EXPECTS(
total_elements <= static_cast<decltype(total_elements)>(std::numeric_limits<size_type>::max()),
"Size of output exceeds the column size limit",
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/cudf/strings/detail/strings_children.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ std::pair<std::unique_ptr<column>, int64_t> make_offsets_child_column(
cudf::detail::make_counting_transform_iterator(0, string_offsets_fn{begin, strings_count});
// Use the sizes-to-offsets iterator to compute the total number of elements
auto const total_bytes =
cudf::detail::sizes_to_offsets(input_itr, input_itr + strings_count + 1, d_offsets, stream);
cudf::detail::sizes_to_offsets(input_itr, input_itr + strings_count + 1, d_offsets, 0, stream);

auto const threshold = cudf::strings::get_offset64_threshold();
CUDF_EXPECTS(cudf::strings::is_large_strings_enabled() || (total_bytes < threshold),
Expand All @@ -163,7 +163,8 @@ std::pair<std::unique_ptr<column>, int64_t> make_offsets_child_column(
offsets_column = make_numeric_column(
data_type{type_id::INT64}, strings_count + 1, mask_state::UNALLOCATED, stream, mr);
auto d_offsets64 = offsets_column->mutable_view().template data<int64_t>();
cudf::detail::sizes_to_offsets(input_itr, input_itr + strings_count + 1, d_offsets64, stream);
cudf::detail::sizes_to_offsets(
input_itr, input_itr + strings_count + 1, d_offsets64, 0, stream);
}

return std::pair(std::move(offsets_column), total_bytes);
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/io/parquet/decode_fixed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ constexpr bool is_split_decode()
* @param chunks List of column chunks
* @param min_row Row index to start reading at
* @param num_rows Maximum number of rows to read
* @param initial_str_offsets Vector to store the initial offsets for large nested string cols
* @param error_code Error code to set if an error is encountered
*/
template <typename level_t, int decode_block_size_t, decode_kernel_mask kernel_mask_t>
Expand All @@ -950,6 +951,7 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
device_span<ColumnChunkDesc const> chunks,
size_t min_row,
size_t num_rows,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code)
{
constexpr bool has_dict_t = has_dict<kernel_mask_t>();
Expand Down Expand Up @@ -1161,11 +1163,14 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
valid_count = next_valid_count;
}

// Now turn the array of lengths into offsets, but skip if this is a large string column. In the
// latter case, offsets will be computed during string column creation.
if constexpr (has_strings_t) {
if (!s->col.is_large_string_col) {
convert_small_string_lengths_to_offsets<decode_block_size_t, has_lists_t>(s);
// For large strings, update the initial string buffer offset to be used during large string
// column construction. Otherwise, convert string sizes to final offsets.
if (s->col.is_large_string_col) {
compute_initial_large_strings_offset(
s, initial_str_offsets[pages[page_idx].chunk_idx], has_lists_t);
} else {
convert_small_string_lengths_to_offsets<decode_block_size_t>(s, has_lists_t);
}
}
if (t == 0 and s->error != 0) { set_error(s->error, error_code); }
Expand All @@ -1185,6 +1190,7 @@ void __host__ DecodePageData(cudf::detail::hostdevice_span<PageInfo> pages,
size_t min_row,
int level_type_size,
decode_kernel_mask kernel_mask,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream)
{
Expand All @@ -1199,11 +1205,11 @@ void __host__ DecodePageData(cudf::detail::hostdevice_span<PageInfo> pages,
if (level_type_size == 1) {
gpuDecodePageDataGeneric<uint8_t, decode_block_size, mask>
<<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
} else {
gpuDecodePageDataGeneric<uint16_t, decode_block_size, mask>
<<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
}
};

Expand Down
50 changes: 23 additions & 27 deletions cpp/src/io/parquet/page_delta_decode.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -435,6 +435,7 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size)
device_span<ColumnChunkDesc const> chunks,
size_t min_row,
size_t num_rows,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code)
{
using cudf::detail::warp_size;
Expand Down Expand Up @@ -579,17 +580,13 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size)
__syncthreads();
}

// Now turn the array of lengths into offsets, but skip if this is a large string column. In the
// latter case, offsets will be computed during string column creation.
if (not s->col.is_large_string_col) {
int value_count = nesting_info_base[leaf_level_index].value_count;

// if no repetition we haven't calculated start/end bounds and instead just skipped
// values until we reach first_row. account for that here.
if (!has_repetition) { value_count -= s->first_row; }

auto const offptr = reinterpret_cast<size_type*>(nesting_info_base[leaf_level_index].data_out);
block_excl_sum<decode_block_size>(offptr, value_count, s->page.str_offset);
// For large strings, update the initial string buffer offset to be used during large string
// column construction. Otherwise, convert string sizes to final offsets.
if (s->col.is_large_string_col) {
compute_initial_large_strings_offset(
s, initial_str_offsets[pages[page_idx].chunk_idx], has_repetition);
} else {
convert_small_string_lengths_to_offsets<decode_block_size>(s, has_repetition);
}

if (t == 0 and s->error != 0) { set_error(s->error, error_code); }
Expand All @@ -603,6 +600,7 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size)
device_span<ColumnChunkDesc const> chunks,
size_t min_row,
size_t num_rows,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code)
{
using cudf::detail::warp_size;
Expand Down Expand Up @@ -741,17 +739,13 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size)
__syncthreads();
}

// Now turn the array of lengths into offsets, but skip if this is a large string column. In the
// latter case, offsets will be computed during string column creation.
if (not s->col.is_large_string_col) {
int value_count = nesting_info_base[leaf_level_index].value_count;

// if no repetition we haven't calculated start/end bounds and instead just skipped
// values until we reach first_row. account for that here.
if (!has_repetition) { value_count -= s->first_row; }

auto const offptr = reinterpret_cast<size_type*>(nesting_info_base[leaf_level_index].data_out);
block_excl_sum<decode_block_size>(offptr, value_count, s->page.str_offset);
// For large strings, update the initial string buffer offset to be used during large string
// column construction. Otherwise, convert string sizes to final offsets.
if (s->col.is_large_string_col) {
compute_initial_large_strings_offset(
s, initial_str_offsets[pages[page_idx].chunk_idx], has_repetition);
} else {
convert_small_string_lengths_to_offsets<decode_block_size>(s, has_repetition);
}

// finally, copy the string data into place
Expand Down Expand Up @@ -797,6 +791,7 @@ void DecodeDeltaByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
size_t num_rows,
size_t min_row,
int level_type_size,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream)
{
Expand All @@ -807,10 +802,10 @@ void DecodeDeltaByteArray(cudf::detail::hostdevice_span<PageInfo> pages,

if (level_type_size == 1) {
gpuDecodeDeltaByteArray<uint8_t><<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
} else {
gpuDecodeDeltaByteArray<uint16_t><<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
}
}

Expand All @@ -822,6 +817,7 @@ void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
size_t num_rows,
size_t min_row,
int level_type_size,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream)
{
Expand All @@ -832,10 +828,10 @@ void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_span<PageInfo> pages,

if (level_type_size == 1) {
gpuDecodeDeltaLengthByteArray<uint8_t><<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
} else {
gpuDecodeDeltaLengthByteArray<uint16_t><<<dim_grid, dim_block, 0, stream.value()>>>(
pages.device_ptr(), chunks, min_row, num_rows, error_code);
pages.device_ptr(), chunks, min_row, num_rows, initial_str_offsets, error_code);
}
}

Expand Down
49 changes: 42 additions & 7 deletions cpp/src/io/parquet/page_string_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <cudf/strings/detail/gather.cuh>

#include <cuda/atomic>

namespace cudf::io::parquet::detail {

// stole this from cudf/strings/detail/gather.cuh. modified to run on a single string on one warp.
Expand Down Expand Up @@ -98,21 +100,54 @@ __device__ inline void block_excl_sum(size_type* arr, size_type length, size_typ
}
}

template <int block_size, bool has_lists>
__device__ inline void convert_small_string_lengths_to_offsets(page_state_s* s)
/**
* @brief Converts string sizes to offsets if this is not a large string column. Otherwise,
* atomically update the initial string offset to be used during large string column construction
*/
template <int block_size>
__device__ void convert_small_string_lengths_to_offsets(page_state_s const* const state,
bool has_lists)
{
// If this is a large string column. In the
// latter case, offsets will be computed during string column creation.
auto& ni = s->nesting_info[s->col.max_nesting_depth - 1];
auto& ni = state->nesting_info[state->col.max_nesting_depth - 1];
int value_count = ni.value_count;

// if no repetition we haven't calculated start/end bounds and instead just skipped
// values until we reach first_row. account for that here.
if constexpr (!has_lists) { value_count -= s->first_row; }
if (not has_lists) { value_count -= state->first_row; }

// Convert the array of lengths into offsets
if (value_count > 0) {
auto const offptr = reinterpret_cast<size_type*>(ni.data_out);
auto const initial_value = state->page.str_offset;
block_excl_sum<block_size>(offptr, value_count, initial_value);
}
}

auto const offptr = reinterpret_cast<size_type*>(ni.data_out);
auto const initial_value = s->page.str_offset;
block_excl_sum<block_size>(offptr, value_count, initial_value);
/**
* @brief Atomically update the initial string offset to be used during large string column
* construction
*/
inline __device__ void compute_initial_large_strings_offset(page_state_s const* const state,
size_t& initial_str_offset,
bool has_lists)
{
// Values decoded by this page.
int value_count = state->nesting_info[state->col.max_nesting_depth - 1].value_count;

// if no repetition we haven't calculated start/end bounds and instead just skipped
// values until we reach first_row. account for that here.
if (not has_lists) { value_count -= state->first_row; }

// Atomically update the initial string offset if this is a large string column. This initial
// offset will be used to compute (64-bit) offsets during large string column construction.
if (value_count > 0 and threadIdx.x == 0) {
auto const initial_value = state->page.str_offset;
cuda::atomic_ref<size_t, cuda::std::thread_scope_device> initial_str_offsets_ref{
initial_str_offset};
initial_str_offsets_ref.fetch_min(initial_value, cuda::std::memory_order_relaxed);
}
}

template <int block_size>
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/io/parquet/parquet_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ void DecodeDeltaBinary(cudf::detail::hostdevice_span<PageInfo> pages,
* @param[in] num_rows Total number of rows to read
* @param[in] min_row Minimum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[out] initial_str_offsets Vector to store the initial offsets for large nested string cols
* @param[out] error_code Error code for kernel failures
* @param[in] stream CUDA stream to use
*/
Expand All @@ -884,6 +885,7 @@ void DecodeDeltaByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
size_t num_rows,
size_t min_row,
int level_type_size,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream);

Expand All @@ -898,6 +900,7 @@ void DecodeDeltaByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
* @param[in] num_rows Total number of rows to read
* @param[in] min_row Minimum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[out] initial_str_offsets Vector to store the initial offsets for large nested string cols
* @param[out] error_code Error code for kernel failures
* @param[in] stream CUDA stream to use
*/
Expand All @@ -906,6 +909,7 @@ void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
size_t num_rows,
size_t min_row,
int level_type_size,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream);

Expand All @@ -921,6 +925,7 @@ void DecodeDeltaLengthByteArray(cudf::detail::hostdevice_span<PageInfo> pages,
* @param[in] min_row Minimum number of rows to read
* @param[in] level_type_size Size in bytes of the type for level decoding
* @param[in] kernel_mask Mask indicating the type of decoding kernel to launch.
* @param[out] initial_str_offsets Vector to store the initial offsets for large nested string cols
* @param[out] error_code Error code for kernel failures
* @param[in] stream CUDA stream to use
*/
Expand All @@ -930,6 +935,7 @@ void DecodePageData(cudf::detail::hostdevice_span<PageInfo> pages,
size_t min_row,
int level_type_size,
decode_kernel_mask kernel_mask,
cudf::device_span<size_t> initial_str_offsets,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream);

Expand Down
Loading

0 comments on commit e0fe51d

Please sign in to comment.