From 478ec50edf302a338db043039abad6a2560144ea Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 13 Jan 2025 15:19:44 -0600 Subject: [PATCH] Precompute AST arity (#17234) This PR precomputes AST arity on the host, to reduce the complexity in device-side arity lookup. Authors: - Bradley Dice (https://github.com/bdice) - Basit Ayantunde (https://github.com/lamarrr) Approvers: - Basit Ayantunde (https://github.com/lamarrr) - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/cudf/pull/17234 --- cpp/CMakeLists.txt | 1 + .../cudf/ast/detail/expression_evaluator.cuh | 4 +- .../cudf/ast/detail/expression_parser.hpp | 50 ++- cpp/include/cudf/ast/detail/operators.hpp | 418 +++--------------- cpp/src/ast/expression_parser.cpp | 3 +- cpp/src/ast/operators.cpp | 293 ++++++++++++ 6 files changed, 391 insertions(+), 378 deletions(-) create mode 100644 cpp/src/ast/operators.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 252cc7897d8..4d83cbd907c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -336,6 +336,7 @@ add_library( src/aggregation/result_cache.cpp src/ast/expression_parser.cpp src/ast/expressions.cpp + src/ast/operators.cpp src/binaryop/binaryop.cpp src/binaryop/compiled/ATan2.cu src/binaryop/compiled/Add.cu diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index 9d8762555d7..001b604814c 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -452,7 +452,7 @@ struct expression_evaluator { ++operator_index) { // Execute operator auto const op = plan.operators[operator_index]; - auto const arity = ast_operator_arity(op); + auto const arity = plan.operator_arities[operator_index]; if (arity == 1) { // Unary operator auto const& input = diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index b5973d0ace9..d2e8c1cd41f 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -88,6 +89,7 @@ struct expression_device_view { device_span data_references; device_span literals; device_span operators; + device_span operator_arities; device_span operator_source_indices; cudf::size_type num_intermediates; }; @@ -229,39 +231,55 @@ class expression_parser { * @param[in] v The `std::vector` containing components (operators, literals, etc). * @param[in,out] sizes The `std::vector` containing the size of each data buffer. * @param[in,out] data_pointers The `std::vector` containing pointers to each data buffer. + * @param[in,out] alignment The maximum alignment needed for all the extracted size and pointers */ template void extract_size_and_pointer(std::vector const& v, std::vector& sizes, - std::vector& data_pointers) + std::vector& data_pointers, + cudf::size_type& alignment) { + // sub-type alignment will only work provided the alignment is lesser or equal to + // alignof(max_align_t) which is the maximum alignment provided by rmm's device buffers + static_assert(alignof(T) <= alignof(max_align_t)); auto const data_size = sizeof(T) * v.size(); sizes.push_back(data_size); data_pointers.push_back(v.data()); + alignment = std::max(alignment, static_cast(alignof(T))); } void move_to_device(rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { std::vector sizes; std::vector data_pointers; + // use a minimum of 4-byte alignment + cudf::size_type buffer_alignment = 4; - extract_size_and_pointer(_data_references, sizes, data_pointers); - extract_size_and_pointer(_literals, sizes, data_pointers); - extract_size_and_pointer(_operators, sizes, data_pointers); - extract_size_and_pointer(_operator_source_indices, sizes, data_pointers); + extract_size_and_pointer(_data_references, sizes, data_pointers, buffer_alignment); + extract_size_and_pointer(_literals, sizes, data_pointers, buffer_alignment); + extract_size_and_pointer(_operators, sizes, data_pointers, buffer_alignment); + extract_size_and_pointer(_operator_arities, sizes, data_pointers, buffer_alignment); + extract_size_and_pointer(_operator_source_indices, sizes, data_pointers, buffer_alignment); // Create device buffer - auto const buffer_size = std::accumulate(sizes.cbegin(), sizes.cend(), 0); - auto buffer_offsets = std::vector(sizes.size()); - thrust::exclusive_scan(sizes.cbegin(), sizes.cend(), buffer_offsets.begin(), 0); + auto buffer_offsets = std::vector(sizes.size()); + thrust::exclusive_scan(sizes.cbegin(), + sizes.cend(), + buffer_offsets.begin(), + cudf::size_type{0}, + [buffer_alignment](auto a, auto b) { + // align each component of the AST program + return cudf::util::round_up_safe(a + b, buffer_alignment); + }); + + auto const buffer_size = buffer_offsets.empty() ? 0 : (buffer_offsets.back() + sizes.back()); + auto host_data_buffer = std::vector(buffer_size); - auto h_data_buffer = std::vector(buffer_size); for (unsigned int i = 0; i < data_pointers.size(); ++i) { - std::memcpy(h_data_buffer.data() + buffer_offsets[i], data_pointers[i], sizes[i]); + std::memcpy(host_data_buffer.data() + buffer_offsets[i], data_pointers[i], sizes[i]); } - _device_data_buffer = rmm::device_buffer(h_data_buffer.data(), buffer_size, stream, mr); - + _device_data_buffer = rmm::device_buffer(host_data_buffer.data(), buffer_size, stream, mr); stream.synchronize(); // Create device pointers to components of plan @@ -277,8 +295,11 @@ class expression_parser { device_expression_data.operators = device_span( reinterpret_cast(device_data_buffer_ptr + buffer_offsets[2]), _operators.size()); - device_expression_data.operator_source_indices = device_span( + device_expression_data.operator_arities = device_span( reinterpret_cast(device_data_buffer_ptr + buffer_offsets[3]), + _operators.size()); + device_expression_data.operator_source_indices = device_span( + reinterpret_cast(device_data_buffer_ptr + buffer_offsets[4]), _operator_source_indices.size()); device_expression_data.num_intermediates = _intermediate_counter.get_max_used(); shmem_per_thread = static_cast( @@ -322,6 +343,7 @@ class expression_parser { bool _has_nulls; std::vector _data_references; std::vector _operators; + std::vector _operator_arities; std::vector _operator_source_indices; std::vector _literals; }; diff --git a/cpp/include/cudf/ast/detail/operators.hpp b/cpp/include/cudf/ast/detail/operators.hpp index 46507700e21..db04e1fe989 100644 --- a/cpp/include/cudf/ast/detail/operators.hpp +++ b/cpp/include/cudf/ast/detail/operators.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,159 +69,111 @@ constexpr bool is_valid_unary_op = cuda::std::is_invocable_v; * @param args Forwarded arguments to `operator()` of `f`. */ template -CUDF_HOST_DEVICE inline constexpr void ast_operator_dispatcher(ast_operator op, F&& f, Ts&&... args) +CUDF_HOST_DEVICE inline constexpr decltype(auto) ast_operator_dispatcher(ast_operator op, + F&& f, + Ts&&... args) { switch (op) { case ast_operator::ADD: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::SUB: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::MUL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::DIV: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::TRUE_DIV: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::FLOOR_DIV: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::MOD: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::PYMOD: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::POW: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::EQUAL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::NULL_EQUAL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::NOT_EQUAL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::LESS: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::GREATER: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::LESS_EQUAL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::GREATER_EQUAL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::BITWISE_AND: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::BITWISE_OR: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::BITWISE_XOR: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::LOGICAL_AND: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::NULL_LOGICAL_AND: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::LOGICAL_OR: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::NULL_LOGICAL_OR: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::IDENTITY: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::IS_NULL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::SIN: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::COS: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::TAN: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCSIN: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCCOS: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCTAN: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::SINH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::COSH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::TANH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCSINH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCCOSH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ARCTANH: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::EXP: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::LOG: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::SQRT: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::CBRT: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::CEIL: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::FLOOR: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::ABS: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::RINT: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::BIT_INVERT: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::NOT: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::CAST_TO_INT64: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::CAST_TO_UINT64: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); case ast_operator::CAST_TO_FLOAT64: - f.template operator()(std::forward(args)...); - break; + return f.template operator()(std::forward(args)...); default: { #ifndef __CUDA_ARCH__ CUDF_FAIL("Invalid operator."); @@ -955,231 +907,6 @@ struct operator_functor { } }; -/** - * @brief Functor used to single-type-dispatch binary operators. - * - * This functor's `operator()` is templated to validate calls to its operators based on the input - * type, as determined by the `is_valid_binary_op` trait. This function assumes that both inputs are - * the same type, and dispatches based on the type of the left input. - * - * @tparam OperatorFunctor Binary operator functor. - */ -template -struct single_dispatch_binary_operator_types { - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(F&& f, Ts&&... args) - { - f.template operator()(std::forward(args)...); - } - - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(F&& f, Ts&&... args) - { -#ifndef __CUDA_ARCH__ - CUDF_FAIL("Invalid binary operation."); -#else - CUDF_UNREACHABLE("Invalid binary operation."); -#endif - } -}; - -/** - * @brief Functor performing a type dispatch for a binary operator. - * - * This functor performs single dispatch, which assumes lhs_type == rhs_type. This may not be true - * for all binary operators but holds for all currently implemented operators. - */ -struct type_dispatch_binary_op { - /** - * @brief Performs type dispatch for a binary operator. - * - * @tparam op AST operator. - * @tparam F Type of forwarded functor. - * @tparam Ts Parameter pack of forwarded arguments. - * @param lhs_type Type of left input data. - * @param rhs_type Type of right input data. - * @param f Forwarded functor to be called. - * @param args Forwarded arguments to `operator()` of `f`. - */ - template - CUDF_HOST_DEVICE inline void operator()(cudf::data_type lhs_type, - cudf::data_type rhs_type, - F&& f, - Ts&&... args) - { - // Single dispatch (assume lhs_type == rhs_type) - type_dispatcher( - lhs_type, - // Always dispatch to the non-null operator for the purpose of type determination. - detail::single_dispatch_binary_operator_types>{}, - std::forward(f), - std::forward(args)...); - } -}; - -/** - * @brief Dispatches a runtime binary operator to a templated type dispatcher. - * - * @tparam F Type of forwarded functor. - * @tparam Ts Parameter pack of forwarded arguments. - * @param lhs_type Type of left input data. - * @param rhs_type Type of right input data. - * @param f Forwarded functor to be called. - * @param args Forwarded arguments to `operator()` of `f`. - */ -template -CUDF_HOST_DEVICE inline constexpr void binary_operator_dispatcher( - ast_operator op, cudf::data_type lhs_type, cudf::data_type rhs_type, F&& f, Ts&&... args) -{ - ast_operator_dispatcher(op, - detail::type_dispatch_binary_op{}, - lhs_type, - rhs_type, - std::forward(f), - std::forward(args)...); -} - -/** - * @brief Functor used to type-dispatch unary operators. - * - * This functor's `operator()` is templated to validate calls to its operators based on the input - * type, as determined by the `is_valid_unary_op` trait. - * - * @tparam OperatorFunctor Unary operator functor. - */ -template -struct dispatch_unary_operator_types { - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(F&& f, Ts&&... args) - { - f.template operator()(std::forward(args)...); - } - - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(F&& f, Ts&&... args) - { -#ifndef __CUDA_ARCH__ - CUDF_FAIL("Invalid unary operation."); -#else - CUDF_UNREACHABLE("Invalid unary operation."); -#endif - } -}; - -/** - * @brief Functor performing a type dispatch for a unary operator. - */ -struct type_dispatch_unary_op { - template - CUDF_HOST_DEVICE inline void operator()(cudf::data_type input_type, F&& f, Ts&&... args) - { - type_dispatcher( - input_type, - // Always dispatch to the non-null operator for the purpose of type determination. - detail::dispatch_unary_operator_types>{}, - std::forward(f), - std::forward(args)...); - } -}; - -/** - * @brief Dispatches a runtime unary operator to a templated type dispatcher. - * - * @tparam F Type of forwarded functor. - * @tparam Ts Parameter pack of forwarded arguments. - * @param input_type Type of input data. - * @param f Forwarded functor to be called. - * @param args Forwarded arguments to `operator()` of `f`. - */ -template -CUDF_HOST_DEVICE inline constexpr void unary_operator_dispatcher(ast_operator op, - cudf::data_type input_type, - F&& f, - Ts&&... args) -{ - ast_operator_dispatcher(op, - detail::type_dispatch_unary_op{}, - input_type, - std::forward(f), - std::forward(args)...); -} - -/** - * @brief Functor to determine the return type of an operator from its input types. - */ -struct return_type_functor { - /** - * @brief Callable for binary operators to determine return type. - * - * @tparam OperatorFunctor Operator functor to perform. - * @tparam LHS Left input type. - * @tparam RHS Right input type. - * @param result Reference whose value is assigned to the result data type. - */ - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(cudf::data_type& result) - { - using Out = cuda::std::invoke_result_t; - result = cudf::data_type(cudf::type_to_id()); - } - - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(cudf::data_type& result) - { -#ifndef __CUDA_ARCH__ - CUDF_FAIL("Invalid binary operation. Return type cannot be determined."); -#else - CUDF_UNREACHABLE("Invalid binary operation. Return type cannot be determined."); -#endif - } - - /** - * @brief Callable for unary operators to determine return type. - * - * @tparam OperatorFunctor Operator functor to perform. - * @tparam T Input type. - * @param result Pointer whose value is assigned to the result data type. - */ - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(cudf::data_type& result) - { - using Out = cuda::std::invoke_result_t; - result = cudf::data_type(cudf::type_to_id()); - } - - template >* = nullptr> - CUDF_HOST_DEVICE inline void operator()(cudf::data_type& result) - { -#ifndef __CUDA_ARCH__ - CUDF_FAIL("Invalid unary operation. Return type cannot be determined."); -#else - CUDF_UNREACHABLE("Invalid unary operation. Return type cannot be determined."); -#endif - } -}; - /** * @brief Gets the return type of an AST operator. * @@ -1187,34 +914,8 @@ struct return_type_functor { * @param operand_types Vector of input types to the operator. * @return cudf::data_type Return type of the operator. */ -inline cudf::data_type ast_operator_return_type(ast_operator op, - std::vector const& operand_types) -{ - auto result = cudf::data_type(cudf::type_id::EMPTY); - switch (operand_types.size()) { - case 1: - unary_operator_dispatcher(op, operand_types[0], detail::return_type_functor{}, result); - break; - case 2: - binary_operator_dispatcher( - op, operand_types[0], operand_types[1], detail::return_type_functor{}, result); - break; - default: CUDF_FAIL("Unsupported operator return type."); break; - } - return result; -} - -/** - * @brief Functor to determine the arity (number of operands) of an operator. - */ -struct arity_functor { - template - CUDF_HOST_DEVICE inline void operator()(cudf::size_type& result) - { - // Arity is not dependent on null handling, so just use the false implementation here. - result = operator_functor::arity; - } -}; +cudf::data_type ast_operator_return_type(ast_operator op, + std::vector const& operand_types); /** * @brief Gets the arity (number of operands) of an AST operator. @@ -1222,12 +923,7 @@ struct arity_functor { * @param op Operator used to determine arity. * @return Arity of the operator. */ -CUDF_HOST_DEVICE inline cudf::size_type ast_operator_arity(ast_operator op) -{ - auto result = cudf::size_type(0); - ast_operator_dispatcher(op, detail::arity_functor{}, result); - return result; -} +cudf::size_type ast_operator_arity(ast_operator op); } // namespace detail diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index d0e4c59ca54..b2cc134d9fa 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,6 +161,7 @@ cudf::size_type expression_parser::visit(operation const& expr) auto const op = expr.get_operator(); auto const data_type = cudf::ast::detail::ast_operator_return_type(op, operand_types); _operators.push_back(op); + _operator_arities.push_back(cudf::ast::detail::ast_operator_arity(op)); // Push data reference auto const output = [&]() { if (expression_index == 0) { diff --git a/cpp/src/ast/operators.cpp b/cpp/src/ast/operators.cpp new file mode 100644 index 00000000000..b60a69a42d9 --- /dev/null +++ b/cpp/src/ast/operators.cpp @@ -0,0 +1,293 @@ +/* + * Copyright (c) 2021-2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include + +#include + +namespace cudf { +namespace ast { +namespace detail { +namespace { + +struct arity_functor { + template + void operator()(cudf::size_type& result) + { + // Arity is not dependent on null handling, so just use the false implementation here. + result = operator_functor::arity; + } +}; + +/** + * @brief Functor to determine the return type of an operator from its input types. + */ +struct return_type_functor { + /** + * @brief Callable for binary operators to determine return type. + * + * @tparam OperatorFunctor Operator functor to perform. + * @tparam LHS Left input type. + * @tparam RHS Right input type. + * @param result Pointer whose value is assigned to the result data type. + */ + template >* = nullptr> + void operator()(cudf::data_type& result) + { + using Out = cuda::std::invoke_result_t; + result = cudf::data_type{cudf::type_to_id()}; + } + + template >* = nullptr> + void operator()(cudf::data_type& result) + { +#ifndef __CUDA_ARCH__ + CUDF_FAIL("Invalid binary operation. Return type cannot be determined."); +#else + CUDF_UNREACHABLE("Invalid binary operation. Return type cannot be determined."); +#endif + result = cudf::data_type{cudf::type_id::EMPTY}; + } + + /** + * @brief Callable for unary operators to determine return type. + * + * @tparam OperatorFunctor Operator functor to perform. + * @tparam T Input type. + * @param result Pointer whose value is assigned to the result data type. + */ + template >* = nullptr> + void operator()(cudf::data_type& result) + { + using Out = cuda::std::invoke_result_t; + result = cudf::data_type{cudf::type_to_id()}; + } + + template >* = nullptr> + void operator()(cudf::data_type& result) + { +#ifndef __CUDA_ARCH__ + CUDF_FAIL("Invalid unary operation. Return type cannot be determined."); +#else + CUDF_UNREACHABLE("Invalid unary operation. Return type cannot be determined."); +#endif + result = cudf::data_type{cudf::type_id::EMPTY}; + } +}; + +/** + * @brief Functor used to single-type-dispatch binary operators. + * + * This functor's `operator()` is templated to validate calls to its operators based on the input + * type, as determined by the `is_valid_binary_op` trait. This function assumes that both inputs are + * the same type, and dispatches based on the type of the left input. + * + * @tparam OperatorFunctor Binary operator functor. + */ +template +struct single_dispatch_binary_operator_types { + template >* = nullptr> + inline void operator()(F&& f, Ts&&... args) + { + f.template operator()(std::forward(args)...); + } + + template >* = nullptr> + inline void operator()(F&& f, Ts&&... args) + { +#ifndef __CUDA_ARCH__ + CUDF_FAIL("Invalid binary operation."); +#else + CUDF_UNREACHABLE("Invalid binary operation."); +#endif + } +}; + +/** + * @brief Functor performing a type dispatch for a binary operator. + * + * This functor performs single dispatch, which assumes lhs_type == rhs_type. This may not be true + * for all binary operators but holds for all currently implemented operators. + */ +struct type_dispatch_binary_op { + /** + * @brief Performs type dispatch for a binary operator. + * + * @tparam op AST operator. + * @tparam F Type of forwarded functor. + * @tparam Ts Parameter pack of forwarded arguments. + * @param lhs_type Type of left input data. + * @param rhs_type Type of right input data. + * @param f Forwarded functor to be called. + * @param args Forwarded arguments to `operator()` of `f`. + */ + template + inline void operator()(cudf::data_type lhs_type, cudf::data_type rhs_type, F&& f, Ts&&... args) + { + // Single dispatch (assume lhs_type == rhs_type) + type_dispatcher( + lhs_type, + // Always dispatch to the non-null operator for the purpose of type determination. + detail::single_dispatch_binary_operator_types>{}, + std::forward(f), + std::forward(args)...); + } +}; + +/** + * @brief Dispatches a runtime binary operator to a templated type dispatcher. + * + * @tparam F Type of forwarded functor. + * @tparam Ts Parameter pack of forwarded arguments. + * @param lhs_type Type of left input data. + * @param rhs_type Type of right input data. + * @param f Forwarded functor to be called. + * @param args Forwarded arguments to `operator()` of `f`. + */ +template +inline constexpr void binary_operator_dispatcher( + ast_operator op, cudf::data_type lhs_type, cudf::data_type rhs_type, F&& f, Ts&&... args) +{ + ast_operator_dispatcher(op, + detail::type_dispatch_binary_op{}, + lhs_type, + rhs_type, + std::forward(f), + std::forward(args)...); +} + +/** + * @brief Functor used to type-dispatch unary operators. + * + * This functor's `operator()` is templated to validate calls to its operators based on the input + * type, as determined by the `is_valid_unary_op` trait. + * + * @tparam OperatorFunctor Unary operator functor. + */ +template +struct dispatch_unary_operator_types { + template >* = nullptr> + inline void operator()(F&& f, Ts&&... args) + { + f.template operator()(std::forward(args)...); + } + + template >* = nullptr> + inline void operator()(F&& f, Ts&&... args) + { +#ifndef __CUDA_ARCH__ + CUDF_FAIL("Invalid unary operation."); +#else + CUDF_UNREACHABLE("Invalid unary operation."); +#endif + } +}; + +/** + * @brief Functor performing a type dispatch for a unary operator. + */ +struct type_dispatch_unary_op { + template + inline void operator()(cudf::data_type input_type, F&& f, Ts&&... args) + { + type_dispatcher( + input_type, + // Always dispatch to the non-null operator for the purpose of type determination. + detail::dispatch_unary_operator_types>{}, + std::forward(f), + std::forward(args)...); + } +}; + +/** + * @brief Dispatches a runtime unary operator to a templated type dispatcher. + * + * @tparam F Type of forwarded functor. + * @tparam Ts Parameter pack of forwarded arguments. + * @param input_type Type of input data. + * @param f Forwarded functor to be called. + * @param args Forwarded arguments to `operator()` of `f`. + */ +template +inline constexpr void unary_operator_dispatcher(ast_operator op, + cudf::data_type input_type, + F&& f, + Ts&&... args) +{ + ast_operator_dispatcher(op, + detail::type_dispatch_unary_op{}, + input_type, + std::forward(f), + std::forward(args)...); +} + +} // namespace + +cudf::data_type ast_operator_return_type(ast_operator op, + std::vector const& operand_types) +{ + cudf::data_type result{cudf::type_id::EMPTY}; + switch (operand_types.size()) { + case 1: + unary_operator_dispatcher(op, operand_types[0], detail::return_type_functor{}, result); + break; + case 2: + binary_operator_dispatcher( + op, operand_types[0], operand_types[1], detail::return_type_functor{}, result); + break; + default: CUDF_FAIL("Unsupported operator return type."); break; + } + return result; +} + +cudf::size_type ast_operator_arity(ast_operator op) +{ + cudf::size_type result{}; + ast_operator_dispatcher(op, arity_functor{}, result); + return result; +} + +} // namespace detail + +} // namespace ast + +} // namespace cudf