Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST/BUG: run all tests on all backends; fix backend-specific bugs #88

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.8.2"
hooks:
- id: ruff-format
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ xfail_strict = true
filterwarnings = ["error"]
log_cli_level = "INFO"
testpaths = ["tests"]

markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]

# Coverage

Expand Down Expand Up @@ -315,6 +315,7 @@ checks = [
exclude = [ # don't report on objects that match any of these regex
'.*test_at.*',
'.*test_funcs.*',
'.*test_testing.*',
'.*test_utils.*',
'.*test_version.*',
'.*test_vendor.*',
Expand Down
16 changes: 9 additions & 7 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ def create_diagonal(
raise ValueError(err_msg)
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
i = offset if offset >= 0 else abs(offset) * n
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x

start = offset if offset >= 0 else abs(offset) * n
stop = min(n * (n - offset), diag.shape[0])
step = n + 1
diag = at(diag)[start:stop:step].set(x)

return xp.reshape(diag, (n, n))


Expand Down Expand Up @@ -407,9 +411,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
result = xp.multiply(a_arr, b_arr)

# Reshape back and return
a_shape = xp.asarray(a_shape)
b_shape = xp.asarray(b_shape)
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
res_shape = tuple(a_s * b_s for a_s, b_s in zip(a_shape, b_shape, strict=True))
return xp.reshape(result, res_shape)


def setdiff1d(
Expand Down Expand Up @@ -632,8 +635,7 @@ def pad(
dtype=x.dtype,
device=_compat.device(x),
)
padded[tuple(slices)] = x
return padded
return at(padded, tuple(slices)).set(x)


class _AtOp(Enum):
Expand Down
15 changes: 15 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,35 @@
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace,
device,
is_cupy_namespace,
is_jax_array,
is_jax_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
size,
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace,
device,
is_cupy_namespace,
is_jax_array,
is_jax_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
size,
)

__all__ = [
"array_namespace",
"device",
"is_cupy_namespace",
"is_jax_array",
"is_jax_namespace",
"is_pydata_sparse_namespace",
"is_torch_namespace",
"is_writeable_array",
"size",
]
5 changes: 5 additions & 0 deletions src/array_api_extra/_lib/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,10 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
def is_cupy_namespace(x: object, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
def is_jax_namespace(x: object, /) -> bool: ...
def is_pydata_sparse_namespace(x: object, /) -> bool: ...
def is_torch_namespace(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
144 changes: 144 additions & 0 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Testing utilities.

Note that this is private API; don't expect it to be stable.
"""

from ._compat import (
array_namespace,
is_cupy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array, ModuleType

__all__ = ["xp_assert_close", "xp_assert_equal"]


def _check_ns_shape_dtype(
actual: Array, desired: Array
) -> ModuleType: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.

Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).

Returns
-------
Arrays namespace.
"""
actual_xp = array_namespace(actual) # Raises on scalars and lists
desired_xp = array_namespace(desired)

msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
assert actual.shape == desired.shape, msg

msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg

return desired_xp


def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.

Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.
"""
xp = _check_ns_shape_dtype(actual, desired)

if is_cupy_namespace(xp):
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)

Check warning on line 66 in src/array_api_extra/_lib/_testing.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_testing.py#L66

Added line #L66 was not covered by tests
elif is_torch_namespace(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
xp.testing.assert_close(
actual,
desired,
rtol=0,
atol=0,
equal_nan=True,
check_dtype=False,
msg=err_msg or None,
)
else:
import numpy as np # pylint: disable=import-outside-toplevel

if is_pydata_sparse_namespace(xp):
actual = actual.todense()
desired = desired.todense()

# JAX uses `np.testing`
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)


def xp_assert_close(
actual: Array,
desired: Array,
*,
rtol: float | None = None,
atol: float = 0,
err_msg: str = "",
) -> None:
"""
Array-API compatible version of `np.testing.assert_allclose`.

Parameters
----------
actual : Array
The array produced by the tested function.
desired : Array
The expected array (typically hardcoded).
rtol : float, optional
Relative tolerance. Default: dtype-dependent.
atol : float, optional
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.
"""
xp = _check_ns_shape_dtype(actual, desired)

floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
if rtol is None and floating:
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
elif rtol is None:
rtol = 1e-7

if is_cupy_namespace(xp):
xp.testing.assert_allclose(

Check warning on line 126 in src/array_api_extra/_lib/_testing.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_testing.py#L126

Added line #L126 was not covered by tests
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)
elif is_torch_namespace(xp):
xp.testing.assert_close(
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
)
else:
import numpy as np # pylint: disable=import-outside-toplevel

if is_pydata_sparse_namespace(xp):
actual = actual.to_dense()
desired = desired.to_dense()

Check warning on line 138 in src/array_api_extra/_lib/_testing.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_testing.py#L137-L138

Added lines #L137 - L138 were not covered by tests

# JAX uses `np.testing`
assert isinstance(rtol, float)
np.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)
4 changes: 3 additions & 1 deletion src/array_api_extra/_lib/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def in1d(
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
if sar.size >= 1:
ar_size = _compat.size(sar)
assert ar_size is not None, "xp.unique*() on lazy backends raises"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caveat: this is missing data-apis/array-api-compat#231

if ar_size >= 1:
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
else:
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Needed to import .conftest from the test modules."""
86 changes: 86 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Pytest fixtures."""

from enum import Enum
from typing import cast

import pytest

from array_api_extra._lib._compat import array_namespace
from array_api_extra._lib._compat import device as get_device
from array_api_extra._lib._typing import Device, ModuleType


class Library(Enum):
"""All array libraries explicitly tested by array-api-extra."""

ARRAY_API_STRICT = "array_api_strict"
NUMPY = "numpy"
NUMPY_READONLY = "numpy_readonly"
lucascolley marked this conversation as resolved.
Show resolved Hide resolved
CUPY = "cupy"
TORCH = "torch"
DASK_ARRAY = "dask.array"
SPARSE = "sparse"
lucascolley marked this conversation as resolved.
Show resolved Hide resolved
JAX_NUMPY = "jax.numpy"

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return self.value


@pytest.fixture(params=tuple(Library))
def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.

Returns
-------
The current Library enum.
"""
elem = cast(Library, request.param)

for marker in request.node.iter_markers("skip_xp_backend"):
skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
if not isinstance(skip_library, Library):
msg = "argument of skip_xp_backend must be a Library enum"
raise TypeError(msg)
if skip_library == elem:
reason = cast(str, marker.kwargs.get("reason", "skip_xp_backend"))
pytest.skip(reason=reason)

return elem


@pytest.fixture
def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.

Returns
-------
The current array namespace.
"""
name = "numpy" if library == Library.NUMPY_READONLY else library.value
xp = pytest.importorskip(name)
if library == Library.JAX_NUMPY:
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]

jax.config.update("jax_enable_x64", True)

# Possibly wrap module with array_api_compat
return array_namespace(xp.empty(0))


@pytest.fixture
def device(
library: Library, xp: ModuleType
) -> Device: # numpydoc ignore=PR01,RT01,RT03
"""
Return a valid device for the backend.

Where possible, return a device that is not the default one.
"""
if library == Library.ARRAY_API_STRICT:
d = xp.Device("device1")
assert get_device(xp.empty(0)) != d
return d
return get_device(xp.empty(0))
Loading
Loading