Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Jan 14, 2025
1 parent f699acc commit 9ca07c2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
15 changes: 7 additions & 8 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8080,7 +8080,7 @@ def value_counts(
return result

@_performance_tracking
def to_pylibcudf(self, copy=False):
def to_pylibcudf(self, copy: bool = False) -> tuple[plc.Table, dict]:
"""
Convert this DataFrame to a pylibcudf.Table.
Expand Down Expand Up @@ -8114,16 +8114,14 @@ def to_pylibcudf(self, copy=False):
"""
if copy:
raise NotImplementedError("copy=True is not supported")
metadata = {}
metadata["index"] = self.index
metadata["column_names"] = self.columns
metadata = {"index": self.index, "columns": self.columns}
return plc.Table(
[col.to_pylibcudf(mode="write") for col in self._columns]
), metadata

@classmethod
@_performance_tracking
def from_pylibcudf(cls, table: plc.Table, metadata: dict):
def from_pylibcudf(cls, table: plc.Table, metadata: dict) -> Self:
"""
Create a DataFrame from a pylibcudf.Table.
Expand All @@ -8148,19 +8146,20 @@ def from_pylibcudf(cls, table: plc.Table, metadata: dict):
the data and mask buffers of the pylibcudf columns, so the newly created
object is not tied to the lifetime of the original pylibcudf.Table.
"""
if metadata is None:
if not isinstance(metadata, dict):
raise ValueError("Must at least pass metadata with column names")
columns = table.columns()
df = cls._from_data(
{
name: cudf.core.column.ColumnBase.from_pylibcudf(
col, data_ptr_exposed=True
)
for name, col in zip(metadata["column_names"], columns)
for name, col in zip(metadata["columns"], columns)
}
)
for key in metadata:
setattr(df, key, metadata[key])
if key in {"index", "columns"}:
setattr(df, key, metadata[key])
return df


Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3845,7 +3845,7 @@ def to_pylibcudf(self, copy=False) -> tuple[plc.Column, dict]:

@classmethod
@_performance_tracking
def from_pylibcudf(cls, col: plc.Column, metadata=None):
def from_pylibcudf(cls, col: plc.Column, metadata=None) -> Self:
"""
Create a Series from a pylibcudf.Column.
Expand All @@ -3871,7 +3871,7 @@ def from_pylibcudf(cls, col: plc.Column, metadata=None):
col, data_ptr_exposed=True
)
)
if metadata:
if isinstance(metadata, dict):
for key in metadata:
setattr(cudf_col, key, metadata[key])
return cudf_col
Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11258,6 +11258,5 @@ def test_dataframe_multiindex_column_names(names):
@pytest.mark.parametrize("pdf", _dataframe_na_data())
def test_roundtrip_dataframe_plc_table(pdf):
expect = cudf.DataFrame.from_pandas(pdf)
plc_table, metadata = expect.to_pylibcudf()
actual = cudf.DataFrame.from_pylibcudf(plc_table, metadata=metadata)
actual = cudf.DataFrame.from_pylibcudf(*expect.to_pylibcudf())
assert_eq(expect, actual)
3 changes: 1 addition & 2 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3008,6 +3008,5 @@ def test_dtype_dtypes_equal():
@pytest.mark.parametrize("ps", _series_na_data())
def test_roundtrip_series_plc_column(ps):
expect = cudf.Series(ps)
plc_col, metadata = expect.to_pylibcudf()
actual = cudf.Series.from_pylibcudf(plc_col, metadata=metadata)
actual = cudf.Series.from_pylibcudf(*expect.to_pylibcudf())
assert_eq(expect, actual)

0 comments on commit 9ca07c2

Please sign in to comment.