diff --git a/omniscidb/QueryOptimizer/CanonicalizeQuery.cpp b/omniscidb/QueryOptimizer/CanonicalizeQuery.cpp index abf0d4b1e..bc368f3d0 100644 --- a/omniscidb/QueryOptimizer/CanonicalizeQuery.cpp +++ b/omniscidb/QueryOptimizer/CanonicalizeQuery.cpp @@ -22,6 +22,78 @@ bool isCompoundAggregate(const AggExpr* agg) { return agg->aggType() == AggType::kStdDevSamp || agg->aggType() == AggType::kCorr; } +void removeFromValidSortUsers(Node* node, std::map>& users) { + for (size_t input_idx = 0; input_idx < node->inputCount(); ++input_idx) { + auto input = node->getAndOwnInput(input_idx); + if (input->is()) { + users.at(input).erase(node); + } else if (!input->is()) { + removeFromValidSortUsers(input.get(), users); + } + } +} + +/** + * Find sort nodes whose result's order is going to be ignored later. It happens if the + * result is later aggregated or sorted again. + */ +void dropDeadSorts(QueryDag& dag) { + auto nodes = dag.getNodes(); + std::list node_list(nodes.begin(), nodes.end()); + std::map::iterator> sorts; + std::map> valid_sort_users; + std::map> all_sort_users; + + // Collect sort nodes and their users. + for (auto node_itr = node_list.begin(); node_itr != node_list.end(); ++node_itr) { + const auto node = *node_itr; + + // Store positions of all sort nodes to be able to remove them later from the nodes + // list. + if (node->is()) { + sorts[node] = node_itr; + // Root node is always considered to have a valid user. We also cannot drop nodes + // with limit and/or offset specified. + if (dag.getRootNode() == node.get() || node->as()->getLimit() || + node->as()->getOffset()) { + valid_sort_users[node].insert(nullptr); + } + } + + for (size_t input_idx = 0; input_idx < node->inputCount(); ++input_idx) { + auto input = node->getAndOwnInput(input_idx); + if (input->is()) { + all_sort_users[input].insert(node.get()); + valid_sort_users[input].insert(node.get()); + } + } + } + + // Find sort and aggregate nodes and remove their inputs from valid sort users. + for (auto node : node_list) { + if (node->is() || node->is()) { + removeFromValidSortUsers(node.get(), valid_sort_users); + } + } + + // Remove sorts with no valid users. + for (auto& pr : valid_sort_users) { + if (pr.second.empty()) { + auto sort = pr.first; + for (auto user : all_sort_users.at(sort)) { + user->replaceInput(sort, sort->getAndOwnInput(0)); + } + node_list.erase(sorts.at(sort)); + } + } + + // Any applied transformation always decreases the number of nodes. + if (node_list.size() != nodes.size()) { + nodes.assign(node_list.begin(), node_list.end()); + dag.setNodes(std::move(nodes)); + } +} + /** * Base class holding interface for compound aggregate expansion. * Compound aggregate is expanded in three steps. @@ -470,6 +542,7 @@ void addWindowFunctionPreProject( } // namespace void canonicalizeQuery(QueryDag& dag) { + dropDeadSorts(dag); expandCompoundAggregates(dag); addWindowFunctionPreProject(dag); } diff --git a/omniscidb/Tests/PartitionedGroupByTest.cpp b/omniscidb/Tests/PartitionedGroupByTest.cpp index 8b01eb718..8afb772ef 100644 --- a/omniscidb/Tests/PartitionedGroupByTest.cpp +++ b/omniscidb/Tests/PartitionedGroupByTest.cpp @@ -173,6 +173,30 @@ TEST_F(PartitionedGroupByTest, AggregationWithSort) { compare_res_data(res, id1_vals, id2_vals, id3_vals, id4_vals, v1_sums, v2_sums); } +TEST_F(PartitionedGroupByTest, Issue695) { + auto old_exec = config().exec; + ScopeGuard g([&old_exec]() { config().exec = old_exec; }); + + config().exec.group_by.default_max_groups_buffer_entry_guess = 1; + config().exec.group_by.big_group_threshold = 1; + config().exec.group_by.enable_cpu_partitioned_groupby = true; + config().exec.group_by.partitioning_buffer_size_threshold = 10; + config().exec.group_by.partitioning_group_size_threshold = 1.5; + config().exec.group_by.min_partitions = 2; + config().exec.group_by.max_partitions = 8; + config().exec.group_by.partitioning_buffer_target_size = 612; + config().exec.enable_multifrag_execution_result = true; + + QueryBuilder builder(ctx(), getSchemaProvider(), configPtr()); + auto scan = builder.scan("test1"); + auto dag1 = scan.agg({"id1"s, "id2"s, "id3"s, "id4"s}, {"sum(v1)"s, "sum(v2)"s}) + .sort(0) + .sort({0, 1, 2, 3}) + .finalize(); + auto res = runQuery(std::move(dag1)); + compare_res_data(res, id1_vals, id2_vals, id3_vals, id4_vals, v1_sums, v2_sums); +} + int main(int argc, char* argv[]) { TestHelpers::init_logger_stderr_only(argc, argv); testing::InitGoogleTest(&argc, argv);