Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HOST_UDF aggregation for reduction and segmented reduction #17645

Merged
merged 43 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
2da9137
Implement `HOST_UDF` aggregation for reduction and segmented reduction
ttnghia Dec 20, 2024
fca442a
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Dec 20, 2024
4b5aa95
Fix example
ttnghia Dec 20, 2024
5982f3e
Update docs
ttnghia Dec 20, 2024
edaa007
Separate one base class into 4 base classes
ttnghia Dec 24, 2024
07579d0
Fix compile errors
ttnghia Dec 24, 2024
2a497e1
Refactor the base classes
ttnghia Dec 27, 2024
86ad3bb
Simplify the interface for reduction and segmented reduction
ttnghia Dec 28, 2024
f27c9fd
Rewrite tests
ttnghia Dec 28, 2024
2deeb3b
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Dec 28, 2024
41cf444
Rewrite comments and reformat
ttnghia Dec 28, 2024
e4cabfd
Reformat
ttnghia Dec 28, 2024
ca83ecf
Change return type to `const&`
ttnghia Dec 28, 2024
54a1e2d
Fix docs
ttnghia Dec 29, 2024
b3e5ce6
Fix inheritant property
ttnghia Dec 29, 2024
9b3e4fb
Fix forward declaration
ttnghia Dec 29, 2024
51d4aa6
Update copyright years
ttnghia Jan 2, 2025
dafe5ab
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Jan 4, 2025
f59d6b1
Refactor groupby base class, further simplifying it
ttnghia Jan 6, 2025
f315813
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Jan 6, 2025
8111e64
Fix style
ttnghia Jan 6, 2025
d7686aa
Fix typo
ttnghia Jan 7, 2025
98c5e77
Update cpp/include/cudf/aggregation/host_udf.hpp
ttnghia Jan 7, 2025
28329a7
Fix typo
ttnghia Jan 7, 2025
c3b52da
Fix typo
ttnghia Jan 7, 2025
89e8380
Remove unused header
ttnghia Jan 7, 2025
4f6d340
Update cpp/include/cudf/aggregation/host_udf.hpp
ttnghia Jan 8, 2025
9d9429a
Update cpp/include/cudf/aggregation/host_udf.hpp
ttnghia Jan 8, 2025
76b5687
Change parameter order
ttnghia Jan 8, 2025
38290ac
Extract reduction and segmented reduction tests
ttnghia Jan 10, 2025
1cb656b
Rename base classes
ttnghia Jan 10, 2025
78dc6a6
Rewrite callbacks
ttnghia Jan 10, 2025
9810e6e
Update copyright header
ttnghia Jan 10, 2025
cecb7f7
Misc
ttnghia Jan 10, 2025
1f24ab0
Rename variables
ttnghia Jan 10, 2025
a6b9699
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Jan 10, 2025
8b88982
Update cpp/include/cudf/aggregation/host_udf.hpp
ttnghia Jan 10, 2025
ad96501
Try to fix docs
ttnghia Jan 10, 2025
b274eec
Merge branch 'branch-25.02' into host_udf_reduction
ttnghia Jan 10, 2025
4b7c874
Try to fix docs build
ttnghia Jan 10, 2025
76e3a87
Test docs
ttnghia Jan 10, 2025
49b676a
Try fixing docs
ttnghia Jan 10, 2025
f68e7fd
Minimal fix for doc builds plus some minor typo fixes
vyasr Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -601,7 +601,7 @@ std::unique_ptr<Base> make_udf_aggregation(udf_type type,
data_type output_type);

// Forward declaration of `host_udf_base` for the factory function of `HOST_UDF` aggregation.
struct host_udf_base;
class host_udf_base;

/**
* @brief Factory to create a HOST_UDF aggregation.
Expand Down
483 changes: 306 additions & 177 deletions cpp/include/cudf/aggregation/host_udf.hpp

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -967,7 +967,9 @@ class udf_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying host-based UDF aggregation.
*/
class host_udf_aggregation final : public groupby_aggregation {
class host_udf_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public segmented_reduce_aggregation {
public:
std::unique_ptr<host_udf_base> udf_ptr;

Expand Down
9 changes: 6 additions & 3 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -145,8 +145,11 @@ struct empty_column_constructor {
}

if constexpr (k == aggregation::Kind::HOST_UDF) {
auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
return std::get<std::unique_ptr<column>>(udf_ptr->get_empty_output(std::nullopt, stream, mr));
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<groupby_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for groupby aggregation.");
return udf_ptr->get_empty_output(stream, mr);
}

return make_empty_column(target_type(values.type(), k));
Expand Down
81 changes: 32 additions & 49 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -795,58 +795,41 @@ void aggregate_result_functor::operator()<aggregation::HOST_UDF>(aggregation con
{
if (cache.has_result(values, agg)) { return; }

auto const& udf_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const data_attrs = [&]() -> host_udf_base::data_attribute_set_t {
if (auto tmp = udf_ptr->get_required_data(); !tmp.empty()) { return tmp; }
// Empty attribute set means everything.
return {host_udf_base::groupby_data_attribute::INPUT_VALUES,
host_udf_base::groupby_data_attribute::GROUPED_VALUES,
host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES,
host_udf_base::groupby_data_attribute::NUM_GROUPS,
host_udf_base::groupby_data_attribute::GROUP_OFFSETS,
host_udf_base::groupby_data_attribute::GROUP_LABELS};
}();
auto const& udf_base_ptr = dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<groupby_host_udf*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for groupby aggregation.");

// Do not cache udf_input, as the actual input data may change from run to run.
host_udf_base::input_map_t udf_input;
for (auto const& attr : data_attrs) {
CUDF_EXPECTS(std::holds_alternative<host_udf_base::groupby_data_attribute>(attr.value) ||
std::holds_alternative<std::unique_ptr<aggregation>>(attr.value),
"Invalid input data attribute for HOST_UDF groupby aggregation.");
if (std::holds_alternative<host_udf_base::groupby_data_attribute>(attr.value)) {
switch (std::get<host_udf_base::groupby_data_attribute>(attr.value)) {
case host_udf_base::groupby_data_attribute::INPUT_VALUES:
udf_input.emplace(attr, values);
break;
case host_udf_base::groupby_data_attribute::GROUPED_VALUES:
udf_input.emplace(attr, get_grouped_values());
break;
case host_udf_base::groupby_data_attribute::SORTED_GROUPED_VALUES:
udf_input.emplace(attr, get_sorted_values());
break;
case host_udf_base::groupby_data_attribute::NUM_GROUPS:
udf_input.emplace(attr, helper.num_groups(stream));
break;
case host_udf_base::groupby_data_attribute::GROUP_OFFSETS:
udf_input.emplace(attr, helper.group_offsets(stream));
break;
case host_udf_base::groupby_data_attribute::GROUP_LABELS:
udf_input.emplace(attr, helper.group_labels(stream));
break;
default: CUDF_UNREACHABLE("Invalid input data attribute for HOST_UDF groupby aggregation.");
}
} else { // data is result from another aggregation
auto other_agg = std::get<std::unique_ptr<aggregation>>(attr.value)->clone();
if (!udf_ptr->callback_input_values) {
udf_ptr->callback_input_values = [&]() -> column_view { return values; };
}
if (!udf_ptr->callback_grouped_values) {
udf_ptr->callback_grouped_values = [&]() -> column_view { return get_grouped_values(); };
}
if (!udf_ptr->callback_sorted_grouped_values) {
udf_ptr->callback_sorted_grouped_values = [&]() -> column_view { return get_sorted_values(); };
}
if (!udf_ptr->callback_num_groups) {
udf_ptr->callback_num_groups = [&]() -> size_type { return helper.num_groups(stream); };
}
if (!udf_ptr->callback_group_offsets) {
udf_ptr->callback_group_offsets = [&]() -> device_span<size_type const> {
return helper.group_offsets(stream);
};
}
if (!udf_ptr->callback_group_labels) {
udf_ptr->callback_group_labels = [&]() -> device_span<size_type const> {
return helper.group_labels(stream);
};
}
if (!udf_ptr->callback_compute_aggregation) {
udf_ptr->callback_compute_aggregation =
[&](std::unique_ptr<aggregation> other_agg) -> column_view {
cudf::detail::aggregation_dispatcher(other_agg->kind, *this, *other_agg);
auto result = cache.get_result(values, *other_agg);
udf_input.emplace(std::move(other_agg), std::move(result));
}
return cache.get_result(values, *other_agg);
};
}

auto output = (*udf_ptr)(udf_input, stream, mr);
CUDF_EXPECTS(std::holds_alternative<std::unique_ptr<column>>(output),
"Invalid output type from HOST_UDF groupby aggregation.");
cache.add_result(values, agg, std::get<std::unique_ptr<column>>(std::move(output)));
cache.add_result(values, agg, (*udf_ptr)(stream, mr));
}

} // namespace detail
Expand Down
48 changes: 5 additions & 43 deletions cpp/src/groupby/sort/host_udf_aggregation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,51 +16,9 @@

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/visitor_overload.hpp>

namespace cudf {

host_udf_base::data_attribute::data_attribute(data_attribute const& other)
: value{std::visit(cudf::detail::visitor_overload{[](auto const& val) { return value_type{val}; },
[](std::unique_ptr<aggregation> const& val) {
return value_type{val->clone()};
}},
other.value)}
{
}

std::size_t host_udf_base::data_attribute::hash::operator()(data_attribute const& attr) const
{
auto const hash_value =
std::visit(cudf::detail::visitor_overload{
[](auto const& val) { return std::hash<int>{}(static_cast<int>(val)); },
[](std::unique_ptr<aggregation> const& val) { return val->do_hash(); }},
attr.value);
return std::hash<std::size_t>{}(attr.value.index()) ^ hash_value;
}

bool host_udf_base::data_attribute::equal_to::operator()(data_attribute const& lhs,
data_attribute const& rhs) const
{
auto const& lhs_val = lhs.value;
auto const& rhs_val = rhs.value;
if (lhs_val.index() != rhs_val.index()) { return false; }
return std::visit(
cudf::detail::visitor_overload{
[](auto const& lhs_val, auto const& rhs_val) {
if constexpr (std::is_same_v<decltype(lhs_val), decltype(rhs_val)>) {
return lhs_val == rhs_val;
} else {
return false;
}
},
[](std::unique_ptr<aggregation> const& lhs_val, std::unique_ptr<aggregation> const& rhs_val) {
return lhs_val->is_equal(*rhs_val);
}},
lhs_val,
rhs_val);
}

namespace detail {

host_udf_aggregation::host_udf_aggregation(std::unique_ptr<host_udf_base> udf_ptr_)
Expand Down Expand Up @@ -99,5 +57,9 @@ template CUDF_EXPORT std::unique_ptr<aggregation> make_host_udf_aggregation<aggr
std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<groupby_aggregation>
make_host_udf_aggregation<groupby_aggregation>(std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<reduce_aggregation>
make_host_udf_aggregation<reduce_aggregation>(std::unique_ptr<host_udf_base>);
template CUDF_EXPORT std::unique_ptr<segmented_reduce_aggregation>
make_host_udf_aggregation<segmented_reduce_aggregation>(std::unique_ptr<host_udf_base>);

} // namespace cudf
16 changes: 13 additions & 3 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/column/column.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/copy.hpp>
Expand Down Expand Up @@ -144,6 +145,13 @@ struct reduce_dispatch_functor {
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
}
case aggregation::HOST_UDF: {
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<reduce_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for reduction.");
return (*udf_ptr)(col, output_dtype, init, stream, mr);
} // case aggregation::HOST_UDF
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand All @@ -161,9 +169,11 @@ std::unique_ptr<scalar> reduce(column_view const& col,
cudf::data_type_error);
if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT ||
agg.kind == aggregation::MIN || agg.kind == aggregation::MAX ||
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) {
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL ||
agg.kind == aggregation::HOST_UDF)) {
CUDF_FAIL(
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, and ALL aggregation types");
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, ALL, and HOST_UDF "
"aggregation types");
}

// Returns default scalar if input column is empty or all null
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/reductions/segmented/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/aggregation/host_udf.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
Expand Down Expand Up @@ -98,6 +100,13 @@ struct segmented_reduce_dispatch_functor {
}
case segmented_reduce_aggregation::NUNIQUE:
return segmented_nunique(col, offsets, null_handling, stream, mr);
case aggregation::HOST_UDF: {
auto const& udf_base_ptr =
dynamic_cast<cudf::detail::host_udf_aggregation const&>(agg).udf_ptr;
auto const udf_ptr = dynamic_cast<segmented_reduce_host_udf const*>(udf_base_ptr.get());
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid HOST_UDF instance for segmented reduction.");
return (*udf_ptr)(col, offsets, output_dtype, null_handling, init, stream, mr);
} // case aggregation::HOST_UDF
default: CUDF_FAIL("Unsupported aggregation type.");
}
}
Expand All @@ -117,9 +126,11 @@ std::unique_ptr<column> segmented_reduce(column_view const& segmented_values,
cudf::data_type_error);
if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT ||
agg.kind == aggregation::MIN || agg.kind == aggregation::MAX ||
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) {
agg.kind == aggregation::ANY || agg.kind == aggregation::ALL ||
agg.kind == aggregation::HOST_UDF)) {
CUDF_FAIL(
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, and ALL aggregation types");
"Initial value is only supported for SUM, PRODUCT, MIN, MAX, ANY, ALL, and HOST_UDF "
"aggregation types");
}

if (segmented_values.is_empty() && offsets.empty()) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,12 @@ ConfigureTest(
REDUCTIONS_TEST
reductions/collect_ops_tests.cpp
reductions/ewm_tests.cpp
reductions/host_udf_example_tests.cu
reductions/list_rank_test.cpp
reductions/rank_tests.cpp
reductions/reduction_tests.cpp
reductions/scan_tests.cpp
reductions/segmented_reduction_tests.cpp
reductions/list_rank_test.cpp
reductions/tdigest_tests.cu
GPUS 1
PERCENT 70
Expand Down
Loading
Loading