Skip to content

Commit

Permalink
WIP ENH: setdiff1d for Dask and jax.jit
Browse files Browse the repository at this point in the history
f
  • Loading branch information
crusaderky committed Jan 26, 2025
1 parent b4d7b2c commit cd644fe
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 118 deletions.
99 changes: 92 additions & 7 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
)
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -539,6 +544,7 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
fill_value: object | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Expand All @@ -555,6 +561,11 @@ def setdiff1d(
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
fill_value : object, optional
Pad the output array with this value.
This is exclusively used for JAX arrays when running inside ``jax.jit``,
where all array shapes need to be known in advance.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.
Expand All @@ -578,12 +589,86 @@ def setdiff1d(
if xp is None:
xp = array_namespace(x1, x2)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
x1 = xp.reshape(x1, (-1,))
x2 = xp.reshape(x2, (-1,))
if x1.shape == (0,) or x2.shape == (0,):
return x1

def _x1_not_in_x2(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""For each element of x1, return True if it is not also in x2."""
# Even when assume_unique=True, there is no provision for x to be sorted
x2 = xp.sort(x2)
idx = xp.searchsorted(x2, x1)

# FIXME at() is faster but needs JAX jit support for bool mask
# idx = at(idx, idx == x2.shape[0]).set(0)
idx = xp.where(idx == x2.shape[0], xp.zeros_like(idx), idx)

return xp.take(x2, idx, axis=0) != x1

def _generic_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""Generic implementation (including eager JAX)."""
# Note: there is no provision in the Array API for xp.unique_values to sort
if not assume_unique:
# Call unique_values early to speed up the algorithm
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
mask = _x1_not_in_x2(x1, x2)
x1 = x1[mask]
return x1 if assume_unique else xp.sort(x1)

def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
"""
Dask implementation.
Works around unique_values returning unknown shapes.
"""
# Do not call unique_values yet, as it would make array shapes unknown
mask = _x1_not_in_x2(x1, x2)
x1 = x1[mask]
# Note: da.unique_values sorts
return x1 if assume_unique else xp.unique_values(x1)

def _jax_jit_impl(
x1: Array, x2: Array, fill_value: object | None
) -> Array: # numpydoc ignore=PR01,RT01
"""
JAX implementation inside jax.jit.
Works around unique_values requiring a size= parameter
and not being able to filter by a boolean mask.
Returns array the same size as x1, padded with fill_value.
"""
# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)

if fill_value is None:
fill_value = xp.zeros((), dtype=x1.dtype)
else:
fill_value = xp.asarray(fill_value, dtype=x1.dtype)
if cast(Array, fill_value).ndim != 0:
msg = "`fill_value` must be a scalar."
raise ValueError(msg)

x1 = xp.where(mask, x1, fill_value)
# Note: jnp.unique_values sorts
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)

if is_dask_namespace(xp):
return _dask_impl(x1, x2)

if is_jax_namespace(xp):
import jax

try:
return _generic_impl(x1, x2) # eager mode
except (
jax.errors.ConcretizationTypeError,
jax.errors.NonConcreteBooleanIndexError,
):
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit

return _generic_impl(x1, x2)


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down
61 changes: 1 addition & 60 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,7 @@
from . import _compat
from ._typing import Array

__all__ = ["in1d", "mean"]


def in1d(
x1: Array,
x2: Array,
/,
*,
assume_unique: bool = False,
invert: bool = False,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""
Check whether each element of an array is also present in a second array.
Returns a boolean array the same length as `x1` that is True
where an element of `x1` is in `x2` and False otherwise.
This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""
if xp is None:
xp = _compat.array_namespace(x1, x2)

# This code is run to make the code significantly faster
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
if invert:
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
for a in x2:
mask &= x1 != a
else:
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
for a in x2:
mask |= x1 == a
return mask

rev_idx = xp.empty(0) # placeholder
if not assume_unique:
x1, rev_idx = xp.unique_inverse(x1)
x2 = xp.unique_values(x2)

ar = xp.concat((x1, x2))
device_ = _compat.device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
ar_size = _compat.size(sar)
assert ar_size is not None, "xp.unique*() on lazy backends raises"
if ar_size >= 1:
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
else:
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: x1.shape[0]]
return xp.take(ret, rev_idx, axis=0)
__all__ = ["mean"]


def mean(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
lazy_xp_function(kron, static_argnames="xp")
lazy_xp_function(nunique, static_argnames="xp")
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp"))
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")

Expand Down Expand Up @@ -547,8 +546,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
assert padded.shape == (4, 4)


@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sort")
class TestSetDiff1D:
@pytest.mark.skip_xp_backend(
Backend.TORCH, reason="index_select not implemented for uint32"
Expand Down
47 changes: 0 additions & 47 deletions tests/test_utils.py

This file was deleted.

0 comments on commit cd644fe

Please sign in to comment.