Skip to content

Commit

Permalink
Simplify StatsAST expression transformer with ast::tree
Browse files Browse the repository at this point in the history
  • Loading branch information
mhaseeb123 committed Dec 13, 2024
1 parent 92652be commit bf3c5a4
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ class stats_expression_converter : public ast::detail::expression_transformer {
*/
std::reference_wrapper<ast::expression const> visit(ast::literal const& expr) override
{
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
return expr;
}

Expand All @@ -278,7 +277,6 @@ class stats_expression_converter : public ast::detail::expression_transformer {
"Statistics AST supports only left table");
CUDF_EXPECTS(expr.get_column_index() < _num_columns,
"Column index cannot be more than number of columns in the table");
_stats_expr = std::reference_wrapper<ast::expression const>(expr);
return expr;
}

Expand Down Expand Up @@ -307,6 +305,9 @@ class stats_expression_converter : public ast::detail::expression_transformer {
CUDF_EXPECTS(dynamic_cast<ast::literal const*>(&operands[1].get()) != nullptr,
"Second operand of binary operation with column reference must be a literal");
v->accept(*this);
// Push literal into the ast::tree
auto const& literal =
_stats_expr.push(*dynamic_cast<ast::literal const*>(&operands[1].get()));
auto const col_index = v->get_column_index();
switch (op) {
/* transform to stats conditions. op(col, literal)
Expand All @@ -318,48 +319,46 @@ class stats_expression_converter : public ast::detail::expression_transformer {
col1 <= val --> vmin <= val
*/
case ast_operator::EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
auto const& op1 =
_operators.emplace_back(ast_operator::LESS_EQUAL, vmin, operands[1].get());
auto const& op2 =
_operators.emplace_back(ast_operator::GREATER_EQUAL, vmax, operands[1].get());
_operators.emplace_back(ast::ast_operator::LOGICAL_AND, op1, op2);
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{
ast::ast_operator::LOGICAL_AND,
_stats_expr.push(ast::operation{ast_operator::GREATER_EQUAL, vmax, literal}),
_stats_expr.push(ast::operation{ast_operator::LESS_EQUAL, vmin, literal})});
break;
}
case ast_operator::NOT_EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
auto const& op1 = _operators.emplace_back(ast_operator::NOT_EQUAL, vmin, vmax);
auto const& op2 =
_operators.emplace_back(ast_operator::NOT_EQUAL, vmax, operands[1].get());
_operators.emplace_back(ast_operator::LOGICAL_OR, op1, op2);
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{
ast_operator::LOGICAL_OR,
_stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmin, vmax}),
_stats_expr.push(ast::operation{ast_operator::NOT_EQUAL, vmax, literal})});
break;
}
case ast_operator::LESS: [[fallthrough]];
case ast_operator::LESS_EQUAL: {
auto const& vmin = _col_ref.emplace_back(col_index * 2);
_operators.emplace_back(op, vmin, operands[1].get());
auto const& vmin = _stats_expr.push(ast::column_reference{col_index * 2});
_stats_expr.push(ast::operation{op, vmin, literal});
break;
}
case ast_operator::GREATER: [[fallthrough]];
case ast_operator::GREATER_EQUAL: {
auto const& vmax = _col_ref.emplace_back(col_index * 2 + 1);
_operators.emplace_back(op, vmax, operands[1].get());
auto const& vmax = _stats_expr.push(ast::column_reference{col_index * 2 + 1});
_stats_expr.push(ast::operation{op, vmax, literal});
break;
}
default: CUDF_FAIL("Unsupported operation in Statistics AST");
};
} else {
auto new_operands = visit_operands(operands);
if (cudf::ast::detail::ast_operator_arity(op) == 2) {
_operators.emplace_back(op, new_operands.front(), new_operands.back());
_stats_expr.push(ast::operation{op, new_operands.front(), new_operands.back()});
} else if (cudf::ast::detail::ast_operator_arity(op) == 1) {
_operators.emplace_back(op, new_operands.front());
_stats_expr.push(ast::operation{op, new_operands.front()});
}
}
_stats_expr = std::reference_wrapper<ast::expression const>(_operators.back());
return std::reference_wrapper<ast::expression const>(_operators.back());
return _stats_expr.back();
}

/**
Expand All @@ -369,7 +368,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {
*/
[[nodiscard]] std::reference_wrapper<ast::expression const> get_stats_expr() const
{
return _stats_expr.value().get();
return _stats_expr.back();
}

private:
Expand All @@ -383,10 +382,8 @@ class stats_expression_converter : public ast::detail::expression_transformer {
}
return transformed_operands;
}
std::optional<std::reference_wrapper<ast::expression const>> _stats_expr;
ast::tree _stats_expr;
size_type _num_columns;
std::list<ast::column_reference> _col_ref;
std::list<ast::operation> _operators;
};
} // namespace

Expand Down

0 comments on commit bf3c5a4

Please sign in to comment.