Skip to content

Commit

Permalink
feat(frontends): Add tensor_scatter_nd_add and test(#26361)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <[email protected]>
Co-authored-by: Zoe Caballero <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2023
1 parent 814389e commit 406c50d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ivy/functional/frontends/tensorflow/general_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,13 @@ def strided_slice(
return ret


@to_ivy_arrays_and_back
def tensor_scatter_nd_add(tensor, indices, updates, name=None):
zero_tensor = ivy.zeros_like(tensor)
scatter_tensor = ivy.scatter_nd(indices, updates, zero_tensor.shape)
return ivy.add(tensor, scatter_tensor)


@with_unsupported_dtypes({"2.14.0 and below": ("uint16",)}, "tensorflow")
@to_ivy_arrays_and_back
def tile(input, multiples, name=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2036,6 +2036,50 @@ def test_tensorflow_strided_slice(
raise e


# tensor_scatter_nd_add
@handle_frontend_test(
fn_tree="tensorflow.tensor_scatter_nd_add",
all_arguments=_multiple_shape_helper(),
tensor=helpers.array_values(
dtype=helpers.get_dtypes("numeric"), shape=(8,), min_value=2, max_value=49
),
indices=helpers.array_values(
dtype=helpers.get_dtypes("integer"), shape=(4, 1), min_value=0, max_value=7
),
updates=helpers.array_values(
dtype=helpers.get_dtypes("integer"),
shape=(4,),
min_value=9,
max_value=12,
),
)
def test_tensorflow_tensor_scatter_nd_add(
*,
all_arguments,
tensor,
indices,
updates,
frontend,
test_flags,
fn_tree,
on_device,
backend_fw,
):
input_dtype, input_matrix, dt_and_multiples = all_arguments
dt_mul, multiples = dt_and_multiples
helpers.test_frontend_function(
input_dtypes=input_dtype + dt_mul,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
tensor=tensor[0],
indices=indices[0],
updates=updates[0],
)


@handle_frontend_test(fn_tree="tensorflow.tile", all_arguments=_multiple_shape_helper())
def test_tensorflow_tile(
*,
Expand Down

0 comments on commit 406c50d

Please sign in to comment.