Skip to content
This repository has been archived by the owner on May 9, 2024. It is now read-only.

Commit

Permalink
Add useless sort nodes removal.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Oct 12, 2023
1 parent a45c880 commit 7a9198a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
73 changes: 73 additions & 0 deletions omniscidb/QueryOptimizer/CanonicalizeQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,78 @@ bool isCompoundAggregate(const AggExpr* agg) {
return agg->aggType() == AggType::kStdDevSamp || agg->aggType() == AggType::kCorr;
}

void removeFromValidSortUsers(Node* node, std::map<NodePtr, std::set<Node*>>& users) {
for (size_t input_idx = 0; input_idx < node->inputCount(); ++input_idx) {
auto input = node->getAndOwnInput(input_idx);
if (input->is<Sort>()) {
users.at(input).erase(node);
} else if (!input->is<Aggregate>()) {
removeFromValidSortUsers(input.get(), users);
}
}
}

/**
* Find sort nodes whose result's order is going to be ignored later. It happens if the
* result is later aggregated or sorted again.
*/
void dropDeadSorts(QueryDag& dag) {
auto nodes = dag.getNodes();
std::list<NodePtr> node_list(nodes.begin(), nodes.end());
std::map<NodePtr, std::list<NodePtr>::iterator> sorts;
std::map<NodePtr, std::set<Node*>> valid_sort_users;
std::map<NodePtr, std::set<Node*>> all_sort_users;

// Collect sort nodes and their users.
for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) {
const auto node = *node_itr;

// Store positions of all sort nodes to be able to remove them later from the nodes
// list.
if (node->is<Sort>()) {
sorts[node] = node_itr;
// Root node is always considered to have a valid user. We also cannot drop nodes
// with limit and/or offset specified.
if (dag.getRootNode() == node.get() || node->as<Sort>()->getLimit() ||
node->as<Sort>()->getOffset()) {
valid_sort_users[node].insert(nullptr);
}
}

for (size_t input_idx = 0; input_idx < node->inputCount(); ++input_idx) {
auto input = node->getAndOwnInput(input_idx);
if (input->is<Sort>()) {
all_sort_users[input].insert(node.get());
valid_sort_users[input].insert(node.get());
}
}
}

// Find sort and aggregate nodes and remove their inputs from valid sort users.
for (auto node : node_list) {
if (node->is<Aggregate>() || node->is<Sort>()) {
removeFromValidSortUsers(node.get(), valid_sort_users);
}
}

// Remove sorts with no valid users.
for (auto& pr : valid_sort_users) {
if (pr.second.empty()) {
auto sort = pr.first;
for (auto user : all_sort_users.at(sort)) {
user->replaceInput(sort, sort->getAndOwnInput(0));
}
node_list.erase(sorts.at(sort));
}
}

// Any applied transformation always decreases the number of nodes.
if (node_list.size() != nodes.size()) {
nodes.assign(node_list.begin(), node_list.end());
dag.setNodes(std::move(nodes));
}
}

/**
* Base class holding interface for compound aggregate expansion.
* Compound aggregate is expanded in three steps.
Expand Down Expand Up @@ -470,6 +542,7 @@ void addWindowFunctionPreProject(
} // namespace

void canonicalizeQuery(QueryDag& dag) {
dropDeadSorts(dag);
expandCompoundAggregates(dag);
addWindowFunctionPreProject(dag);
}
Expand Down
24 changes: 24 additions & 0 deletions omniscidb/Tests/PartitionedGroupByTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ TEST_F(PartitionedGroupByTest, AggregationWithSort) {
compare_res_data(res, id1_vals, id2_vals, id3_vals, id4_vals, v1_sums, v2_sums);
}

TEST_F(PartitionedGroupByTest, Issue695) {
auto old_exec = config().exec;
ScopeGuard g([&old_exec]() { config().exec = old_exec; });

config().exec.group_by.default_max_groups_buffer_entry_guess = 1;
config().exec.group_by.big_group_threshold = 1;
config().exec.group_by.enable_cpu_partitioned_groupby = true;
config().exec.group_by.partitioning_buffer_size_threshold = 10;
config().exec.group_by.partitioning_group_size_threshold = 1.5;
config().exec.group_by.min_partitions = 2;
config().exec.group_by.max_partitions = 8;
config().exec.group_by.partitioning_buffer_target_size = 612;
config().exec.enable_multifrag_execution_result = true;

QueryBuilder builder(ctx(), getSchemaProvider(), configPtr());
auto scan = builder.scan("test1");
auto dag1 = scan.agg({"id1"s, "id2"s, "id3"s, "id4"s}, {"sum(v1)"s, "sum(v2)"s})
.sort(0)
.sort({0, 1, 2, 3})
.finalize();
auto res = runQuery(std::move(dag1));
compare_res_data(res, id1_vals, id2_vals, id3_vals, id4_vals, v1_sums, v2_sums);
}

int main(int argc, char* argv[]) {
TestHelpers::init_logger_stderr_only(argc, argv);
testing::InitGoogleTest(&argc, argv);
Expand Down

0 comments on commit 7a9198a

Please sign in to comment.