From f0c30d2c87824e296a4c3203e5f0a3c464a27329 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 13 Jan 2025 15:22:04 +0000 Subject: [PATCH 1/2] TST/BUG: run all tests on all backends; fix backend-specific bugs --- .pre-commit-config.yaml | 2 +- docs/api-reference.md | 12 ++ pixi.lock | 2 +- pyproject.toml | 3 +- src/array_api_extra/_funcs.py | 16 +- src/array_api_extra/_lib/_compat.py | 15 ++ src/array_api_extra/_lib/_compat.pyi | 5 + src/array_api_extra/_lib/_utils.py | 4 +- src/array_api_extra/testing.py | 140 +++++++++++++ tests/__init__.py | 1 + tests/conftest.py | 86 ++++++++ tests/test_at.py | 54 ++--- tests/test_funcs.py | 287 ++++++++++++++------------- tests/test_testing.py | 68 +++++++ tests/test_utils.py | 33 +-- vendor_tests/test_vendor.py | 10 + 16 files changed, 537 insertions(+), 201 deletions(-) create mode 100644 src/array_api_extra/testing.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_testing.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42e2206..655be69 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,9 +44,9 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.8.2" hooks: + - id: ruff-format - id: ruff args: ["--fix", "--show-fixes"] - - id: ruff-format - repo: https://github.com/codespell-project/codespell rev: "v2.3.0" diff --git a/docs/api-reference.md b/docs/api-reference.md index b43c960..a6133d3 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -15,3 +15,15 @@ setdiff1d sinc ``` + +## Test tools + +```{eval-rst} +.. currentmodule:: array_api_extra.testing +.. autosummary:: + :nosignatures: + :toctree: generated + + xp_assert_equal + xp_assert_close +``` diff --git a/pixi.lock b/pixi.lock index 5d3612f..2790b20 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2469,7 +2469,7 @@ packages: - pypi: . name: array-api-extra version: 0.5.1.dev0 - sha256: d8083ec4ee363a390f2afd622df56756078ce3ba5f1f67e88867111a2d306b57 + sha256: 8b4533cc75534abb69425a1e5c9f6a4ab96949562d2e90d41ea0e22187a02c1b requires_dist: - array-api-compat>=1.10.0,<2 - furo>=2023.8.17 ; extra == 'docs' diff --git a/pyproject.toml b/pyproject.toml index dadc311..4f5ddac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,7 +180,7 @@ xfail_strict = true filterwarnings = ["error"] log_cli_level = "INFO" testpaths = ["tests"] - +markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"] # Coverage @@ -315,6 +315,7 @@ checks = [ exclude = [ # don't report on objects that match any of these regex '.*test_at.*', '.*test_funcs.*', + '.*test_testing.*', '.*test_utils.*', '.*test_version.*', '.*test_vendor.*', diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 5efbe31..7502561 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -214,8 +214,12 @@ def create_diagonal( raise ValueError(err_msg) n = x.shape[0] + abs(offset) diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x)) - i = offset if offset >= 0 else abs(offset) * n - diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x + + start = offset if offset >= 0 else abs(offset) * n + stop = min(n * (n - offset), diag.shape[0]) + step = n + 1 + diag = at(diag)[start:stop:step].set(x) + return xp.reshape(diag, (n, n)) @@ -407,9 +411,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: result = xp.multiply(a_arr, b_arr) # Reshape back and return - a_shape = xp.asarray(a_shape) - b_shape = xp.asarray(b_shape) - return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape))) + res_shape = tuple(a_s * b_s for a_s, b_s in zip(a_shape, b_shape, strict=True)) + return xp.reshape(result, res_shape) def setdiff1d( @@ -632,8 +635,7 @@ def pad( dtype=x.dtype, device=_compat.device(x), ) - padded[tuple(slices)] = x - return padded + return at(padded, tuple(slices)).set(x) class _AtOp(Enum): diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index a24175d..7d843f1 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -6,20 +6,35 @@ from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] array_namespace, device, + is_cupy_namespace, is_jax_array, + is_jax_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, is_writeable_array, + size, ) except ImportError: from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] array_namespace, device, + is_cupy_namespace, is_jax_array, + is_jax_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, is_writeable_array, + size, ) __all__ = [ "array_namespace", "device", + "is_cupy_namespace", "is_jax_array", + "is_jax_namespace", + "is_pydata_sparse_namespace", + "is_torch_namespace", "is_writeable_array", + "size", ] diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi index 4d06a7f..8532584 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_compat.pyi @@ -18,5 +18,10 @@ def array_namespace( use_compat: bool | None = None, ) -> ArrayModule: ... def device(x: Array, /) -> Device: ... +def is_cupy_namespace(x: object, /) -> bool: ... def is_jax_array(x: object, /) -> bool: ... +def is_jax_namespace(x: object, /) -> bool: ... +def is_pydata_sparse_namespace(x: object, /) -> bool: ... +def is_torch_namespace(x: object, /) -> bool: ... def is_writeable_array(x: object, /) -> bool: ... +def size(x: Array, /) -> int | None: ... diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py index 1191b4f..afb4cfc 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils.py @@ -54,7 +54,9 @@ def in1d( order = xp.argsort(ar, stable=True) reverse_order = xp.argsort(order, stable=True) sar = xp.take(ar, order, axis=0) - if sar.size >= 1: + ar_size = _compat.size(sar) + assert ar_size is not None, "xp.unique*() on lazy backends raises" + if ar_size >= 1: bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1] else: bool_ar = xp.asarray([False]) if invert else xp.asarray([True]) diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py new file mode 100644 index 0000000..ac73ba2 --- /dev/null +++ b/src/array_api_extra/testing.py @@ -0,0 +1,140 @@ +"""Testing utilities.""" + +from ._lib._compat import ( + array_namespace, + is_cupy_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, +) +from ._lib._typing import Array, ModuleType + +__all__ = ["xp_assert_close", "xp_assert_equal"] + + +def _check_ns_shape_dtype( + actual: Array, desired: Array +) -> ModuleType: # numpydoc ignore=RT03 + """ + Assert that namespace, shape and dtype of the two arrays match. + + Parameters + ---------- + actual : Array + The array produced by the tested function. + desired : Array + The expected array (typically hardcoded). + + Returns + ------- + Arrays namespace. + """ + actual_xp = array_namespace(actual) # Raises on scalars and lists + desired_xp = array_namespace(desired) + + msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" + assert actual_xp == desired_xp, msg + + msg = f"shapes do not match: {actual.shape} != f{desired.shape}" + assert actual.shape == desired.shape, msg + + msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" + assert actual.dtype == desired.dtype, msg + + return desired_xp + + +def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: + """ + Array-API compatible version of `np.testing.assert_array_equal`. + + Parameters + ---------- + actual : Array + The array produced by the tested function. + desired : Array + The expected array (typically hardcoded). + err_msg : str, optional + Error message to display on failure. + """ + xp = _check_ns_shape_dtype(actual, desired) + + if is_cupy_namespace(xp): + xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) + elif is_torch_namespace(xp): + # PyTorch recommends using `rtol=0, atol=0` like this + # to test for exact equality + xp.testing.assert_close( + actual, + desired, + rtol=0, + atol=0, + equal_nan=True, + check_dtype=False, + msg=err_msg or None, + ) + else: + import numpy as np # pylint: disable=import-outside-toplevel + + if is_pydata_sparse_namespace(xp): + actual = actual.todense() + desired = desired.todense() + + # JAX uses `np.testing` + np.testing.assert_array_equal(actual, desired, err_msg=err_msg) + + +def xp_assert_close( + actual: Array, + desired: Array, + *, + rtol: float | None = None, + atol: float = 0, + err_msg: str = "", +) -> None: + """ + Array-API compatible version of `np.testing.assert_allclose`. + + Parameters + ---------- + actual : Array + The array produced by the tested function. + desired : Array + The expected array (typically hardcoded). + rtol : float, optional + Relative tolerance. Default: dtype-dependent. + atol : float, optional + Absolute tolerance. Default: 0. + err_msg : str, optional + Error message to display on failure. + """ + xp = _check_ns_shape_dtype(actual, desired) + + floating = xp.isdtype(actual.dtype, ("real floating", "complex floating")) + if rtol is None and floating: + # multiplier of 4 is used as for `np.float64` this puts the default `rtol` + # roughly half way between sqrt(eps) and the default for + # `numpy.testing.assert_allclose`, 1e-7 + rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 + elif rtol is None: + rtol = 1e-7 + + if is_cupy_namespace(xp): + xp.testing.assert_allclose( + actual, desired, rtol=rtol, atol=atol, err_msg=err_msg + ) + elif is_torch_namespace(xp): + xp.testing.assert_close( + actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None + ) + else: + import numpy as np # pylint: disable=import-outside-toplevel + + if is_pydata_sparse_namespace(xp): + actual = actual.to_dense() + desired = desired.to_dense() + + # JAX uses `np.testing` + assert isinstance(rtol, float) + np.testing.assert_allclose( + actual, desired, rtol=rtol, atol=atol, err_msg=err_msg + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..3ccaf52 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Needed to import .conftest from the test modules.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0bf3114 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,86 @@ +"""Pytest fixtures.""" + +from enum import Enum +from typing import cast + +import pytest + +from array_api_extra._lib._compat import array_namespace +from array_api_extra._lib._compat import device as get_device +from array_api_extra._lib._typing import Device, ModuleType + + +class Library(Enum): + """All array libraries explicitly tested by array-api-extra.""" + + ARRAY_API_STRICT = "array_api_strict" + NUMPY = "numpy" + NUMPY_READONLY = "numpy_readonly" + CUPY = "cupy" + TORCH = "torch" + DASK_ARRAY = "dask.array" + SPARSE = "sparse" + JAX_NUMPY = "jax.numpy" + + def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 + """Pretty-print parameterized test names.""" + return self.value + + +@pytest.fixture(params=tuple(Library)) +def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,RT03 + """ + Parameterized fixture that iterates on all libraries. + + Returns + ------- + The current Library enum. + """ + elem = cast(Library, request.param) + + for marker in request.node.iter_markers("skip_xp_backend"): + skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage] + if not isinstance(skip_library, Library): + msg = "argument of skip_xp_backend must be a Library enum" + raise TypeError(msg) + if skip_library == elem: + reason = cast(str, marker.kwargs.get("reason", "skip_xp_backend")) + pytest.skip(reason=reason) + + return elem + + +@pytest.fixture +def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03 + """ + Parameterized fixture that iterates on all libraries. + + Returns + ------- + The current array namespace. + """ + name = "numpy" if library == Library.NUMPY_READONLY else library.value + xp = pytest.importorskip(name) + if library == Library.JAX_NUMPY: + import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] + + jax.config.update("jax_enable_x64", True) + + # Possibly wrap module with array_api_compat + return array_namespace(xp.empty(0)) + + +@pytest.fixture +def device( + library: Library, xp: ModuleType +) -> Device: # numpydoc ignore=PR01,RT01,RT03 + """ + Return a valid device for the backend. + + Where possible, return a device that is not the default one. + """ + if library == Library.ARRAY_API_STRICT: + d = xp.Device("device1") + assert get_device(xp.empty(0)) != d + return d + return get_device(xp.empty(0)) diff --git a/tests/test_at.py b/tests/test_at.py index 24d356c..09c46dc 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -1,58 +1,31 @@ from collections.abc import Callable, Generator from contextlib import contextmanager -from enum import Enum from typing import cast import numpy as np import pytest from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] array_namespace, - is_dask_array, is_pydata_sparse_array, is_writeable_array, ) from array_api_extra import at from array_api_extra._funcs import _AtOp -from array_api_extra._lib._typing import Array +from array_api_extra._lib._typing import Array, ModuleType +from array_api_extra.testing import xp_assert_equal +from .conftest import Library -class Library(Enum): - ARRAY_API_STRICT = "array_api_strict" - NUMPY = "numpy" - NUMPY_READONLY = "numpy_readonly" - CUPY = "cupy" - TORCH = "torch" - DASK_ARRAY = "dask.array" - SPARSE = "sparse" - JAX_NUMPY = "jax.numpy" - # @override from Python 3.12 - def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] - return self.value - - -@pytest.fixture(params=tuple(Library)) -def array(request: pytest.FixtureRequest) -> Array: - library = request.param - if library is Library.NUMPY_READONLY: - x = np.asarray([10.0, 20.0, 30.0]) +@pytest.fixture +def array(library: Library, xp: ModuleType) -> Array: + x = xp.asarray([10.0, 20.0, 30.0]) + if library == Library.NUMPY_READONLY: x.flags.writeable = False - else: - xp = pytest.importorskip(library.value) - x = xp.asarray([10.0, 20.0, 30.0]) return x -def assert_array_equal(a: Array, b: Array) -> None: - xp = array_namespace(a) - b = xp.asarray(b) - eq = xp.all(a == b) - if is_dask_array(a): - eq = eq.compute() - assert eq - - @contextmanager def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: if copy is False and not is_writeable_array(array): @@ -66,7 +39,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: if copy is None: copy = not is_writeable_array(array) - assert_array_equal(xp.all(array == array_orig), copy) + xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) @pytest.mark.parametrize( @@ -92,6 +65,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ], ) def test_update_ops( + xp: ModuleType, array: Array, kwargs: dict[str, bool | None], expect_copy: bool | None, @@ -106,7 +80,7 @@ def test_update_ops( func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit] y = func(arg, **kwargs) assert isinstance(y, type(array)) - assert_array_equal(y, expect) + xp_assert_equal(y, xp.asarray(expect)) def test_copy_invalid(): @@ -129,12 +103,12 @@ def test_xp(): def test_alternate_index_syntax(): a = np.asarray([1, 2, 3]) - assert_array_equal(at(a, 0).set(4, copy=True), [4, 2, 3]) - assert_array_equal(at(a)[0].set(4, copy=True), [4, 2, 3]) + xp_assert_equal(at(a, 0).set(4, copy=True), np.asarray([4, 2, 3])) + xp_assert_equal(at(a)[0].set(4, copy=True), np.asarray([4, 2, 3])) a_at = at(a) - assert_array_equal(a_at[0].add(1, copy=True), [2, 2, 3]) - assert_array_equal(a_at[1].add(2, copy=True), [1, 4, 3]) + xp_assert_equal(a_at[0].add(1, copy=True), np.asarray([2, 2, 3])) + xp_assert_equal(a_at[1].add(2, copy=True), np.asarray([1, 4, 3])) with pytest.raises(ValueError, match="Index"): at(a).set(4) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2fc9041..059fd13 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,13 +1,11 @@ import contextlib import warnings -# data-apis/array-api-strict#6 -import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_equal from array_api_extra import ( + at, atleast_nd, cov, create_diagonal, @@ -17,175 +15,184 @@ setdiff1d, sinc, ) -from array_api_extra._lib._typing import Array +from array_api_extra._lib._compat import device as get_device +from array_api_extra._lib._typing import Array, Device, ModuleType +from array_api_extra.testing import xp_assert_close, xp_assert_equal +from .conftest import Library +# mypy: disable-error-code=no-untyped-usage + + +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no expand_dims") class TestAtLeastND: - def test_0D(self): - x = xp.asarray(1) + def test_0D(self, xp: ModuleType): + x = xp.asarray(1.0) y = atleast_nd(x, ndim=0) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=1) - assert_array_equal(y, xp.ones((1,))) + xp_assert_equal(y, xp.ones((1,))) y = atleast_nd(x, ndim=5) - assert_array_equal(y, xp.ones((1, 1, 1, 1, 1))) + xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1))) - def test_1D(self): + def test_1D(self, xp: ModuleType): x = xp.asarray([0, 1]) y = atleast_nd(x, ndim=0) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=1) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=2) - assert_array_equal(y, xp.asarray([[0, 1]])) + xp_assert_equal(y, xp.asarray([[0, 1]])) y = atleast_nd(x, ndim=5) - assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2))) + xp_assert_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2))) - def test_2D(self): - x = xp.asarray([[3]]) + def test_2D(self, xp: ModuleType): + x = xp.asarray([[3.0]]) y = atleast_nd(x, ndim=0) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=2) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=3) - assert_array_equal(y, 3 * xp.ones((1, 1, 1))) + xp_assert_equal(y, 3 * xp.ones((1, 1, 1))) y = atleast_nd(x, ndim=5) - assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) + xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) - def test_5D(self): + def test_5D(self, xp: ModuleType): x = xp.ones((1, 1, 1, 1, 1)) y = atleast_nd(x, ndim=0) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=4) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=5) - assert_array_equal(y, x) + xp_assert_equal(y, x) y = atleast_nd(x, ndim=6) - assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) + xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1))) y = atleast_nd(x, ndim=9) - assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) + xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1))) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) - assert atleast_nd(x, ndim=2).device == device + assert get_device(atleast_nd(x, ndim=2)) == device - def test_xp(self): + def test_xp(self, xp: ModuleType): x = xp.asarray(1) y = atleast_nd(x, ndim=0, xp=xp) - assert_array_equal(y, x) + xp_assert_equal(y, x) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype") class TestCov: - def test_basic(self): - assert_allclose( + def test_basic(self, xp: ModuleType): + xp_assert_close( cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T), - xp.asarray([[1.0, -1.0], [-1.0, 1.0]]), + xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64), ) - def test_complex(self): - x = xp.asarray([[1, 2, 3], [1j, 2j, 3j]]) - res = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]]) - assert_allclose(cov(x), res) + def test_complex(self, xp: ModuleType): + actual = cov(xp.asarray([[1, 2, 3], [1j, 2j, 3j]])) + expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) + xp_assert_close(actual, expect) - def test_empty(self): + def test_empty(self, xp: ModuleType): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning) - assert_array_equal(cov(xp.asarray([])), xp.nan) - assert_array_equal( + xp_assert_equal(cov(xp.asarray([])), xp.asarray(xp.nan, dtype=xp.float64)) + xp_assert_equal( cov(xp.reshape(xp.asarray([]), (0, 2))), - xp.reshape(xp.asarray([]), (0, 0)), + xp.reshape(xp.asarray([], dtype=xp.float64), (0, 0)), ) - assert_array_equal( + xp_assert_equal( cov(xp.reshape(xp.asarray([]), (2, 0))), - xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]]), + xp.asarray([[xp.nan, xp.nan], [xp.nan, xp.nan]], dtype=xp.float64), ) - def test_combination(self): + def test_combination(self, xp: ModuleType): x = xp.asarray([-2.1, -1, 4.3]) y = xp.asarray([3, 1.1, 0.12]) X = xp.stack((x, y), axis=0) - desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]]) - assert_allclose(cov(X), desired, rtol=1e-6) - assert_allclose(cov(x), xp.asarray(11.71)) - assert_allclose(cov(y), xp.asarray(2.144133), rtol=1e-6) + desired = xp.asarray([[11.71, -4.286], [-4.286, 2.144133]], dtype=xp.float64) + xp_assert_close(cov(X), desired, rtol=1e-6) + xp_assert_close(cov(x), xp.asarray(11.71, dtype=xp.float64)) + xp_assert_close(cov(y), xp.asarray(2.144133, dtype=xp.float64), rtol=1e-6) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) - assert cov(x).device == device + assert get_device(cov(x)) == device - def test_xp(self): - assert_allclose( - cov(xp.asarray([[0, 2], [1, 1], [2, 0]]).T, xp=xp), - xp.asarray([[1.0, -1.0], [-1.0, 1.0]]), + def test_xp(self, xp: ModuleType): + xp_assert_close( + cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp), + xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64), ) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no device") class TestCreateDiagonal: - def test_1d(self): + def test_1d(self, xp: ModuleType): # from np.diag tests vals = 100 * xp.arange(5, dtype=xp.float64) - b = xp.zeros((5, 5)) + b = xp.zeros((5, 5), dtype=xp.float64) for k in range(5): - b[k, k] = vals[k] - assert_array_equal(create_diagonal(vals), b) - b = xp.zeros((7, 7)) + b = at(b)[k, k].set(vals[k]) + xp_assert_equal(create_diagonal(vals), b) + b = xp.zeros((7, 7), dtype=xp.float64) c = xp.asarray(b, copy=True) for k in range(5): - b[k, k + 2] = vals[k] - c[k + 2, k] = vals[k] - assert_array_equal(create_diagonal(vals, offset=2), b) - assert_array_equal(create_diagonal(vals, offset=-2), c) + b = at(b)[k, k + 2].set(vals[k]) + c = at(c)[k + 2, k].set(vals[k]) + xp_assert_equal(create_diagonal(vals, offset=2), b) + xp_assert_equal(create_diagonal(vals, offset=-2), c) @pytest.mark.parametrize("n", range(1, 10)) @pytest.mark.parametrize("offset", range(1, 10)) - def test_create_diagonal(self, n: int, offset: int): + def test_create_diagonal(self, xp: ModuleType, n: int, offset: int): # from scipy._lib tests rng = np.random.default_rng(2347823) one = xp.asarray(1.0) x = rng.random(n) A = create_diagonal(xp.asarray(x, dtype=one.dtype), offset=offset) B = xp.asarray(np.diag(x, offset), dtype=one.dtype) - assert_array_equal(A, B) + xp_assert_equal(A, B) - def test_0d(self): + def test_0d(self, xp: ModuleType): with pytest.raises(ValueError, match="1-dimensional"): create_diagonal(xp.asarray(1)) - def test_2d(self): + def test_2d(self, xp: ModuleType): with pytest.raises(ValueError, match="1-dimensional"): create_diagonal(xp.asarray([[1]])) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) - assert create_diagonal(x).device == device + assert get_device(create_diagonal(x)) == device - def test_xp(self): + def test_xp(self, xp: ModuleType): x = xp.asarray([1, 2]) y = create_diagonal(x, xp=xp) - assert_array_equal(y, xp.asarray([[1, 0], [0, 2]])) + xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]])) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no sparse.expand_dims") class TestExpandDims: - def test_functionality(self): + @pytest.mark.skip_xp_backend(Library.DASK_ARRAY, reason="tuple index out of range") + @pytest.mark.skip_xp_backend(Library.TORCH, reason="tuple index out of range") + def test_functionality(self, xp: ModuleType): def _squeeze_all(b: Array) -> Array: """Mimics `np.squeeze(b)`. `xpx.squeeze`?""" for axis in range(b.ndim): @@ -200,14 +207,14 @@ def _squeeze_all(b: Array) -> Array: assert b.shape[axis] == 1 assert _squeeze_all(b).shape == s - def test_axis_tuple(self): + def test_axis_tuple(self, xp: ModuleType): a = xp.empty((3, 3, 3)) assert expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) assert expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) assert expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) - def test_axis_out_of_range(self): + def test_axis_out_of_range(self, xp: ModuleType): s = (2, 3, 4, 5) a = xp.empty(s) with pytest.raises(IndexError, match="out of bounds"): @@ -221,64 +228,64 @@ def test_axis_out_of_range(self): with pytest.raises(IndexError, match="out of bounds"): expand_dims(a, axis=(0, 5)) - def test_repeated_axis(self): + def test_repeated_axis(self, xp: ModuleType): a = xp.empty((3, 3, 3)) with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(1, 1)) - def test_positive_negative_repeated(self): + def test_positive_negative_repeated(self, xp: ModuleType): # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 a = xp.empty((2, 3, 4, 5)) with pytest.raises(ValueError, match="Duplicate dimensions"): expand_dims(a, axis=(3, -3)) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) - assert expand_dims(x, axis=0).device == device + assert get_device(expand_dims(x, axis=0)) == device - def test_xp(self): + def test_xp(self, xp: ModuleType): x = xp.asarray([1, 2, 3]) y = expand_dims(x, axis=(0, 1, 2), xp=xp) assert y.shape == (1, 1, 1, 3) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no sparse.expand_dims") class TestKron: - def test_basic(self): + def test_basic(self, xp: ModuleType): # Using 0-dimensional array a = xp.asarray(1) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[1, 2], [3, 4]]) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray(1) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) # Using 1-dimensional array a = xp.asarray([3]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[3, 6], [9, 12]]) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray([3]) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) # Using 3-dimensional array a = xp.asarray([[[1]], [[2]]]) b = xp.asarray([[1, 2], [3, 4]]) k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) a = xp.asarray([[1, 2], [3, 4]]) b = xp.asarray([[[1]], [[2]]]) k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) - def test_kron_smoke(self): + def test_kron_smoke(self, xp: ModuleType): a = xp.ones((3, 3)) b = xp.ones((3, 3)) k = xp.ones((9, 9)) - assert_array_equal(kron(a, b), k) + xp_assert_equal(kron(a, b), k) @pytest.mark.parametrize( ("shape_a", "shape_b"), @@ -291,7 +298,9 @@ def test_kron_smoke(self): ((2, 0, 0, 2), (2, 0, 2)), ], ) - def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]): + def test_kron_shape( + self, xp: ModuleType, shape_a: tuple[int, ...], shape_b: tuple[int, ...] + ): a = xp.ones(shape_a) b = xp.ones(shape_b) normalised_shape_a = xp.asarray( @@ -305,119 +314,123 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]): ) k = kron(a, b) - assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron") + assert k.shape == expected_shape - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([1, 2, 3], device=device) x2 = xp.asarray([4, 5], device=device) - assert kron(x1, x2).device == device + assert get_device(kron(x1, x2)) == device - def test_xp(self): + def test_xp(self, xp: ModuleType): a = xp.ones((3, 3)) b = xp.ones((3, 3)) k = xp.ones((9, 9)) - assert_array_equal(kron(a, b, xp=xp), k) + xp_assert_equal(kron(a, b, xp=xp), k) +@pytest.mark.skip_xp_backend(Library.DASK_ARRAY, reason="no argsort") +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no device") class TestSetDiff1D: - def test_setdiff1d(self): + @pytest.mark.skip_xp_backend( + Library.TORCH, reason="index_select not implemented for uint32" + ) + def test_setdiff1d(self, xp: ModuleType): x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4]) x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5]) expected = xp.asarray([6, 7]) actual = setdiff1d(x1, x2) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) x1 = xp.arange(21) x2 = xp.arange(19) expected = xp.asarray([19, 20]) actual = setdiff1d(x1, x2) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) - assert_array_equal(setdiff1d(xp.empty(0), xp.empty(0)), xp.empty(0)) + xp_assert_equal(setdiff1d(xp.empty(0), xp.empty(0)), xp.empty(0)) x1 = xp.empty(0, dtype=xp.uint32) x2 = x1 - assert_equal(setdiff1d(x1, x2).dtype, xp.uint32) + assert xp.isdtype(setdiff1d(x1, x2).dtype, xp.uint32) - def test_assume_unique(self): + def test_assume_unique(self, xp: ModuleType): x1 = xp.asarray([3, 2, 1]) x2 = xp.asarray([7, 5, 2]) expected = xp.asarray([3, 1]) actual = setdiff1d(x1, x2, assume_unique=True) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) - assert setdiff1d(x1, x2).device == device + assert get_device(setdiff1d(x1, x2)) == device - def test_xp(self): + def test_xp(self, xp: ModuleType): x1 = xp.asarray([3, 8, 20]) x2 = xp.asarray([2, 3, 4]) expected = xp.asarray([8, 20]) actual = setdiff1d(x1, x2, xp=xp) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype") class TestSinc: - def test_simple(self): - assert_array_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) + def test_simple(self, xp: ModuleType): + xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0)) w = sinc(xp.linspace(-1, 1, 100)) # check symmetry - assert_allclose(w, xp.flip(w, axis=0)) + xp_assert_close(w, xp.flip(w, axis=0)) @pytest.mark.parametrize("x", [0, 1 + 3j]) - def test_dtype(self, x: int | complex): + def test_dtype(self, xp: ModuleType, x: int | complex): with pytest.raises(ValueError, match="real floating data type"): sinc(xp.asarray(x)) - def test_3d(self): + def test_3d(self, xp: ModuleType): x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) - expected = xp.zeros((3, 3, 2)) - expected[0, 0, 0] = 1.0 - assert_allclose(sinc(x), expected, atol=1e-15) + expected = xp.zeros((3, 3, 2), dtype=xp.float64) + expected = at(expected)[0, 0, 0].set(1.0) + xp_assert_close(sinc(x), expected, atol=1e-15) - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): x = xp.asarray(0.0, device=device) - assert sinc(x).device == device + assert get_device(sinc(x)) == device - def test_xp(self): - assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + def test_xp(self, xp: ModuleType): + xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no arange, no device") class TestPad: - def test_simple(self): + def test_simple(self, xp: ModuleType): a = xp.arange(1, 4) padded = pad(a, 2) - assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0])) + xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0])) - def test_fill_value(self): + def test_fill_value(self, xp: ModuleType): a = xp.arange(1, 4) padded = pad(a, 2, constant_values=42) - assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42])) + xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42])) - def test_ndim(self): + def test_ndim(self, xp: ModuleType): a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) padded = pad(a, 2) assert padded.shape == (6, 7, 8) - def test_mode_not_implemented(self): + def test_mode_not_implemented(self, xp: ModuleType): a = xp.arange(3) with pytest.raises(NotImplementedError, match="Only `'constant'`"): pad(a, 2, mode="edge") - def test_device(self): - device = xp.Device("device1") + def test_device(self, xp: ModuleType, device: Device): a = xp.asarray(0.0, device=device) - assert pad(a, 2).device == device + assert get_device(pad(a, 2)) == device - def test_xp(self): - assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3)) + def test_xp(self, xp: ModuleType): + padded = pad(xp.asarray(0), 1, xp=xp) + xp_assert_equal(padded, xp.asarray(0)) - def test_tuple_width(self): + def test_tuple_width(self, xp: ModuleType): a = xp.reshape(xp.arange(12), (3, 4)) padded = pad(a, (1, 0)) assert padded.shape == (4, 5) @@ -428,7 +441,7 @@ def test_tuple_width(self): with pytest.raises(ValueError, match="expect a 2-tuple"): pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] - def test_list_of_tuples_width(self): + def test_list_of_tuples_width(self, xp: ModuleType): a = xp.reshape(xp.arange(12), (3, 4)) padded = pad(a, [(1, 0), (0, 2)]) assert padded.shape == (4, 6) diff --git a/tests/test_testing.py b/tests/test_testing.py new file mode 100644 index 0000000..c47fecd --- /dev/null +++ b/tests/test_testing.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest + +from array_api_extra.testing import xp_assert_close, xp_assert_equal + +from .conftest import Library + +# mypy: disable-error-code=no-any-decorated +# pyright: reportUnknownParameterType=false,reportMissingParameterType=false + + +@pytest.mark.parametrize( + "func", + [ + xp_assert_equal, + pytest.param( + xp_assert_close, + marks=pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype"), + ), + ], +) +def test_assert_close_equal_basic(xp, func): + func(xp.asarray(0), xp.asarray(0)) + func(xp.asarray([1, 2]), xp.asarray([1, 2])) + + with pytest.raises(AssertionError, match="shapes do not match"): + func(xp.asarray([0]), xp.asarray([[0]])) + + with pytest.raises(AssertionError, match="dtypes do not match"): + func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64)) + + with pytest.raises(AssertionError): + func(xp.asarray([1, 2]), xp.asarray([1, 3])) + + with pytest.raises(AssertionError, match="hello"): + func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello") + + +@pytest.mark.skip_xp_backend(Library.NUMPY) +@pytest.mark.skip_xp_backend(Library.NUMPY_READONLY) +@pytest.mark.parametrize( + "func", + [ + xp_assert_equal, + pytest.param( + xp_assert_close, + marks=pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype"), + ), + ], +) +def test_assert_close_equal_namespace(xp, func): + with pytest.raises(AssertionError): + func(xp.asarray(0), np.asarray(0)) + with pytest.raises(TypeError): + func(xp.asarray(0), 0) + with pytest.raises(TypeError): + func(xp.asarray([0]), [0]) + + +@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype") +def test_assert_close_tolerance(xp): + xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03) + with pytest.raises(AssertionError): + xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01) + + xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3) + with pytest.raises(AssertionError): + xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1807627..4e41254 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,31 +1,38 @@ -# data-apis/array-api-strict#6 -import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] +import numpy as np import pytest -from numpy.testing import assert_array_equal -from array_api_extra._lib._typing import Array +from array_api_extra._lib._compat import device as get_device +from array_api_extra._lib._typing import Array, Device, ModuleType from array_api_extra._lib._utils import in1d +from array_api_extra.testing import xp_assert_equal + +from .conftest import Library + +# mypy: disable-error-code=no-untyped-usage -# some test coverage already provided by TestSetDiff1D class TestIn1D: + @pytest.mark.skip_xp_backend(Library.DASK_ARRAY, reason="no argsort") + @pytest.mark.skip_xp_backend(Library.SPARSE, reason="no unique_inverse, no device") # cover both code paths - @pytest.mark.parametrize("x2", [xp.arange(9), xp.arange(15)]) - def test_no_invert_assume_unique(self, x2: Array): + @pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)]) + def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array): x1 = xp.asarray([3, 8, 20]) + x2 = xp.asarray(x2) expected = xp.asarray([True, True, False]) actual = in1d(x1, x2) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) - def test_device(self): - device = xp.Device("device1") + @pytest.mark.skip_xp_backend(Library.SPARSE, reason="no device") + def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([3, 8, 20], device=device) x2 = xp.asarray([2, 3, 4], device=device) - assert in1d(x1, x2).device == device + assert get_device(in1d(x1, x2)) == device - def test_xp(self): + @pytest.mark.skip_xp_backend(Library.SPARSE, reason="no arange, no device") + def test_xp(self, xp: ModuleType): x1 = xp.asarray([1, 6]) x2 = xp.arange(5) expected = xp.asarray([True, False]) actual = in1d(x1, x2, xp=xp) - assert_array_equal(actual, expected) + xp_assert_equal(actual, expected) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index c2e6570..e320280 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -6,8 +6,13 @@ def test_vendor_compat(): from ._array_api_compat_vendor import ( # type: ignore[attr-defined] array_namespace, device, + is_cupy_namespace, is_jax_array, + is_jax_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, is_writeable_array, + size, ) x = xp.asarray([1, 2, 3]) @@ -15,6 +20,11 @@ def test_vendor_compat(): device(x) assert not is_jax_array(x) assert is_writeable_array(x) + assert not is_cupy_namespace(xp) + assert not is_jax_namespace(xp) + assert not is_pydata_sparse_namespace(xp) + assert not is_torch_namespace(xp) + assert size(x) == 3 def test_vendor_extra(): From 744591784fde027bcbbf1b4987d7854d74fb8c18 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 14 Jan 2025 14:30:46 +0000 Subject: [PATCH 2/2] Make test functions private --- docs/api-reference.md | 12 ------------ src/array_api_extra/{testing.py => _lib/_testing.py} | 10 +++++++--- tests/test_at.py | 2 +- tests/test_funcs.py | 2 +- tests/test_testing.py | 2 +- tests/test_utils.py | 2 +- 6 files changed, 11 insertions(+), 19 deletions(-) rename src/array_api_extra/{testing.py => _lib/_testing.py} (96%) diff --git a/docs/api-reference.md b/docs/api-reference.md index a6133d3..b43c960 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -15,15 +15,3 @@ setdiff1d sinc ``` - -## Test tools - -```{eval-rst} -.. currentmodule:: array_api_extra.testing -.. autosummary:: - :nosignatures: - :toctree: generated - - xp_assert_equal - xp_assert_close -``` diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/_lib/_testing.py similarity index 96% rename from src/array_api_extra/testing.py rename to src/array_api_extra/_lib/_testing.py index ac73ba2..b866012 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -1,12 +1,16 @@ -"""Testing utilities.""" +""" +Testing utilities. -from ._lib._compat import ( +Note that this is private API; don't expect it to be stable. +""" + +from ._compat import ( array_namespace, is_cupy_namespace, is_pydata_sparse_namespace, is_torch_namespace, ) -from ._lib._typing import Array, ModuleType +from ._typing import Array, ModuleType __all__ = ["xp_assert_close", "xp_assert_equal"] diff --git a/tests/test_at.py b/tests/test_at.py index 09c46dc..ed56f61 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -12,8 +12,8 @@ from array_api_extra import at from array_api_extra._funcs import _AtOp +from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._typing import Array, ModuleType -from array_api_extra.testing import xp_assert_equal from .conftest import Library diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 059fd13..5f18ef6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -16,8 +16,8 @@ sinc, ) from array_api_extra._lib._compat import device as get_device +from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._typing import Array, Device, ModuleType -from array_api_extra.testing import xp_assert_close, xp_assert_equal from .conftest import Library diff --git a/tests/test_testing.py b/tests/test_testing.py index c47fecd..28b37d0 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from array_api_extra.testing import xp_assert_close, xp_assert_equal +from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from .conftest import Library diff --git a/tests/test_utils.py b/tests/test_utils.py index 4e41254..8cf49c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,9 +2,9 @@ import pytest from array_api_extra._lib._compat import device as get_device +from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._typing import Array, Device, ModuleType from array_api_extra._lib._utils import in1d -from array_api_extra.testing import xp_assert_equal from .conftest import Library