Skip to content

Commit

Permalink
Fix to avoid requiring serialization for Agg
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jan 10, 2025
1 parent bba8b3f commit 2d37c08
Showing 1 changed file with 4 additions and 51 deletions.
55 changes: 4 additions & 51 deletions python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class Agg(Expr):
__slots__ = ("name", "op", "options", "request")
__slots__ = ("name", "options", "request")
_non_child = ("dtype", "name", "options")

def __init__(
Expand All @@ -46,58 +46,11 @@ def __init__(
raise NotImplementedError(
f"Unsupported aggregation {name=}"
) # pragma: no cover; all valid aggs are supported
# TODO: nan handling in groupby case
if name == "min":
req = plc.aggregation.min()
elif name == "max":
req = plc.aggregation.max()
elif name == "median":
req = plc.aggregation.median()
elif name == "n_unique":
# TODO: datatype of result
req = plc.aggregation.nunique(null_handling=plc.types.NullPolicy.INCLUDE)
elif name == "first" or name == "last":
req = None
elif name == "mean":
req = plc.aggregation.mean()
elif name == "sum":
req = plc.aggregation.sum()
elif name == "std":
# TODO: handle nans
req = plc.aggregation.std(ddof=options)
elif name == "var":
# TODO: handle nans
req = plc.aggregation.variance(ddof=options)
elif name == "count":
req = plc.aggregation.count(
null_handling=plc.types.NullPolicy.EXCLUDE
if not options
else plc.types.NullPolicy.INCLUDE
)
elif name == "quantile":
if name == "quantile":
_, quantile = self.children
if not isinstance(quantile, Literal):
raise NotImplementedError("Only support literal quantile values")
req = plc.aggregation.quantile(
quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options]
)
else:
raise NotImplementedError(
f"Unreachable, {name=} is incorrectly listed in _SUPPORTED"
) # pragma: no cover
self.request = req
op = getattr(self, f"_{name}", None)
if op is None:
op = partial(self._reduce, request=req)
elif name in {"min", "max"}:
op = partial(op, propagate_nans=options)
elif name in {"count", "sum", "first", "last"}:
pass
else:
raise NotImplementedError(
f"Unreachable, supported agg {name=} has no implementation"
) # pragma: no cover
self.op = op
self.request = None

_SUPPORTED: ClassVar[frozenset[str]] = frozenset(
[
Expand Down Expand Up @@ -288,7 +241,7 @@ def do_evaluate(
op = partial(self._reduce, request=self.request)
elif self.name in {"min", "max"}:
op = partial(op, propagate_nans=self.options)
elif self.name in {"count", "first", "last"}:
elif self.name in {"count", "sum", "first", "last"}:
pass
else:
raise NotImplementedError(
Expand Down

0 comments on commit 2d37c08

Please sign in to comment.