From 99f171c9227b0594b4abcf4ae6900c2bdc5e3e79 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:58:22 -0800 Subject: [PATCH] Fix scalar element indexing with new scalar return --- python/cudf/cudf/core/column/lists.py | 7 +++++ python/cudf/cudf/core/column/struct.py | 2 +- python/cudf/cudf/core/dtypes.py | 42 ++++++++++++++++++++++++++ python/cudf/cudf/core/scalar.py | 33 -------------------- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index 04b4003c510..1dbde0015c2 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -110,6 +110,13 @@ def memory_usage(self): ) return n + def element_indexing(self, index: int) -> list: + result = super().element_indexing(index) + if isinstance(result, list): + return self.dtype._recursively_replace_fields(result) + else: + return result + def __setitem__(self, key, value): if isinstance(value, list): value = cudf.Scalar(value) diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index 052a68cec98..2e10166295b 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -120,7 +120,7 @@ def memory_usage(self) -> int: def element_indexing(self, index: int) -> dict: result = super().element_indexing(index) - return dict(zip(self.dtype.fields, result.values())) + return self.dtype._recursively_replace_fields(result) def __setitem__(self, key, value): if isinstance(value, dict): diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index ce7fb968069..32e695b32e3 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -518,6 +518,28 @@ def deserialize(cls, header: dict, frames: list): def itemsize(self): return self.element_type.itemsize + def _recursively_replace_fields(self, result: list) -> list: + """ + Return a new list result but with the keys of dict element by the keys in StructDtype.fields.keys(). + + Intended when result comes from pylibcudf without preserved nested field names. + """ + if isinstance(self.element_type, StructDtype): + return [ + self.element_type._recursively_replace_fields(res) + if isinstance(res, dict) + else res + for res in result + ] + elif isinstance(self.element_type, ListDtype): + return [ + self.element_type._recursively_replace_fields(res) + if isinstance(res, list) + else res + for res in result + ] + return result + class StructDtype(_BaseDtype): """ @@ -677,6 +699,26 @@ def itemsize(self): for field in self._typ ) + def _recursively_replace_fields(self, result: dict) -> dict: + """ + Return a new dict result but with the keys replaced by the keys in self.fields.keys(). + + Intended when result comes from pylibcudf without preserved nested field names. + """ + new_result = {} + for (new_field, field_dtype), result_value in zip( + self.fields.items(), result.values() + ): + if isinstance(field_dtype, StructDtype) and isinstance( + result_value, dict + ): + new_result[new_field] = ( + field_dtype._recursively_replace_fields(result_value) + ) + else: + new_result[new_field] = result_value + return new_result + decimal_dtype_template = textwrap.dedent( """ diff --git a/python/cudf/cudf/core/scalar.py b/python/cudf/cudf/core/scalar.py index 522b8a28541..719fb25f85f 100644 --- a/python/cudf/cudf/core/scalar.py +++ b/python/cudf/cudf/core/scalar.py @@ -203,37 +203,6 @@ def _to_plc_scalar(value: ScalarLike, dtype: Dtype) -> plc.Scalar: return plc_scalar -def gather_metadata( - dtypes: dict[str, Any], -) -> list[plc.interop.ColumnMetadata]: - """Convert a dict of dtypes to a list of ColumnMetadata objects. - - The metadata is constructed recursively so that nested types are - represented as nested ColumnMetadata objects. - - Parameters - ---------- - dtypes : dict - A dict mapping column names to dtypes. - - Returns - ------- - List[ColumnMetadata] - A list of ColumnMetadata objects. - """ - out = [] - for name, dtype in dtypes.items(): - v = plc.interop.ColumnMetadata(name) - if isinstance(dtype, cudf.StructDtype): - v.children_meta = gather_metadata(dtype.fields) - elif isinstance(dtype, cudf.ListDtype): - # Offsets column is unnamed and has no children - v.children_meta.append(plc.interop.ColumnMetadata("")) - v.children_meta.extend(gather_metadata({"": dtype.element_type})) - out.append(v) - return out - - # Note that the metaclass below can easily be generalized for use with # other classes, if needed in the future. Simply replace the arguments # of the `__call__` method with `*args` and `**kwargs`. This will @@ -393,8 +362,6 @@ def is_valid(self) -> bool: return not cudf.utils.utils._is_null_host_scalar(self._host_value) def _device_value_to_host(self) -> None: - # metadata = [gather_metadata]({"": self.dtype})[0] - # ps = plc.interop.to_arrow(self._device_value, metadata) ps = plc.interop.to_arrow(self._device_value) is_datetime = pa.types.is_timestamp(ps.type) is_timedelta = pa.types.is_duration(ps.type)