Skip to content

Commit

Permalink
Implement __rpow__ for Tensor (#2731)
Browse files Browse the repository at this point in the history
### Changes

Implement  __rpow__ for Tensor

### Related tickets

143943

### Tests

TemplateTestNNCFTensorOperators :: test_operators_int_rev
  • Loading branch information
AlexanderDokuchaev authored Jun 13, 2024
1 parent d06b174 commit b780041
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions nncf/experimental/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __rmul__(self, other: Union[Tensor, float]) -> Tensor:
def __pow__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(self.data ** unwrap_tensor_data(other))

def __rpow__(self, other: Union[Tensor, float]) -> Tensor:
return Tensor(unwrap_tensor_data(other) ** self.data)

def __truediv__(self, other: Union[Tensor, float]) -> Tensor:
return _call_function("_binary_op_nowarn", self, other, operator.truediv)

Expand Down
3 changes: 2 additions & 1 deletion tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"floordiv": operator.floordiv,
"neg": lambda a, _: -a,
}
BINARY_OPERATORS = ["add", "sub", "pow", "mul", "truediv", "floordiv"]

COMPARISON_OPERATOR_MAP = {
"lt": operator.lt,
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_operators_int(self, op_name):
assert isinstance(res_nncf, Tensor)
assert res_nncf.device == nncf_tensor_a.device

@pytest.mark.parametrize("op_name", ("add", "sub", "mul", "truediv", "floordiv"))
@pytest.mark.parametrize("op_name", BINARY_OPERATORS)
def test_operators_int_rev(self, op_name):
tensor_a = self.to_tensor([1, 2])
value = 2
Expand Down

0 comments on commit b780041

Please sign in to comment.