Skip to content

Commit

Permalink
Always reindex=True for all numpy inputs (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Mar 26, 2023
1 parent 13d1062 commit 24dc7fd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
17 changes: 14 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,9 +1519,15 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:


def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, any_by_dask: bool
reindex: bool | None,
func,
method: T_Method,
expected_groups,
any_by_dask: bool,
is_dask_array: bool,
) -> bool:
if reindex is True:
all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
if _is_arg_reduction(func):
raise NotImplementedError
if method in ["blockwise", "cohorts"]:
Expand All @@ -1530,6 +1536,9 @@ def _validate_reindex(
)

if reindex is None:
if all_numpy:
return True

if method == "blockwise" or _is_arg_reduction(func):
reindex = False

Expand Down Expand Up @@ -1796,7 +1805,9 @@ def groupby_reduce(
if method == "split-reduce":
method = "cohorts"

reindex = _validate_reindex(reindex, func, method, expected_groups, any_by_dask)
reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
)

if not is_duck_array(array):
array = np.asarray(array)
Expand Down
37 changes: 29 additions & 8 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def test_subset_block_2d(flatblocks, expectidx):


@pytest.mark.parametrize(
"expected, reindex, func, expected_groups, any_by_dask",
"dask_expected, reindex, func, expected_groups, any_by_dask",
[
# argmax only False
[False, None, "argmax", None, False],
Expand All @@ -1252,22 +1252,43 @@ def test_subset_block_2d(flatblocks, expectidx):
[True, None, "sum", ([1], None), True],
],
)
def test_validate_reindex_map_reduce(expected, reindex, func, expected_groups, any_by_dask):
actual = _validate_reindex(reindex, func, "map-reduce", expected_groups, any_by_dask)
assert actual == expected
def test_validate_reindex_map_reduce(
dask_expected, reindex, func, expected_groups, any_by_dask
) -> None:
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
)
assert actual is dask_expected

# always reindex with all numpy inputs
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
)
assert actual

actual = _validate_reindex(
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
)
assert actual

def test_validate_reindex():

def test_validate_reindex() -> None:
for method in ["map-reduce", "cohorts"]:
with pytest.raises(NotImplementedError):
_validate_reindex(True, "argmax", method, expected_groups=None, any_by_dask=False)
_validate_reindex(
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
)

for method in ["blockwise", "cohorts"]:
with pytest.raises(ValueError):
_validate_reindex(True, "sum", method, expected_groups=None, any_by_dask=False)
_validate_reindex(
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
)

for func in ["sum", "argmax"]:
actual = _validate_reindex(None, func, method, expected_groups=None, any_by_dask=False)
actual = _validate_reindex(
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
)
assert actual is False


Expand Down

0 comments on commit 24dc7fd

Please sign in to comment.