Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into use-gcc-13-with-cuda-12-conda-builds
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice authored Jan 13, 2025
2 parents 630fc0f + 4ec389b commit a7d8e21
Show file tree
Hide file tree
Showing 16 changed files with 941 additions and 495 deletions.
4 changes: 2 additions & 2 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -601,7 +601,7 @@ std::unique_ptr<Base> make_udf_aggregation(udf_type type,
data_type output_type);

// Forward declaration of `host_udf_base` for the factory function of `HOST_UDF` aggregation.
struct host_udf_base;
class host_udf_base;

/**
* @brief Factory to create a HOST_UDF aggregation.
Expand Down
478 changes: 301 additions & 177 deletions cpp/include/cudf/aggregation/host_udf.hpp

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -967,7 +967,9 @@ class udf_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying host-based UDF aggregation.
*/
class host_udf_aggregation final : public groupby_aggregation {
class host_udf_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public segmented_reduce_aggregation {
public:
std::unique_ptr<host_udf_base> udf_ptr;

Expand Down
9 changes: 6 additions & 3 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -145,8 +145,11 @@ struct empty_column_constructor {
}

if constexpr (k == aggregation::Kind::HOST_UDF) {
auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
return std::get<std::unique_ptr<column>>(udf_ptr->get_empty_output(std::nullopt, stream, mr));
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<groupby_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for groupby aggregation.");
return udf_ptr->get_empty_output(stream, mr);
}

return make_empty_column(target_type(values.type(), k));
Expand Down
81 changes: 32 additions & 49 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -795,58 +795,41 @@ void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation con
{
if (cache.has_result(values, agg)) { return; }

auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const data_attrs = [&]() -> host_udf_base::data_attribute_set_t {
if (auto tmp = udf_ptr->get_required_data(); !tmp.empty()) { return tmp; }
// Empty attribute set means everything.
return {host_udf_base::groupby_data_attribute::INPUT_VALUES,
host_udf_base::groupby_data_attribute::GROUPED_VALUES,
host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES,
host_udf_base::groupby_data_attribute::NUM_GROUPS,
host_udf_base::groupby_data_attribute::GROUP_OFFSETS,
host_udf_base::groupby_data_attribute::GROUP_LABELS};
}();
auto const& udf_base_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<groupby_host_udf*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for groupby aggregation.");

// Do not cache udf_input, as the actual input data may change from run to run.
host_udf_base::input_map_t udf_input;
for (auto const& attr : data_attrs) {
CUDF_EXPECTS(std::holds_alternative<host_udf_base::groupby_data_attribute>(attr.value) ||
std::holds_alternative<std::unique_ptr<aggregation>>(attr.value),
"Invalid input data attribute for HOST_UDF groupby aggregation.");
if (std::holds_alternative<host_udf_base::groupby_data_attribute>(attr.value)) {
switch (std::get<host_udf_base::groupby_data_attribute>(attr.value)) {
case host_udf_base::groupby_data_attribute::INPUT_VALUES:
udf_input.emplace(attr, values);
break;
case host_udf_base::groupby_data_attribute::GROUPED_VALUES:
udf_input.emplace(attr, get_grouped_values());
break;
case host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES:
udf_input.emplace(attr, get_sorted_values());
break;
case host_udf_base::groupby_data_attribute::NUM_GROUPS:
udf_input.emplace(attr, helper.num_groups(stream));
break;
case host_udf_base::groupby_data_attribute::GROUP_OFFSETS:
udf_input.emplace(attr, helper.group_offsets(stream));
break;
case host_udf_base::groupby_data_attribute::GROUP_LABELS:
udf_input.emplace(attr, helper.group_labels(stream));
break;
default: CUDF_UNREACHABLE("Invalid input data attribute for HOST_UDF groupby aggregation.");
}
} else { // data is result from another aggregation
auto other_agg = std::get<std::unique_ptr<aggregation>>(attr.value)->clone();
if (!udf_ptr->callback_input_values) {
udf_ptr->callback_input_values = [&]() -> column_view { return values; };
}
if (!udf_ptr->callback_grouped_values) {
udf_ptr->callback_grouped_values = [&]() -> column_view { return get_grouped_values(); };
}
if (!udf_ptr->callback_sorted_grouped_values) {
udf_ptr->callback_sorted_grouped_values = [&]() -> column_view { return get_sorted_values(); };
}
if (!udf_ptr->callback_num_groups) {
udf_ptr->callback_num_groups = [&]() -> size_type { return helper.num_groups(stream); };
}
if (!udf_ptr->callback_group_offsets) {
udf_ptr->callback_group_offsets = [&]() -> device_span<size_type const> {
return helper.group_offsets(stream);
};
}
if (!udf_ptr->callback_group_labels) {
udf_ptr->callback_group_labels = [&]() -> device_span<size_type const> {
return helper.group_labels(stream);
};
}
if (!udf_ptr->callback_compute_aggregation) {
udf_ptr->callback_compute_aggregation =
[&](std::unique_ptr<aggregation> other_agg) -> column_view {
cudf::detail::aggregation_dispatcher(other_agg->kind, *this, *other_agg);
auto result = cache.get_result(values, *other_agg);
udf_input.emplace(std::move(other_agg), std::move(result));
}
return cache.get_result(values, *other_agg);
};
}

auto output = (*udf_ptr)(udf_input, stream, mr);
CUDF_EXPECTS(std::holds_alternative<std::unique_ptr<column>>(output),
"Invalid output type from HOST_UDF groupby aggregation.");
cache.add_result(values, agg, std::get<std::unique_ptr<column>>(std::move(output)));
cache.add_result(values, agg, (*udf_ptr)(stream, mr));
}

} // namespace detail
Expand Down
48 changes: 5 additions & 43 deletions cpp/src/groupby/sort/host_udf_aggregation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,51 +16,9 @@

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/visitor_overload.hpp>

namespace cudf {

host_udf_base::data_attribute::data_attribute(data_attribute const& other)
: value{std::visit(cudf::detail::visitor_overload{[](auto const& val) { return value_type{val}; },
[](std::unique_ptr<aggregation> const& val) {
return value_type{val->clone()};
}},
other.value)}
{
}

std::size_t host_udf_base::data_attribute::hash::operator()(data_attribute const& attr) const
{
auto const hash_value =
std::visit(cudf::detail::visitor_overload{
[](auto const& val) { return std::hash<int>{}(static_cast<int>(val)); },
[](std::unique_ptr<aggregation> const& val) { return val->do_hash(); }},
attr.value);
return std::hash<std::size_t>{}(attr.value.index()) ^ hash_value;
}

bool host_udf_base::data_attribute::equal_to::operator()(data_attribute const& lhs,
data_attribute const& rhs) const
{
auto const& lhs_val = lhs.value;
auto const& rhs_val = rhs.value;
if (lhs_val.index() != rhs_val.index()) { return false; }
return std::visit(
cudf::detail::visitor_overload{
[](auto const& lhs_val, auto const& rhs_val) {
if constexpr (std::is_same_v<decltype(lhs_val), decltype(rhs_val)>) {
return lhs_val == rhs_val;
} else {
return false;
}
},
[](std::unique_ptr<aggregation> const& lhs_val, std::unique_ptr<aggregation> const& rhs_val) {
return lhs_val->is_equal(*rhs_val);
}},
lhs_val,
rhs_val);
}

namespace detail {

host_udf_aggregation::host_udf_aggregation(std::unique_ptr<host_udf_base> udf_ptr_)
Expand Down Expand Up @@ -99,5 +57,9 @@ template CUDF_EXPORT std::unique_ptr<aggregation> make_host_udf_aggregation<aggr
std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_host_udf_aggregation<groupby_aggregation>(std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_host_udf_aggregation<reduce_aggregation>(std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<segmented_reduce_aggregation>
make_host_udf_aggregation<segmented_reduce_aggregation>(std::unique_ptr<host_udf_base>);

} // namespace cudf
16 changes: 13 additions & 3 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/column/column.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/copy.hpp>
Expand Down Expand Up @@ -144,6 +145,13 @@ struct reduce_dispatch_functor {
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
}
case aggregation::HOST_UDF: {
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<reduce_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for reduction.");
return (*udf_ptr)(col, output_dtype, init, stream, mr);
} // case aggregation::HOST_UDF
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand All @@ -161,9 +169,11 @@ std::unique_ptr<scalar> reduce(column_view const& col,
cudf::data_type_error);
if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT ||
agg.kind == aggregation::MIN || agg.kind == aggregation::MAX ||
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) {
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL ||
agg.kind == aggregation::HOST_UDF)) {
CUDF_FAIL(
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, and ALL aggregation types");
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, ALL, and HOST_UDF "
"aggregation types");
}

// Returns default scalar if input column is empty or all null
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/reductions/segmented/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
Expand Down Expand Up @@ -98,6 +100,13 @@ struct segmented_reduce_dispatch_functor {
}
case segmented_reduce_aggregation::NUNIQUE:
return segmented_nunique(col, offsets, null_handling, stream, mr);
case aggregation::HOST_UDF: {
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<segmented_reduce_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for segmented reduction.");
return (*udf_ptr)(col, offsets, output_dtype, null_handling, init, stream, mr);
} // case aggregation::HOST_UDF
default: CUDF_FAIL("Unsupported aggregation type.");
}
}
Expand All @@ -117,9 +126,11 @@ std::unique_ptr<column> segmented_reduce(column_view const& segmented_values,
cudf::data_type_error);
if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT ||
agg.kind == aggregation::MIN || agg.kind == aggregation::MAX ||
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) {
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL ||
agg.kind == aggregation::HOST_UDF)) {
CUDF_FAIL(
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, and ALL aggregation types");
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, ALL, and HOST_UDF "
"aggregation types");
}

if (segmented_values.is_empty() && offsets.empty()) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,12 @@ ConfigureTest(
REDUCTIONS_TEST
reductions/collect_ops_tests.cpp
reductions/ewm_tests.cpp
reductions/host_udf_example_tests.cu
reductions/list_rank_test.cpp
reductions/rank_tests.cpp
reductions/reduction_tests.cpp
reductions/scan_tests.cpp
reductions/segmented_reduction_tests.cpp
reductions/list_rank_test.cpp
reductions/tdigest_tests.cu
GPUS 1
PERCENT 70
Expand Down
Loading

0 comments on commit a7d8e21

Please sign in to comment.