Skip to content

Commit

Permalink
Adjust search_sorted output return
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Jan 27, 2025
1 parent 622b4ef commit b3aa5cd
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 32 deletions.
20 changes: 10 additions & 10 deletions python/cudf/cudf/core/_internals/search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import pylibcudf as plc

from cudf.core.buffer import acquire_spill_lock
from cudf.core.column import ColumnBase

if TYPE_CHECKING:
from cudf.core.column import ColumnBase


@acquire_spill_lock()
Expand All @@ -16,7 +18,7 @@ def search_sorted(
side: Literal["left", "right"],
ascending: bool = True,
na_position: Literal["first", "last"] = "last",
) -> ColumnBase:
) -> plc.Column:
"""Find indices where elements should be inserted to maintain order
Parameters
Expand All @@ -43,11 +45,9 @@ def search_sorted(
plc.search,
"lower_bound" if side == "left" else "upper_bound",
)
return ColumnBase.from_pylibcudf(
func(
plc.Table([col.to_pylibcudf(mode="read") for col in source]),
plc.Table([col.to_pylibcudf(mode="read") for col in values]),
column_order,
null_precedence,
)
return func(
plc.Table([col.to_pylibcudf(mode="read") for col in source]),
plc.Table([col.to_pylibcudf(mode="read") for col in values]),
column_order,
null_precedence,
)
14 changes: 8 additions & 6 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,12 +1812,14 @@ def searchsorted(
raise ValueError(
"Column searchsorted expects values to be column of same dtype"
)
return search.search_sorted( # type: ignore[return-value]
[self],
[value],
side=side,
ascending=ascending,
na_position=na_position,
return ColumnBase.from_pylibcudf(
search.search_sorted( # type: ignore[return-value]
[self],
[value],
side=side,
ascending=ascending,
na_position=na_position,
)
)

def unique(self) -> Self:
Expand Down
14 changes: 8 additions & 6 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,12 +1348,14 @@ def searchsorted(
for val, common_dtype in zip(values, common_dtype_list)
]

outcol = search.search_sorted(
sources,
values,
side,
ascending=ascending,
na_position=na_position,
outcol = ColumnBase.from_pylibcudf(
search.search_sorted(
sources,
values,
side,
ascending=ascending,
na_position=na_position,
)
)

# Return result as cupy array if the values is non-scalar
Expand Down
24 changes: 14 additions & 10 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,21 @@ def _lexsorted_equal_range(
else:
sort_inds = None
sort_vals = idx
lower_bound = search.search_sorted(
list(sort_vals._columns),
keys,
side="left",
ascending=sort_vals.is_monotonic_increasing,
lower_bound = ColumnBase.from_pylibcudf(
search.search_sorted(
list(sort_vals._columns),
keys,
side="left",
ascending=sort_vals.is_monotonic_increasing,
)
).element_indexing(0)
upper_bound = search.search_sorted(
list(sort_vals._columns),
keys,
side="right",
ascending=sort_vals.is_monotonic_increasing,
upper_bound = ColumnBase.from_pylibcudf(
search.search_sorted(
list(sort_vals._columns),
keys,
side="right",
ascending=sort_vals.is_monotonic_increasing,
)
).element_indexing(0)

return lower_bound, upper_bound, sort_inds
Expand Down

0 comments on commit b3aa5cd

Please sign in to comment.