From bf3c5a4b66617a71f27a2db79a80a52901b24dc1 Mon Sep 17 00:00:00 2001 From: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> Date: Fri, 13 Dec 2024 00:00:37 +0000 Subject: [PATCH] Simplify StatsAST expression transformer with ast::tree --- cpp/src/io/parquet/predicate_pushdown.cpp | 51 +++++++++++------------ 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index b0cbabf1c12..9047ff9169b 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -265,7 +265,6 @@ class stats_expression_converter : public ast::detail::expression_transformer { */ std::reference_wrapper visit(ast::literal const& expr) override { - _stats_expr = std::reference_wrapper(expr); return expr; } @@ -278,7 +277,6 @@ class stats_expression_converter : public ast::detail::expression_transformer { "Statistics AST supports only left table"); CUDF_EXPECTS(expr.get_column_index() < _num_columns, "Column index cannot be more than number of columns in the table"); - _stats_expr = std::reference_wrapper(expr); return expr; } @@ -307,6 +305,9 @@ class stats_expression_converter : public ast::detail::expression_transformer { CUDF_EXPECTS(dynamic_cast(&operands[1].get()) != nullptr, "Second operand of binary operation with column reference must be a literal"); v->accept(*this); + // Push literal into the ast::tree + auto const& literal = + _stats_expr.push(*dynamic_cast(&operands[1].get())); auto const col_index = v->get_column_index(); switch (op) { /* transform to stats conditions. op(col, literal) @@ -318,34 +319,33 @@ class stats_expression_converter : public ast::detail::expression_transformer { col1 <= val --> vmin <= val */ case ast_operator::EQUAL: { - auto const& vmin = _col_ref.emplace_back(col_index * 2); - auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1); - auto const& op1 = - _operators.emplace_back(ast_operator::LESS_EQUAL, vmin, operands[1].get()); - auto const& op2 = - _operators.emplace_back(ast_operator::GREATER_EQUAL, vmax, operands[1].get()); - _operators.emplace_back(ast::ast_operator::LOGICAL_AND, op1, op2); + auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2}); + auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1}); + _stats_expr.push(ast::operation{ + ast::ast_operator::LOGICAL_AND, + _stats_expr.push(ast::operation{ast_operator::GREATER_EQUAL, vmax, literal}), + _stats_expr.push(ast::operation{ast_operator::LESS_EQUAL, vmin, literal})}); break; } case ast_operator::NOT_EQUAL: { - auto const& vmin = _col_ref.emplace_back(col_index * 2); - auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1); - auto const& op1 = _operators.emplace_back(ast_operator::NOT_EQUAL, vmin, vmax); - auto const& op2 = - _operators.emplace_back(ast_operator::NOT_EQUAL, vmax, operands[1].get()); - _operators.emplace_back(ast_operator::LOGICAL_OR, op1, op2); + auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2}); + auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1}); + _stats_expr.push(ast::operation{ + ast_operator::LOGICAL_OR, + _stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmin, vmax}), + _stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmax, literal})}); break; } case ast_operator::LESS: [[fallthrough]]; case ast_operator::LESS_EQUAL: { - auto const& vmin = _col_ref.emplace_back(col_index * 2); - _operators.emplace_back(op, vmin, operands[1].get()); + auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2}); + _stats_expr.push(ast::operation{op, vmin, literal}); break; } case ast_operator::GREATER: [[fallthrough]]; case ast_operator::GREATER_EQUAL: { - auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1); - _operators.emplace_back(op, vmax, operands[1].get()); + auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1}); + _stats_expr.push(ast::operation{op, vmax, literal}); break; } default: CUDF_FAIL("Unsupported operation in Statistics AST"); @@ -353,13 +353,12 @@ class stats_expression_converter : public ast::detail::expression_transformer { } else { auto new_operands = visit_operands(operands); if (cudf::ast::detail::ast_operator_arity(op) == 2) { - _operators.emplace_back(op, new_operands.front(), new_operands.back()); + _stats_expr.push(ast::operation{op, new_operands.front(), new_operands.back()}); } else if (cudf::ast::detail::ast_operator_arity(op) == 1) { - _operators.emplace_back(op, new_operands.front()); + _stats_expr.push(ast::operation{op, new_operands.front()}); } } - _stats_expr = std::reference_wrapper(_operators.back()); - return std::reference_wrapper(_operators.back()); + return _stats_expr.back(); } /** @@ -369,7 +368,7 @@ class stats_expression_converter : public ast::detail::expression_transformer { */ [[nodiscard]] std::reference_wrapper get_stats_expr() const { - return _stats_expr.value().get(); + return _stats_expr.back(); } private: @@ -383,10 +382,8 @@ class stats_expression_converter : public ast::detail::expression_transformer { } return transformed_operands; } - std::optional> _stats_expr; + ast::tree _stats_expr; size_type _num_columns; - std::list _col_ref; - std::list _operators; }; } // namespace