Skip to content

Commit

Permalink
Precompute AST arity (#17234)
Browse files Browse the repository at this point in the history
This PR precomputes AST arity on the host, to reduce the complexity in device-side arity lookup.

Authors:
  - Bradley Dice (https://github.com/bdice)
  - Basit Ayantunde (https://github.com/lamarrr)

Approvers:
  - Basit Ayantunde (https://github.com/lamarrr)
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)

URL: #17234
  • Loading branch information
bdice authored Jan 13, 2025
1 parent bbf4f78 commit 478ec50
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 378 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ add_library(
src/aggregation/result_cache.cpp
src/ast/expression_parser.cpp
src/ast/expressions.cpp
src/ast/operators.cpp
src/binaryop/binaryop.cpp
src/binaryop/compiled/ATan2.cu
src/binaryop/compiled/Add.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/ast/detail/expression_evaluator.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -452,7 +452,7 @@ struct expression_evaluator {
++operator_index) {
// Execute operator
auto const op = plan.operators[operator_index];
auto const arity = ast_operator_arity(op);
auto const arity = plan.operator_arities[operator_index];
if (arity == 1) {
// Unary operator
auto const& input =
Expand Down
50 changes: 36 additions & 14 deletions cpp/include/cudf/ast/detail/expression_parser.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

#include <cudf/ast/detail/operators.hpp>
#include <cudf/ast/expressions.hpp>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
Expand Down Expand Up @@ -88,6 +89,7 @@ struct expression_device_view {
device_span<detail::device_data_reference const> data_references;
device_span<generic_scalar_device_view const> literals;
device_span<ast_operator const> operators;
device_span<cudf::size_type const> operator_arities;
device_span<cudf::size_type const> operator_source_indices;
cudf::size_type num_intermediates;
};
Expand Down Expand Up @@ -229,39 +231,55 @@ class expression_parser {
* @param[in] v The `std::vector` containing components (operators, literals, etc).
* @param[in,out] sizes The `std::vector` containing the size of each data buffer.
* @param[in,out] data_pointers The `std::vector` containing pointers to each data buffer.
* @param[in,out] alignment The maximum alignment needed for all the extracted size and pointers
*/
template <typename T>
void extract_size_and_pointer(std::vector<T> const& v,
std::vector<cudf::size_type>& sizes,
std::vector<void const*>& data_pointers)
std::vector<void const*>& data_pointers,
cudf::size_type& alignment)
{
// sub-type alignment will only work provided the alignment is lesser or equal to
// alignof(max_align_t) which is the maximum alignment provided by rmm's device buffers
static_assert(alignof(T) <= alignof(max_align_t));
auto const data_size = sizeof(T) * v.size();
sizes.push_back(data_size);
data_pointers.push_back(v.data());
alignment = std::max(alignment, static_cast<cudf::size_type>(alignof(T)));
}

void move_to_device(rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr)
{
std::vector<cudf::size_type> sizes;
std::vector<void const*> data_pointers;
// use a minimum of 4-byte alignment
cudf::size_type buffer_alignment = 4;

extract_size_and_pointer(_data_references, sizes, data_pointers);
extract_size_and_pointer(_literals, sizes, data_pointers);
extract_size_and_pointer(_operators, sizes, data_pointers);
extract_size_and_pointer(_operator_source_indices, sizes, data_pointers);
extract_size_and_pointer(_data_references, sizes, data_pointers, buffer_alignment);
extract_size_and_pointer(_literals, sizes, data_pointers, buffer_alignment);
extract_size_and_pointer(_operators, sizes, data_pointers, buffer_alignment);
extract_size_and_pointer(_operator_arities, sizes, data_pointers, buffer_alignment);
extract_size_and_pointer(_operator_source_indices, sizes, data_pointers, buffer_alignment);

// Create device buffer
auto const buffer_size = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
auto buffer_offsets = std::vector<int>(sizes.size());
thrust::exclusive_scan(sizes.cbegin(), sizes.cend(), buffer_offsets.begin(), 0);
auto buffer_offsets = std::vector<cudf::size_type>(sizes.size());
thrust::exclusive_scan(sizes.cbegin(),
sizes.cend(),
buffer_offsets.begin(),
cudf::size_type{0},
[buffer_alignment](auto a, auto b) {
// align each component of the AST program
return cudf::util::round_up_safe(a + b, buffer_alignment);
});

auto const buffer_size = buffer_offsets.empty() ? 0 : (buffer_offsets.back() + sizes.back());
auto host_data_buffer = std::vector<char>(buffer_size);

auto h_data_buffer = std::vector<char>(buffer_size);
for (unsigned int i = 0; i < data_pointers.size(); ++i) {
std::memcpy(h_data_buffer.data() + buffer_offsets[i], data_pointers[i], sizes[i]);
std::memcpy(host_data_buffer.data() + buffer_offsets[i], data_pointers[i], sizes[i]);
}

_device_data_buffer = rmm::device_buffer(h_data_buffer.data(), buffer_size, stream, mr);

_device_data_buffer = rmm::device_buffer(host_data_buffer.data(), buffer_size, stream, mr);
stream.synchronize();

// Create device pointers to components of plan
Expand All @@ -277,8 +295,11 @@ class expression_parser {
device_expression_data.operators = device_span<ast_operator const>(
reinterpret_cast<ast_operator const*>(device_data_buffer_ptr + buffer_offsets[2]),
_operators.size());
device_expression_data.operator_source_indices = device_span<cudf::size_type const>(
device_expression_data.operator_arities = device_span<cudf::size_type const>(
reinterpret_cast<cudf::size_type const*>(device_data_buffer_ptr + buffer_offsets[3]),
_operators.size());
device_expression_data.operator_source_indices = device_span<cudf::size_type const>(
reinterpret_cast<cudf::size_type const*>(device_data_buffer_ptr + buffer_offsets[4]),
_operator_source_indices.size());
device_expression_data.num_intermediates = _intermediate_counter.get_max_used();
shmem_per_thread = static_cast<int>(
Expand Down Expand Up @@ -322,6 +343,7 @@ class expression_parser {
bool _has_nulls;
std::vector<detail::device_data_reference> _data_references;
std::vector<ast_operator> _operators;
std::vector<cudf::size_type> _operator_arities;
std::vector<cudf::size_type> _operator_source_indices;
std::vector<generic_scalar_device_view> _literals;
};
Expand Down
Loading

0 comments on commit 478ec50

Please sign in to comment.