Skip to content

Commit

Permalink
Design 2->4
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 26, 2025
1 parent cd644fe commit 028441c
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
size: int | None = None,
fill_value: object | None = None,
xp: ModuleType | None = None,
) -> Array:
Expand All @@ -561,11 +562,16 @@ def setdiff1d(
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
fill_value : object, optional
Pad the output array with this value.
size : int, optional
The size of the output array. This is exclusively used inside the JAX JIT, and
only for as long as JAX does not support arrays of unknown size inside it. In
all other cases, it is disregarded.
Returned elements will be clipped if they are more than size, and padded with
`fill_value` if they are less. Default: raise if inside ``jax.jit``.
This is exclusively used for JAX arrays when running inside ``jax.jit``,
where all array shapes need to be known in advance.
fill_value : object, optional
Pad the output array with this value. This is exclusively used for JAX arrays
when running inside ``jax.jit``. Default: 0.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.
Expand Down Expand Up @@ -630,7 +636,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
return x1 if assume_unique else xp.unique_values(x1)

def _jax_jit_impl(
x1: Array, x2: Array, fill_value: object | None
x1: Array, x2: Array, size: int | None, fill_value: object | None
) -> Array: # numpydoc ignore=PR01,RT01
"""
JAX implementation inside jax.jit.
Expand All @@ -639,9 +645,9 @@ def _jax_jit_impl(
and not being able to filter by a boolean mask.
Returns array the same size as x1, padded with fill_value.
"""
# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)

if size is None:
msg = "`size` is mandatory when running inside `jax.jit`."
raise ValueError(msg)
if fill_value is None:
fill_value = xp.zeros((), dtype=x1.dtype)
else:
Expand All @@ -650,9 +656,13 @@ def _jax_jit_impl(
msg = "`fill_value` must be a scalar."
raise ValueError(msg)

# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)
x1 = xp.where(mask, x1, fill_value)
# Note: jnp.unique_values sorts
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
# Move fill_value to the right
x1 = xp.take(x1, xp.argsort(~mask, stable=True))
x1 = x1[:size]
x1 = xp.unique_values(x1, size=size, fill_value=fill_value)

if is_dask_namespace(xp):
return _dask_impl(x1, x2)
Expand All @@ -666,7 +676,7 @@ def _jax_jit_impl(
jax.errors.ConcretizationTypeError,
jax.errors.NonConcreteBooleanIndexError,
):
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit

return _generic_impl(x1, x2)

Expand Down

0 comments on commit 028441c

Please sign in to comment.