Skip to content

Commit

Permalink
Fix scalar element indexing with new scalar return
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Jan 10, 2025
1 parent 76ca634 commit 99f171c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
7 changes: 7 additions & 0 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def memory_usage(self):
)
return n

def element_indexing(self, index: int) -> list:
result = super().element_indexing(index)
if isinstance(result, list):
return self.dtype._recursively_replace_fields(result)
else:
return result

def __setitem__(self, key, value):
if isinstance(value, list):
value = cudf.Scalar(value)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def memory_usage(self) -> int:

def element_indexing(self, index: int) -> dict:
result = super().element_indexing(index)
return dict(zip(self.dtype.fields, result.values()))
return self.dtype._recursively_replace_fields(result)

def __setitem__(self, key, value):
if isinstance(value, dict):
Expand Down
42 changes: 42 additions & 0 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,28 @@ def deserialize(cls, header: dict, frames: list):
def itemsize(self):
return self.element_type.itemsize

def _recursively_replace_fields(self, result: list) -> list:
"""
Return a new list result but with the keys of dict element by the keys in StructDtype.fields.keys().
Intended when result comes from pylibcudf without preserved nested field names.
"""
if isinstance(self.element_type, StructDtype):
return [
self.element_type._recursively_replace_fields(res)
if isinstance(res, dict)
else res
for res in result
]
elif isinstance(self.element_type, ListDtype):
return [
self.element_type._recursively_replace_fields(res)
if isinstance(res, list)
else res
for res in result
]
return result


class StructDtype(_BaseDtype):
"""
Expand Down Expand Up @@ -677,6 +699,26 @@ def itemsize(self):
for field in self._typ
)

def _recursively_replace_fields(self, result: dict) -> dict:
"""
Return a new dict result but with the keys replaced by the keys in self.fields.keys().
Intended when result comes from pylibcudf without preserved nested field names.
"""
new_result = {}
for (new_field, field_dtype), result_value in zip(
self.fields.items(), result.values()
):
if isinstance(field_dtype, StructDtype) and isinstance(
result_value, dict
):
new_result[new_field] = (
field_dtype._recursively_replace_fields(result_value)
)
else:
new_result[new_field] = result_value
return new_result


decimal_dtype_template = textwrap.dedent(
"""
Expand Down
33 changes: 0 additions & 33 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,37 +203,6 @@ def _to_plc_scalar(value: ScalarLike, dtype: Dtype) -> plc.Scalar:
return plc_scalar


def gather_metadata(
dtypes: dict[str, Any],
) -> list[plc.interop.ColumnMetadata]:
"""Convert a dict of dtypes to a list of ColumnMetadata objects.
The metadata is constructed recursively so that nested types are
represented as nested ColumnMetadata objects.
Parameters
----------
dtypes : dict
A dict mapping column names to dtypes.
Returns
-------
List[ColumnMetadata]
A list of ColumnMetadata objects.
"""
out = []
for name, dtype in dtypes.items():
v = plc.interop.ColumnMetadata(name)
if isinstance(dtype, cudf.StructDtype):
v.children_meta = gather_metadata(dtype.fields)
elif isinstance(dtype, cudf.ListDtype):
# Offsets column is unnamed and has no children
v.children_meta.append(plc.interop.ColumnMetadata(""))
v.children_meta.extend(gather_metadata({"": dtype.element_type}))
out.append(v)
return out


# Note that the metaclass below can easily be generalized for use with
# other classes, if needed in the future. Simply replace the arguments
# of the `__call__` method with `*args` and `**kwargs`. This will
Expand Down Expand Up @@ -393,8 +362,6 @@ def is_valid(self) -> bool:
return not cudf.utils.utils._is_null_host_scalar(self._host_value)

def _device_value_to_host(self) -> None:
# metadata = [gather_metadata]({"": self.dtype})[0]
# ps = plc.interop.to_arrow(self._device_value, metadata)
ps = plc.interop.to_arrow(self._device_value)
is_datetime = pa.types.is_timestamp(ps.type)
is_timedelta = pa.types.is_duration(ps.type)
Expand Down

0 comments on commit 99f171c

Please sign in to comment.