Skip to content

Commit

Permalink
Make self.request a property
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jan 13, 2025
1 parent 54a9cd6 commit faf42d0
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class Agg(Expr):
__slots__ = ("name", "options", "request")
__slots__ = ("_request", "name", "options")
_non_child = ("dtype", "name", "options")

def __init__(
Expand All @@ -50,7 +50,8 @@ def __init__(
_, quantile = self.children
if not isinstance(quantile, Literal):
raise NotImplementedError("Only support literal quantile values")
self.request = None

self._request = None

_SUPPORTED: ClassVar[frozenset[str]] = frozenset(
[
Expand All @@ -77,8 +78,10 @@ def __init__(
"linear": plc.types.Interpolation.LINEAR,
}

def _fill_request(self):
if self.request is None:
@property
def request(self):
"""Return the aggregation request."""
if self._request is None:
# TODO: nan handling in groupby case
if self.name == "min":
req = plc.aggregation.min()
Expand Down Expand Up @@ -115,7 +118,9 @@ def _fill_request(self):
raise NotImplementedError(
f"Unreachable, {self.name=} is incorrectly listed in _SUPPORTED"
) # pragma: no cover
self.request = req
self._request = req

return self._request

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand All @@ -127,7 +132,6 @@ def collect_agg(self, *, depth: int) -> AggInfo:
raise NotImplementedError("Nan propagation in groupby for min/max")
(child,) = self.children
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
self._fill_request()
request = self.request
# These are handled specially here because we don't set up the
# request for the whole-frame agg because we can avoid a
Expand Down Expand Up @@ -234,8 +238,6 @@ def do_evaluate(
f"Agg in context {context}"
) # pragma: no cover; unreachable

self._fill_request()

op = getattr(self, f"_{self.name}", None)
if op is None:
op = partial(self._reduce, request=self.request)
Expand Down

0 comments on commit faf42d0

Please sign in to comment.