From 789a02eef9524ddb154443cc11ce229ac08085f2 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Mon, 7 Nov 2022 10:24:45 -0600 Subject: [PATCH] FIX-#5200: Use squeeze parameter instead of SeriesGroupby Signed-off-by: Devin Petersohn --- modin/pandas/groupby.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index dcd990f13b2..8d6f6c2b628 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -425,13 +425,15 @@ def __getitem__(self, key): if is_list_like(key): make_dataframe = True else: + key = [key] if self._as_index: make_dataframe = False else: make_dataframe = True - key = [key] + internal_by = frozenset(self._internal_by) + cols_to_grab = internal_by.union(key) + key = [col for col in self._df.columns if col in cols_to_grab] if make_dataframe: - internal_by = frozenset(self._internal_by) if len(internal_by.intersection(key)) != 0: ErrorMessage.missmatch_with_pandas( operation="GroupBy.__getitem__", @@ -443,8 +445,6 @@ def __getitem__(self, key): + "df.groupby(df['by_column'].copy())['by_column']" ), ) - cols_to_grab = internal_by.union(key) - key = [col for col in self._df.columns if col in cols_to_grab] return DataFrameGroupBy( self._df[key], drop=self._drop, @@ -459,7 +459,8 @@ def __getitem__(self, key): "Column lookups on GroupBy with arbitrary Series in by" + " is not yet supported." ) - return SeriesGroupBy( + kwargs["squeeze"] = True + return DataFrameGroupBy( self._df[key], drop=False, **kwargs,