Skip to content

Commit

Permalink
Remove mad and tshift when Pandas >= 2.0 as they are removed there. (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
caneff authored Sep 7, 2023
1 parent 5afb54e commit 7ba4fa9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
24 changes: 18 additions & 6 deletions sdks/python/apache_beam/dataframe/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def wrapper(self, *args, **kwargs):
'quantile',
'describe',
'sem',
'mad',
'skew',
'kurt',
'kurtosis',
Expand All @@ -126,6 +125,10 @@ def wrapper(self, *args, **kwargs):
'cov',
'nunique',
]
# mad was removed in Pandas 2.0.
if PD_VERSION < (2, 0):
UNLIFTABLE_AGGREGATIONS.append('mad')

ALL_AGGREGATIONS = (
LIFTABLE_AGGREGATIONS + LIFTABLE_WITH_SUM_AGGREGATIONS +
UNLIFTABLE_AGGREGATIONS)
Expand Down Expand Up @@ -2092,7 +2095,9 @@ def axes(self):
sum = _agg_method(pd.Series, 'sum')
median = _agg_method(pd.Series, 'median')
sem = _agg_method(pd.Series, 'sem')
mad = _agg_method(pd.Series, 'mad')
# mad was removed in Pandas 2.0.
if PD_VERSION < (2, 0):
mad = _agg_method(pd.Series, 'mad')

argmax = frame_base.wont_implement_method(
pd.Series, 'argmax', reason='order-sensitive')
Expand Down Expand Up @@ -3914,10 +3919,12 @@ def pivot_helper(df):
std = _agg_method(pd.DataFrame, 'std')
var = _agg_method(pd.DataFrame, 'var')
sem = _agg_method(pd.DataFrame, 'sem')
mad = _agg_method(pd.DataFrame, 'mad')
skew = _agg_method(pd.DataFrame, 'skew')
kurt = _agg_method(pd.DataFrame, 'kurt')
kurtosis = _agg_method(pd.DataFrame, 'kurtosis')
# mad was removed in Pandas 2.0.
if PD_VERSION < (2, 0):
mad = _agg_method(pd.DataFrame, 'mad')

take = frame_base.wont_implement_method(pd.DataFrame, 'take',
reason='deprecated')
Expand Down Expand Up @@ -4670,7 +4677,10 @@ def _is_unliftable(agg_func):
return _check_str_or_np_builtin(agg_func, UNLIFTABLE_AGGREGATIONS)

NUMERIC_AGGREGATIONS = ['max', 'min', 'prod', 'sum', 'mean', 'median', 'std',
'var', 'sem', 'mad', 'skew', 'kurt', 'kurtosis']
'var', 'sem', 'skew', 'kurt', 'kurtosis']
# mad was removed in Pandas 2.0.
if PD_VERSION < (2, 0):
NUMERIC_AGGREGATIONS.append('mad')

def _is_numeric(agg_func):
return _check_str_or_np_builtin(agg_func, NUMERIC_AGGREGATIONS)
Expand Down Expand Up @@ -4698,7 +4708,6 @@ class _DeferredGroupByCols(frame_base.DeferredFrame):
idxmax = frame_base._elementwise_method('idxmax', base=DataFrameGroupBy)
idxmin = frame_base._elementwise_method('idxmin', base=DataFrameGroupBy)
last = frame_base._elementwise_method('last', base=DataFrameGroupBy)
mad = frame_base._elementwise_method('mad', base=DataFrameGroupBy)
max = frame_base._elementwise_method('max', base=DataFrameGroupBy)
mean = frame_base._elementwise_method('mean', base=DataFrameGroupBy)
median = frame_base._elementwise_method('median', base=DataFrameGroupBy)
Expand All @@ -4717,8 +4726,11 @@ class _DeferredGroupByCols(frame_base.DeferredFrame):
DataFrameGroupBy, 'tail', explanation=_PEEK_METHOD_EXPLANATION)
take = frame_base.wont_implement_method(
DataFrameGroupBy, 'take', reason='deprecated')
tshift = frame_base._elementwise_method('tshift', base=DataFrameGroupBy)
var = frame_base._elementwise_method('var', base=DataFrameGroupBy)
# These already deprecated methods were removed in Pandas 2.0
if PD_VERSION < (2, 0):
mad = frame_base._elementwise_method('mad', base=DataFrameGroupBy)
tshift = frame_base._elementwise_method('tshift', base=DataFrameGroupBy)

@property # type: ignore
@frame_base.with_docs_from(DataFrameGroupBy)
Expand Down
14 changes: 10 additions & 4 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,12 @@ def test_groupby_multiindex_keep_nans(self):
lambda df: df.groupby(['foo', 'bar'], dropna=False).sum(), GROUPBY_DF)


NONPARALLEL_METHODS = ['quantile', 'describe', 'median', 'sem']
# mad was removed in pandas 2
if PD_VERSION < (2, 0):
NONPARALLEL_METHODS.append('mad')


class AggregationTest(_AbstractFrameTest):
"""Tests for global aggregation methods on DataFrame/Series."""

Expand All @@ -1955,7 +1961,7 @@ class AggregationTest(_AbstractFrameTest):
def test_series_agg(self, agg_method):
s = pd.Series(list(range(16)))

nonparallel = agg_method in ('quantile', 'describe', 'median', 'sem', 'mad')
nonparallel = agg_method in NONPARALLEL_METHODS

# TODO(https://github.com/apache/beam/issues/20926): max and min produce
# the wrong proxy
Expand All @@ -1974,7 +1980,7 @@ def test_series_agg(self, agg_method):
def test_series_agg_method(self, agg_method):
s = pd.Series(list(range(16)))

nonparallel = agg_method in ('quantile', 'describe', 'median', 'sem', 'mad')
nonparallel = agg_method in NONPARALLEL_METHODS

# TODO(https://github.com/apache/beam/issues/20926): max and min produce
# the wrong proxy
Expand All @@ -1990,7 +1996,7 @@ def test_series_agg_method(self, agg_method):
def test_dataframe_agg(self, agg_method):
df = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [2, 3, 5, 7]})

nonparallel = agg_method in ('quantile', 'describe', 'median', 'sem', 'mad')
nonparallel = agg_method in NONPARALLEL_METHODS

# TODO(https://github.com/apache/beam/issues/20926): max and min produce
# the wrong proxy
Expand All @@ -2007,7 +2013,7 @@ def test_dataframe_agg(self, agg_method):
def test_dataframe_agg_method(self, agg_method):
df = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [2, 3, 5, 7]})

nonparallel = agg_method in ('quantile', 'describe', 'median', 'sem', 'mad')
nonparallel = agg_method in NONPARALLEL_METHODS

# TODO(https://github.com/apache/beam/issues/20926): max and min produce
# the wrong proxy
Expand Down

0 comments on commit 7ba4fa9

Please sign in to comment.