From 14e250864b95c0813c7171d65d6256fbdffedff8 Mon Sep 17 00:00:00 2001 From: Matthew Murray Date: Thu, 9 Jan 2025 07:38:42 -0800 Subject: [PATCH] add a test, clean up reorder gather maps --- python/cudf_polars/cudf_polars/dsl/ir.py | 6 ------ python/cudf_polars/tests/test_join.py | 17 ++++------------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 6d5684edbc3..525bed5453a 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -1139,9 +1139,6 @@ def __init__( self.options = options self.children = (left, right) self._non_child_args = (self.left_on, self.right_on, self.options) - # # TODO: Implement maintain_order - # if options[5] != "none": - # raise NotImplementedError("maintain_order not implemented yet") if any( isinstance(e.value, expr.Literal) for e in itertools.chain(self.left_on, self.right_on) @@ -1224,7 +1221,6 @@ def _reorder_maps( dt = plc.interop.to_arrow(plc.types.SIZE_TYPE) init = plc.interop.from_arrow(pa.scalar(0, type=dt)) step = plc.interop.from_arrow(pa.scalar(1, type=dt)) - if maintain_order in {"none", "left_right", "right_left"}: left_order = plc.copying.gather( plc.Table([plc.filling.sequence(left_rows, init, step)]), @@ -1256,8 +1252,6 @@ def _reorder_maps( sort_keys = left_order.columns() + right_order.columns() elif maintain_order == "right_left": sort_keys = right_order.columns() + left_order.columns() - else: - sort_keys = [] return plc.sorting.stable_sort_by_key( plc.Table([lg, rg]), plc.Table(sort_keys), diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 584ac549ddd..68733853117 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -53,28 +53,19 @@ def right(): ) -@pytest.mark.parametrize( - "maintain_order", ["left", "left_right", "right_left", "right"] -) -def test_join_maintain_order_param_unsupported(left, right, maintain_order): - q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order) - - assert_ir_translation_raises(q, NotImplementedError) - - @pytest.mark.parametrize( "join_expr", [ pl.col("a"), + pl.col("a") * 2, + [pl.col("a"), pl.col("c") + 1], ["c", "a"], ], ) @pytest.mark.parametrize( - "maintain_order", ["none", "left", "right", "left_right", "right_left"] + "maintain_order", ["left", "left_right", "right_left", "right"] ) -def test_join_preserving_different_orderings( - left, right, how, join_expr, maintain_order -): +def test_order_preserving_joins(left, right, how, join_expr, maintain_order): query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order) assert_gpu_result_equal(query)