diff --git a/ivy/functional/backends/jax/gradients.py b/ivy/functional/backends/jax/gradients.py index c75a8b751d4e7..5e96bdd28c101 100644 --- a/ivy/functional/backends/jax/gradients.py +++ b/ivy/functional/backends/jax/gradients.py @@ -49,7 +49,7 @@ def _forward_fn( ivy.index_nest(xs, grad_idx), ivy.is_array ) for idx in xs_grad_arr_idx: - xs_grad_arr_idxs.append(grad_idx + idx) + xs_grad_arr_idxs.append(list(grad_idx) + idx) ivy.set_nest_at_indices(xs, xs_grad_arr_idxs, x_arr_values) elif ivy.is_array(xs): xs = x @@ -75,8 +75,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ): # Conversion of required arrays to float variables and duplicate index chains ( diff --git a/ivy/functional/backends/mxnet/gradients.py b/ivy/functional/backends/mxnet/gradients.py index a832aa2d0bab8..203640b7df2ef 100644 --- a/ivy/functional/backends/mxnet/gradients.py +++ b/ivy/functional/backends/mxnet/gradients.py @@ -2,7 +2,7 @@ signature.""" # global -from typing import Optional, Sequence, Union +from typing import Sequence, Union import mxnet as mx # local @@ -27,8 +27,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ): raise IvyNotImplementedException() diff --git a/ivy/functional/backends/numpy/gradients.py b/ivy/functional/backends/numpy/gradients.py index 9b1cb295b6d6e..d6ba1e9b55bd7 100644 --- a/ivy/functional/backends/numpy/gradients.py +++ b/ivy/functional/backends/numpy/gradients.py @@ -3,7 +3,7 @@ # global import logging -from typing import Optional, Sequence, Union +from typing import Sequence, Union import ivy @@ -31,8 +31,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ): logging.warning( "NumPy does not support autograd, " diff --git a/ivy/functional/backends/paddle/gradients.py b/ivy/functional/backends/paddle/gradients.py index 3f148343b03f1..ff3801646d31d 100644 --- a/ivy/functional/backends/paddle/gradients.py +++ b/ivy/functional/backends/paddle/gradients.py @@ -108,7 +108,7 @@ def grad_(x): {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version ) def execute_with_gradients( - func, xs, /, *, retain_grads=False, xs_grad_idxs=[[0]], ret_grad_idxs=[[0]] + func, xs, /, *, retain_grads=False, xs_grad_idxs=((0,),), ret_grad_idxs=((0,),) ): # Conversion of required arrays to float variables and duplicate index chains xs, xs_grad_idxs, xs1, required_duplicate_index_chains, _ = ( diff --git a/ivy/functional/backends/tensorflow/gradients.py b/ivy/functional/backends/tensorflow/gradients.py index a006020b08d1c..da688a2dd71f5 100644 --- a/ivy/functional/backends/tensorflow/gradients.py +++ b/ivy/functional/backends/tensorflow/gradients.py @@ -68,8 +68,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ): # Conversion of required arrays to float variables and duplicate index chains xs, xs_grad_idxs, xs_required, required_duplicate_index_chains, _ = ( diff --git a/ivy/functional/backends/torch/gradients.py b/ivy/functional/backends/torch/gradients.py index 18b1c139630a6..a1f84509bb8da 100644 --- a/ivy/functional/backends/torch/gradients.py +++ b/ivy/functional/backends/torch/gradients.py @@ -99,8 +99,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ): # Conversion of required arrays to float variables and duplicate index chains xs, xs_grad_idxs, xs1, required_duplicate_index_chains, _ = ( diff --git a/ivy/functional/ivy/gradients.py b/ivy/functional/ivy/gradients.py index 2d2c735ae9afd..8b9d72f383c1c 100644 --- a/ivy/functional/ivy/gradients.py +++ b/ivy/functional/ivy/gradients.py @@ -406,8 +406,8 @@ def execute_with_gradients( /, *, retain_grads: bool = False, - xs_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], - ret_grad_idxs: Optional[Sequence[Sequence[Union[str, int]]]] = [[0]], + xs_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), + ret_grad_idxs: Sequence[Sequence[Union[str, int]]] = ((0,),), ) -> Tuple[ivy.Array, ivy.Array]: """Call function func with input of xs variables, and return the function result func_ret and the gradients of each output variable w.r.t each input diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index a90f2df63f698..0a5f653f7fd8f 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -136,7 +136,7 @@ def check_all_or_any_fn( *args, fn, type="all", - limit=[0], + limit=(0,), message="args must exist according to type and limit given", as_array=True, ):