Skip to content

Commit

Permalink
rework incompatible_dtype test
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 25, 2025
1 parent eb6b721 commit a990846
Showing 1 changed file with 43 additions and 48 deletions.
91 changes: 43 additions & 48 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from array_api_extra._lib._utils._typing import Array, Index
from array_api_extra.testing import lazy_xp_function

pytestmark = [
pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
]


def at_op( # type: ignore[no-any-explicit]
x: Array,
Expand Down Expand Up @@ -71,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.parametrize(
("kwargs", "expect_copy"),
[
Expand Down Expand Up @@ -170,78 +173,70 @@ def test_alternate_index_syntax():
at(a, 0)[0].set(4)


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.parametrize("copy", [True, None])
@pytest.mark.parametrize(
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
)
def test_iops_incompatible_dtype(
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None
@pytest.mark.parametrize("bool_mask", [False, True])
@pytest.mark.parametrize("op", list(_AtOp))
def test_incompatible_dtype(
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
):
"""Test that at() replicates the backend's behaviour for
in-place operations with incompatible dtypes.
Note:
Behavior is backend-specific, but only two behaviors are allowed:
1. raise an exception, or
2. return the same dtype as x, disregarding y.dtype (no broadcasting).
Note that __i<op>__ and __<op>__ behave differently, and we want to
replicate the behavior of __i<op>__:
>>> a = np.asarray([1, 2, 3])
>>> a / 1.5
array([0. , 0.66666667, 1.33333333])
>>> a /= 1.5
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
to dtype('int64') with casting rule 'same_kind'
See Also
--------
"""
x = xp.asarray([2, 4])

if library is Backend.DASK:
z = at_op(x, slice(None), op, 1.1, copy=copy)
assert z.dtype == x.dtype

elif library is Backend.JAX:
with pytest.warns(FutureWarning, match="cannot safely cast"):
z = at_op(x, slice(None), op, 1.1, copy=copy)
assert z.dtype == x.dtype

else:
idx = xp.asarray([True, False]) if bool_mask else slice(None)
z = None

if library is Backend.JAX:
if bool_mask:
z = at_op(x, idx, op, 1.1, copy=copy)
else:
with pytest.warns(FutureWarning, match="cannot safely cast"):
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.DASK:
if op in (_AtOp.MIN, _AtOp.MAX):
pytest.xfail(reason="need array-api-compat 1.11")
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
at_op(x, slice(None), op, 1.1, copy=copy)
at_op(x, idx, op, 1.1, copy=copy)


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.parametrize(
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
)
def test_bool_mask_incompatible_dtype(xp: ModuleType, library: Backend, op: _AtOp):
"""
When xp.where(idx, y, x) would promote the dtype of the output
to y.dtype, at(x, idx).set(y) must retain x.dtype instead
"""
x = xp.asarray([1, 2])
idx = xp.asarray([True, False])
if library in (Backend.DASK, Backend.JAX):
z = at_op(x, idx, op, 1.1)
assert z.dtype == x.dtype
elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX):
# There is no __i<op>__ version of these operations
z = at_op(x, idx, op, 1.1, copy=copy)

else:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
at_op(x, idx, op, 1.1)
at_op(x, idx, op, 1.1, copy=copy)

assert z is None or z.dtype == x.dtype


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
def test_bool_mask_nd(xp: ModuleType):
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
idx = xp.asarray([[True, False, False], [False, True, True]])
z = at_op(x, idx, _AtOp.SET, 0)
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
@pytest.mark.parametrize("bool_mask", [False, True])
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
Expand Down

0 comments on commit a990846

Please sign in to comment.