From b3aa5cdd2d3b9001850a6126e61b52bf81441e93 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 27 Jan 2025 11:03:43 -0800 Subject: [PATCH] Adjust search_sorted output return --- python/cudf/cudf/core/_internals/search.py | 20 +++++++++--------- python/cudf/cudf/core/column/column.py | 14 +++++++------ python/cudf/cudf/core/frame.py | 14 +++++++------ python/cudf/cudf/core/index.py | 24 +++++++++++++--------- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/python/cudf/cudf/core/_internals/search.py b/python/cudf/cudf/core/_internals/search.py index aa410c36575..bee198800e7 100644 --- a/python/cudf/cudf/core/_internals/search.py +++ b/python/cudf/cudf/core/_internals/search.py @@ -1,12 +1,14 @@ # Copyright (c) 2020-2025, NVIDIA CORPORATION. from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING, Literal import pylibcudf as plc from cudf.core.buffer import acquire_spill_lock -from cudf.core.column import ColumnBase + +if TYPE_CHECKING: + from cudf.core.column import ColumnBase @acquire_spill_lock() @@ -16,7 +18,7 @@ def search_sorted( side: Literal["left", "right"], ascending: bool = True, na_position: Literal["first", "last"] = "last", -) -> ColumnBase: +) -> plc.Column: """Find indices where elements should be inserted to maintain order Parameters @@ -43,11 +45,9 @@ def search_sorted( plc.search, "lower_bound" if side == "left" else "upper_bound", ) - return ColumnBase.from_pylibcudf( - func( - plc.Table([col.to_pylibcudf(mode="read") for col in source]), - plc.Table([col.to_pylibcudf(mode="read") for col in values]), - column_order, - null_precedence, - ) + return func( + plc.Table([col.to_pylibcudf(mode="read") for col in source]), + plc.Table([col.to_pylibcudf(mode="read") for col in values]), + column_order, + null_precedence, ) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 99c9b1133ae..d895a87054f 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1812,12 +1812,14 @@ def searchsorted( raise ValueError( "Column searchsorted expects values to be column of same dtype" ) - return search.search_sorted( # type: ignore[return-value] - [self], - [value], - side=side, - ascending=ascending, - na_position=na_position, + return ColumnBase.from_pylibcudf( + search.search_sorted( # type: ignore[return-value] + [self], + [value], + side=side, + ascending=ascending, + na_position=na_position, + ) ) def unique(self) -> Self: diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index acea8991f47..b9d5b0403da 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1348,12 +1348,14 @@ def searchsorted( for val, common_dtype in zip(values, common_dtype_list) ] - outcol = search.search_sorted( - sources, - values, - side, - ascending=ascending, - na_position=na_position, + outcol = ColumnBase.from_pylibcudf( + search.search_sorted( + sources, + values, + side, + ascending=ascending, + na_position=na_position, + ) ) # Return result as cupy array if the values is non-scalar diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 278c8b24e86..e883569a047 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -121,17 +121,21 @@ def _lexsorted_equal_range( else: sort_inds = None sort_vals = idx - lower_bound = search.search_sorted( - list(sort_vals._columns), - keys, - side="left", - ascending=sort_vals.is_monotonic_increasing, + lower_bound = ColumnBase.from_pylibcudf( + search.search_sorted( + list(sort_vals._columns), + keys, + side="left", + ascending=sort_vals.is_monotonic_increasing, + ) ).element_indexing(0) - upper_bound = search.search_sorted( - list(sort_vals._columns), - keys, - side="right", - ascending=sort_vals.is_monotonic_increasing, + upper_bound = ColumnBase.from_pylibcudf( + search.search_sorted( + list(sort_vals._columns), + keys, + side="right", + ascending=sort_vals.is_monotonic_increasing, + ) ).element_indexing(0) return lower_bound, upper_bound, sort_inds