From e786e88fd8b734af010a2f3de2590de21d1e5858 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:41:14 +0530 Subject: [PATCH] feat: added the erfinv function to ivy's experimental API (#28159) --- .../array/experimental/elementwise.py | 31 +++++++ .../container/experimental/elementwise.py | 91 +++++++++++++++++++ .../backends/jax/experimental/elementwise.py | 9 ++ .../numpy/experimental/elementwise.py | 12 +++ .../paddle/experimental/elementwise.py | 10 ++ .../tensorflow/experimental/elementwise.py | 10 ++ .../torch/experimental/elementwise.py | 13 +++ .../ivy/experimental/elementwise.py | 36 ++++++++ .../test_core/test_elementwise.py | 38 +++++++- 9 files changed, 249 insertions(+), 1 deletion(-) diff --git a/ivy/data_classes/array/experimental/elementwise.py b/ivy/data_classes/array/experimental/elementwise.py index f48d1b303fc55..80eb526cb7093 100644 --- a/ivy/data_classes/array/experimental/elementwise.py +++ b/ivy/data_classes/array/experimental/elementwise.py @@ -1191,3 +1191,34 @@ def erfc( ivy.array([1.00000000e+00, 1.84270084e+00, 2.80259693e-45]) """ return ivy.erfc(self._data, out=out) + + def erfinv( + self: ivy.Array, + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.erfinv. This method simply + wraps the function, and so the docstring for ivy.erfinv also applies to + this method with minimal changes. + + Parameters + ---------- + self + Input array with real or complex valued argument. + out + Alternate output array in which to place the result. + The default is None. + + Returns + ------- + ret + Values of the inverse error function. + + Examples + -------- + >>> x = ivy.array([0, -1., 10.]) + >>> x.erfinv() + ivy.array([1.00000000e+00, 1.84270084e+00, 2.80259693e-45]) + """ + return ivy.erfinv(self._data, out=out) diff --git a/ivy/data_classes/container/experimental/elementwise.py b/ivy/data_classes/container/experimental/elementwise.py index 939c32874fc50..402e0fa1fac1e 100644 --- a/ivy/data_classes/container/experimental/elementwise.py +++ b/ivy/data_classes/container/experimental/elementwise.py @@ -3491,3 +3491,94 @@ def erfc( } """ return self.static_erfc(self, out=out) + + @staticmethod + def static_erfinv( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ivy.Container static method variant of ivy.erfinv. This method + simply wraps the function, and so the docstring for ivy.erfinv also + applies to this method with minimal changes. + + Parameters + ---------- + x + The container whose array contains real or complex valued argument. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. + + Returns + ------- + ret + container with values of the inverse error function. + + Examples + -------- + >>> x = ivy.Container(a=ivy.array([1., 2.]), b=ivy.array([-3., -4.])) + >>> ivy.Container.static_erfinv(x) + { + a: ivy.array([0.15729921, 0.00467773]), + b: ivy.array([1.99997795, 2.]) + } + """ + return ContainerBase.cont_multi_map_in_function( + "erfinv", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def erfinv( + self: ivy.Container, + /, + *, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ivy.Container instance method variant of ivy.erfinv. This method + simply wraps the function, and so the docstring for ivy.erfinv also + applies to this method with minimal changes. + + Parameters + ---------- + self + The container whose array contains real or complex valued argument. + out + optional output container, for writing the result to. + + Returns + ------- + ret + container with values of the inverse error function. + + Examples + -------- + With one :class:`ivy.Container` input: + >>> x = ivy.Container(a=ivy.array([1., 2., 3.]), b=ivy.array([-1., -2., -3.])) + >>> x.erfinv() + { + a: ivy.array([1.57299206e-01, 4.67773480e-03, 2.20904985e-05]), + b: ivy.array([1.84270084, 1.99532223, 1.99997795]) + } + """ + return self.static_erfinv(self, out=out) diff --git a/ivy/functional/backends/jax/experimental/elementwise.py b/ivy/functional/backends/jax/experimental/elementwise.py index 1880fd049e412..363bc6a658213 100644 --- a/ivy/functional/backends/jax/experimental/elementwise.py +++ b/ivy/functional/backends/jax/experimental/elementwise.py @@ -499,3 +499,12 @@ def erfc( out: Optional[JaxArray] = None, ) -> JaxArray: return js.special.erfc(x) + + +def erfinv( + x: JaxArray, + /, + *, + out: Optional[JaxArray] = None, +) -> JaxArray: + return js.special.erfinv(x) diff --git a/ivy/functional/backends/numpy/experimental/elementwise.py b/ivy/functional/backends/numpy/experimental/elementwise.py index 895a0daf0f271..ab35f7d972aa7 100644 --- a/ivy/functional/backends/numpy/experimental/elementwise.py +++ b/ivy/functional/backends/numpy/experimental/elementwise.py @@ -602,3 +602,15 @@ def is_pos_inf(op): return np.where(underflow, result_underflow, result_no_underflow).astype( input_dtype ) + + +# TODO: Remove this once native function is available. +# Compute an approximation of the error function complement (1 - erf(x)). +def erfinv( + x: np.ndarray, + /, + *, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + with ivy.ArrayMode(False): + return np.sqrt(2) * erfc(x) diff --git a/ivy/functional/backends/paddle/experimental/elementwise.py b/ivy/functional/backends/paddle/experimental/elementwise.py index 7fabf0b78ef52..e0c53010b1cf2 100644 --- a/ivy/functional/backends/paddle/experimental/elementwise.py +++ b/ivy/functional/backends/paddle/experimental/elementwise.py @@ -815,3 +815,13 @@ def is_pos_inf(op): result = paddle.squeeze(result, axis=-1) return result + + +@with_supported_dtypes( + {"2.6.0 and below": ("float32", "float64")}, + backend_version, +) +def erfinv( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + return paddle.erfinv(x) diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py index 8f5b53c24a1dd..977e0d0584c87 100644 --- a/ivy/functional/backends/tensorflow/experimental/elementwise.py +++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py @@ -566,3 +566,13 @@ def erfc( out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: return tf.math.erfc(x) + + +@with_supported_dtypes({"2.15.0 and below": ("float",)}, backend_version) +def erfinv( + x: Union[tf.Tensor, tf.Variable], + /, + *, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + return tf.math.erfinv(x) diff --git a/ivy/functional/backends/torch/experimental/elementwise.py b/ivy/functional/backends/torch/experimental/elementwise.py index 309139a7bbf62..0b7b23d89f67b 100644 --- a/ivy/functional/backends/torch/experimental/elementwise.py +++ b/ivy/functional/backends/torch/experimental/elementwise.py @@ -442,3 +442,16 @@ def erfc( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.special.erfc(x) + + +@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version) +def erfinv( + x: torch.Tensor, + /, + *, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.special.erfinv(x, out=out) + + +erfinv.support_native_out = True diff --git a/ivy/functional/ivy/experimental/elementwise.py b/ivy/functional/ivy/experimental/elementwise.py index 83025fd5dd56f..c2680ceb4fe06 100644 --- a/ivy/functional/ivy/experimental/elementwise.py +++ b/ivy/functional/ivy/experimental/elementwise.py @@ -1637,3 +1637,39 @@ def erfc( ivy.array([0.00467773, 1.84270084, 1. ]) """ return ivy.current_backend(x).erfc(x, out=out) + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@handle_out_argument +@to_native_arrays_and_back +@handle_device +def erfinv( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + out: Optional[ivy.Array] = None, +): + """Compute the inverse error function. + + Parameters + ---------- + x + Input array of real or complex valued argument. + out + optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret + Values of the inverse error function. + + Examples + -------- + >>> x = ivy.array([0, 0.5, -1.]) + >>> ivy.erfinv(x) + ivy.array([0.0000, 0.4769, -inf]) + """ + return ivy.current_backend(x).erfinv(x, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py index 12e3bd938b063..03ca8d09b42dc 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py @@ -1,5 +1,5 @@ # global -from hypothesis import strategies as st +from hypothesis import assume, strategies as st # local import ivy @@ -510,6 +510,42 @@ def test_erfc( ) +# erfinv +@handle_test( + fn_tree="functional.ivy.experimental.erfinv", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + abs_smallest_val=1e-05, + ), +) +def test_erfinv( + *, + dtype_and_x, + backend_fw, + test_flags, + fn_name, + on_device, +): + input_dtype, x = dtype_and_x + if on_device == "cpu": + assume("float16" not in input_dtype and "bfloat16" not in input_dtype) + test_values = True + if backend_fw == "numpy": + # the numpy backend requires an approximation which doesn't pass the value tests + test_values = False + helpers.test_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + test_values=test_values, + x=x[0], + ) + + # fix @handle_test( fn_tree="functional.ivy.experimental.fix",