Skip to content

Commit

Permalink
add a test, clean up reorder gather maps
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Jan 9, 2025
1 parent 6a10590 commit 14e2508
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 19 deletions.
6 changes: 0 additions & 6 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,6 @@ def __init__(
self.options = options
self.children = (left, right)
self._non_child_args = (self.left_on, self.right_on, self.options)
# # TODO: Implement maintain_order
# if options[5] != "none":
# raise NotImplementedError("maintain_order not implemented yet")
if any(
isinstance(e.value, expr.Literal)
for e in itertools.chain(self.left_on, self.right_on)
Expand Down Expand Up @@ -1224,7 +1221,6 @@ def _reorder_maps(
dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
init = plc.interop.from_arrow(pa.scalar(0, type=dt))
step = plc.interop.from_arrow(pa.scalar(1, type=dt))

if maintain_order in {"none", "left_right", "right_left"}:
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
Expand Down Expand Up @@ -1256,8 +1252,6 @@ def _reorder_maps(
sort_keys = left_order.columns() + right_order.columns()
elif maintain_order == "right_left":
sort_keys = right_order.columns() + left_order.columns()
else:
sort_keys = []
return plc.sorting.stable_sort_by_key(
plc.Table([lg, rg]),
plc.Table(sort_keys),
Expand Down
17 changes: 4 additions & 13 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,19 @@ def right():
)


@pytest.mark.parametrize(
"maintain_order", ["left", "left_right", "right_left", "right"]
)
def test_join_maintain_order_param_unsupported(left, right, maintain_order):
q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order)

assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
pl.col("a") * 2,
[pl.col("a"), pl.col("c") + 1],
["c", "a"],
],
)
@pytest.mark.parametrize(
"maintain_order", ["none", "left", "right", "left_right", "right_left"]
"maintain_order", ["left", "left_right", "right_left", "right"]
)
def test_join_preserving_different_orderings(
left, right, how, join_expr, maintain_order
):
def test_order_preserving_joins(left, right, how, join_expr, maintain_order):
query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order)
assert_gpu_result_equal(query)

Expand Down

0 comments on commit 14e2508

Please sign in to comment.