From 6629e75e34149480ba501cef4c29cfe74b4efce2 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 15 Mar 2024 18:37:53 +0200 Subject: [PATCH 1/5] tensor dispatcher --- nncf/experimental/tensor/README.md | 27 +-- .../tensor/functions/dispatcher.py | 159 ++++++++++--- nncf/experimental/tensor/functions/linalg.py | 7 +- nncf/experimental/tensor/functions/numeric.py | 215 ++++-------------- .../tensor/functions/numpy_linalg.py | 3 +- .../tensor/functions/numpy_numeric.py | 101 ++++---- .../tensor/functions/torch_linalg.py | 2 +- .../tensor/functions/torch_numeric.py | 98 ++++---- nncf/experimental/tensor/tensor.py | 9 +- 9 files changed, 276 insertions(+), 345 deletions(-) diff --git a/nncf/experimental/tensor/README.md b/nncf/experimental/tensor/README.md index 13d75081e57..a9d614b6f0c 100644 --- a/nncf/experimental/tensor/README.md +++ b/nncf/experimental/tensor/README.md @@ -108,37 +108,32 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) 2. Add function to functions module ```python - @functools.singledispatch - def foo(a: TTensor, arg1: Type) -> TTensor: + @tensor_dispatch + def foo(a: Tensor, arg1: Type) -> Tensor: """ __description__ - :param a: The input tensor. + :param a: __description__ :param arg1: __description__ :return: __description__ """ - if isinstance(a, tensor.Tensor): - return tensor.Tensor(foo(a.data, axis)) - return NotImplemented(f"Function `foo` is not implemented for {type(a)}") ``` - **NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument. + **NOTE** Type of wrapper function selected by type hint of function, supported signatures of functions: ```python - @functools.singledispatch - def foo(x: List[Tensor], axis: int = 0) -> Tensor: - if isinstance(x, List): - unwrapped_x = [i.data for i in x] - return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis)) - raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}") + def foo(a: Tensor, *args) -> Tensor: + def foo(a: Tensor, *args) -> Any: + def foo(a: Tensor, *args) -> List[Tensor]: + def foo(a: List[Tensor], *args) -> Tensor: ``` -3. Add backend specific implementation of method to correcponding module: +3. Add backend specific implementation of method to corresponding module: - `functions/numpy_*.py` ```python - @_register_numpy_types(fns.foo) + @fns.foo.register def _(a: TType, arg1: Type) -> np.ndarray: return np.foo(a, arg1) ``` @@ -146,7 +141,7 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) - `functions/torch_*.py` ```python - @fns.foo.register(torch.Tensor) + @fns.foo.register def _(a: torch.Tensor, arg1: Type) -> torch.Tensor: return torch.foo(a, arg1) ``` diff --git a/nncf/experimental/tensor/functions/dispatcher.py b/nncf/experimental/tensor/functions/dispatcher.py index 525910bd94a..402fd9068e5 100644 --- a/nncf/experimental/tensor/functions/dispatcher.py +++ b/nncf/experimental/tensor/functions/dispatcher.py @@ -8,50 +8,147 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import List - -import numpy as np +from functools import _find_impl +from inspect import getfullargspec +from inspect import isclass +from inspect import isfunction +from types import MappingProxyType +from typing import List, _GenericAlias, _UnionGenericAlias, get_type_hints from nncf.experimental.tensor import Tensor -def tensor_guard(func: callable): +def _get_target_types(type_alias): + if isclass(type_alias): + return [type_alias] + if isinstance(type_alias, (_UnionGenericAlias, _GenericAlias)): + ret = [] + for t in type_alias.__args__: + ret.extend(_get_target_types(t)) + return ret + + +def tensor_dispatch(func): """ - A decorator that ensures that the first argument to the decorated function is a Tensor. + This decorator creates a registry of functions for different types and provides a wrapper + that calls the appropriate function based on the type of the first argument. + It's particularly designed to handle Tensor inputs and outputs effectively. + + :param func: The function to be decorated. + :return: The decorated function with type-based dispatching functionality. """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - if isinstance(args[0], Tensor): - return func(*args, **kwargs) - raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {type(args[0])}") + registry = {} - return wrapper + def dispatch(cls): + """ + Retrieves the registered function for a given type. + :param cls: The type to retrieve the function for. + :return: The registered function for the given type, or a function that raises a NotImplementedError + if no function is registered for type. + """ + try: + return registry[cls] + except KeyError: + return _find_impl(cls, registry) -def dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs): - """ - Dispatches the function to the type of the wrapped data of the first element in tensor_list. + def register(rfunc): + """Registers a function for a specific type or types. - :param fn: A function wrapped by `functools.singledispatch`. - :param tensor_list: List of Tensors. - :return: The result value of the function call. - """ - unwrapped_list = [i.data for i in tensor_list] - return fn.dispatch(type(unwrapped_list[0]))(unwrapped_list, *args, **kwargs) + :param rfunc: The function to register. + :return: The registered function. + """ + assert isfunction(rfunc), "Register object should be a function." + assert getfullargspec(func)[0] == getfullargspec(rfunc)[0], "Differ names of arguments of function" + target_type_hint = get_type_hints(rfunc).get(getfullargspec(rfunc)[0][0]) + assert target_type_hint is not None, "No type hint for first argument of function" -def register_numpy_types(singledispatch_fn): - """ - Decorator to register function to singledispatch for numpy classes. + types_to_registry = set(_get_target_types(target_type_hint)) - :param singledispatch_fn: singledispatch function. - """ + for t in types_to_registry: + assert t not in registry, f"{t} already registered for function" + registry[t] = rfunc + return rfunc - def inner(func): - singledispatch_fn.register(np.ndarray)(func) - singledispatch_fn.register(np.generic)(func) - return func + def wrapper_tensor_to_tensor(*args, **kw): + """ + Wrapper for functions that take and return a Tensor. + This wrapper unwraps Tensor arguments and wraps the returned value in a Tensor if necessary. + """ + is_wrapped = any(isinstance(x, Tensor) for x in args) + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + ret = dispatch(args[0].__class__)(*args, **kw) + return Tensor(ret) if is_wrapped else ret - return inner + def wrapper_tensor_to_any(*args, **kw): + """ + Wrapper for functions that take a Tensor and return any type. + This wrapper unwraps Tensor arguments but doesn't specifically wrap the returned value. + """ + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + return dispatch(args[0].__class__)(*args, **kw) + + def wrapper_tensor_to_list(*args, **kw): + """ + Wrapper for functions that take a Tensor and return a list. + This wrapper unwraps Tensor arguments and wraps the list elements as Tensors if necessary. + """ + is_wrapped = any(isinstance(x, Tensor) for x in args) + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + ret = dispatch(args[0].__class__)(*args, **kw) + if is_wrapped: + return [Tensor(x) for x in ret] + return ret + + def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw): + """ + Wrapper for functions that take a list of Tensors and return a Tensor. + This wrapper handles lists containing Tensors appropriately. + """ + if any(isinstance(x, Tensor) for x in list_of_tensors): + list_of_tensors = [x.data if isinstance(x, Tensor) else x for x in list_of_tensors] + return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw)) + return dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw) + + def raise_not_implemented(*args, **kw): + """ + Raises a NotImplementedError for types that are not registered. + """ + if isinstance(args[0], list): + arg_type = type(args[0][0].data) if isinstance(args[0][0], Tensor) else type(args[0][0]) + else: + arg_type = type(args[0].data) if isinstance(args[0], Tensor) else type(args[0]) + + raise NotImplementedError(f"Function `{func.__name__}` is not implemented for {arg_type}") + + # Select wrapper by signature of function + type_hints = get_type_hints(func) + first_type_hint = type_hints.get(getfullargspec(func)[0][0]) + return_type_hint = type_hints.get("return") + wrapper = None + if first_type_hint is Tensor: + if return_type_hint is Tensor: + wrapper = wrapper_tensor_to_tensor + elif isinstance(return_type_hint, _GenericAlias) and not isinstance(return_type_hint, _UnionGenericAlias): + wrapper = wrapper_tensor_to_list + else: + wrapper = wrapper_tensor_to_any + elif isinstance(first_type_hint, _GenericAlias) and return_type_hint is Tensor: + wrapper = wrapper_list_to_tensor + + assert wrapper is not None, ( + "Not supported signature of dispatch function, supported:\n" + " def foo(a: Tensor, ...) -> Tensor\n" + " def foo(a: Tensor, ...) -> Any\n" + " def foo(a: Tensor, ...) -> List[Tensor]\n" + " def foo(a: List[Tensor], ...) -> Tensor\n" + ) + + registry[object] = raise_not_implemented + wrapper.register = register + wrapper.dispatch = dispatch + wrapper.registry = MappingProxyType(registry) + + return wrapper diff --git a/nncf/experimental/tensor/functions/linalg.py b/nncf/experimental/tensor/functions/linalg.py index 167c045b194..0ecb0aa96bf 100644 --- a/nncf/experimental/tensor/functions/linalg.py +++ b/nncf/experimental/tensor/functions/linalg.py @@ -9,15 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from typing import Optional, Tuple, Union from nncf.experimental.tensor import Tensor -from nncf.experimental.tensor.functions.dispatcher import tensor_guard +from nncf.experimental.tensor.functions.dispatcher import tensor_dispatch -@functools.singledispatch -@tensor_guard +@tensor_dispatch def norm( a: Tensor, ord: Optional[Union[str, float, int]] = None, @@ -61,4 +59,3 @@ def norm( as dimensions with size one. Default: False. :return: Norm of the matrix or vector. """ - return Tensor(norm(a.data, ord, axis, keepdims)) diff --git a/nncf/experimental/tensor/functions/numeric.py b/nncf/experimental/tensor/functions/numeric.py index 3dd2c0d8815..95e07a82a38 100644 --- a/nncf/experimental/tensor/functions/numeric.py +++ b/nncf/experimental/tensor/functions/numeric.py @@ -9,20 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from nncf.experimental.tensor.definitions import TensorDataType from nncf.experimental.tensor.definitions import TensorDeviceType from nncf.experimental.tensor.definitions import TypeInfo -from nncf.experimental.tensor.functions.dispatcher import dispatch_list -from nncf.experimental.tensor.functions.dispatcher import tensor_guard +from nncf.experimental.tensor.functions.dispatcher import tensor_dispatch from nncf.experimental.tensor.tensor import Tensor -from nncf.experimental.tensor.tensor import unwrap_tensor_data -@functools.singledispatch -@tensor_guard +@tensor_dispatch def device(a: Tensor) -> TensorDeviceType: """ Return the device of the tensor. @@ -30,11 +26,9 @@ def device(a: Tensor) -> TensorDeviceType: :param a: The input tensor. :return: The device of the tensor. """ - return device(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -45,11 +39,9 @@ def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Te This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar. """ - return Tensor(squeeze(a.data, axis=axis)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def flatten(a: Tensor) -> Tensor: """ Return a copy of the tensor collapsed into one dimension. @@ -57,11 +49,9 @@ def flatten(a: Tensor) -> Tensor: :param a: The input tensor. :return: A copy of the input tensor, flattened to one dimension. """ - return Tensor(flatten(a.data)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Return the maximum of an array or maximum along an axis. @@ -72,11 +62,9 @@ def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: one. With this option, the result will broadcast correctly against the input array. False, by default. :return: Maximum of a. """ - return Tensor(max(a.data, axis, keepdims)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Return the minimum of an array or minimum along an axis. @@ -87,11 +75,9 @@ def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: one. With this option, the result will broadcast correctly against the input array. False, by default. :return: Minimum of a. """ - return Tensor(min(a.data, axis, keepdims)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def abs(a: Tensor) -> Tensor: """ Calculate the absolute value element-wise. @@ -99,12 +85,10 @@ def abs(a: Tensor) -> Tensor: :param a: The input tensor. :return: A tensor containing the absolute value of each element in x. """ - return Tensor(abs(a.data)) -@functools.singledispatch -@tensor_guard -def astype(a: Tensor, data_type: TensorDataType) -> Tensor: +@tensor_dispatch +def astype(a: Tensor, dtype: TensorDataType) -> Tensor: """ Copy of the tensor, cast to a specified type. @@ -113,11 +97,9 @@ def astype(a: Tensor, data_type: TensorDataType) -> Tensor: :return: Copy of the tensor in specified type. """ - return Tensor(astype(a.data, data_type)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def dtype(a: Tensor) -> TensorDataType: """ Return data type of the tensor. @@ -125,11 +107,9 @@ def dtype(a: Tensor) -> TensorDataType: :param a: The input tensor. :return: The data type of the tensor. """ - return dtype(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -138,11 +118,9 @@ def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: :param shape: The new shape should be compatible with the original shape. :return: Reshaped tensor. """ - return Tensor(reshape(a.data, shape)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Test whether all tensor elements along a given axis evaluate to True. @@ -151,11 +129,9 @@ def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor :param axis: Axis or axes along which a logical AND reduction is performed. :return: A new tensor. """ - return Tensor(all(a.data, axis=axis)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def allclose( a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> bool: @@ -171,11 +147,9 @@ def allclose( Defaults to False. :return: True if the two arrays are equal within the given tolerance, otherwise False. """ - return allclose(a.data, unwrap_tensor_data(b), rtol=rtol, atol=atol, equal_nan=equal_nan) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Test whether any tensor elements along a given axis evaluate to True. @@ -184,11 +158,9 @@ def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor :param axis: Axis or axes along which a logical OR reduction is performed. :return: A new tensor. """ - return Tensor(any(a.data, axis)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -198,11 +170,9 @@ def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) :return: Number of non-zero values in the tensor along a given axis. Otherwise, the total number of non-zero values in the tensor is returned. """ - return Tensor(count_nonzero(a.data, axis)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def isempty(a: Tensor) -> bool: """ Return True if input tensor is empty. @@ -210,11 +180,9 @@ def isempty(a: Tensor) -> bool: :param a: The input tensor. :return: True if tensor is empty, otherwise False. """ - return isempty(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def isclose( a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> Tensor: @@ -230,19 +198,9 @@ def isclose( Defaults to False. :return: Returns a boolean tensor of where a and b are equal within the given tolerance. """ - return Tensor( - isclose( - a.data, - unwrap_tensor_data(b), - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - ) - ) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise maximum of tensor elements. @@ -251,11 +209,9 @@ def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor. :return: Output tensor. """ - return Tensor(maximum(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise minimum of tensor elements. @@ -264,11 +220,9 @@ def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor. :return: Output tensor. """ - return Tensor(minimum(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def ones_like(a: Tensor) -> Tensor: """ Return a tensor of ones with the same shape and type as a given tensor. @@ -276,11 +230,9 @@ def ones_like(a: Tensor) -> Tensor: :param a: The shape and data-type of a define these same attributes of the returned tensor. :return: Tensor of ones with the same shape and type as a. """ - return Tensor(ones_like(a.data)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) -> Tensor: """ Return elements chosen from x or y depending on condition. @@ -290,17 +242,9 @@ def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) - :param y: Value at indices where condition is False. :return: A tensor with elements from x where condition is True, and elements from y elsewhere. """ - return Tensor( - where( - condition.data, - unwrap_tensor_data(x), - unwrap_tensor_data(y), - ) - ) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def zeros_like(a: Tensor) -> Tensor: """ Return an tensor of zeros with the same shape and type as a given tensor. @@ -308,10 +252,9 @@ def zeros_like(a: Tensor) -> Tensor: :param input: The shape and data-type of a define these same attributes of the returned tensor. :return: tensor of zeros with the same shape and type as a. """ - return Tensor(zeros_like(a.data)) -@functools.singledispatch +@tensor_dispatch def stack(x: List[Tensor], axis: int = 0) -> Tensor: """ Stacks a list of Tensors rank-R tensors into one Tensor rank-(R+1) tensor. @@ -320,27 +263,20 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor: :param axis: The axis to stack along. :return: Stacked Tensor. """ - if isinstance(x, List): - return Tensor(dispatch_list(stack, x, axis=axis)) - raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") -@functools.singledispatch -@tensor_guard -def unstack(a: Tensor, axis: int = 0) -> List[Tensor]: +@tensor_dispatch +def unstack(x: Tensor, axis: int = 0) -> List[Tensor]: """ Unstack a Tensor into list. - :param a: Tensor to unstack. + :param x: Tensor to unstack. :param axis: The axis to unstack along. :return: List of Tensor. """ - res = unstack(a.data, axis=axis) - return [Tensor(i) for i in res] -@functools.singledispatch -@tensor_guard +@tensor_dispatch def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor: """ Move axes of an array to new positions. @@ -350,11 +286,9 @@ def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[ :param destination: Destination positions for each of the original axes. These must also be unique. :return: Array with moved axes. """ - return Tensor(moveaxis(a.data, source, destination)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Compute the arithmetic mean along the specified axis. @@ -364,11 +298,9 @@ def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims :param keepdims: Destination positions for each of the original axes. These must also be unique. :return: Array with moved axes. """ - return Tensor(mean(a.data, axis, keepdims)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def round(a: Tensor, decimals=0) -> Tensor: """ Evenly round to the given number of decimals. @@ -378,11 +310,9 @@ def round(a: Tensor, decimals=0) -> Tensor: it specifies the number of positions to the left of the decimal point. :return: An array of the same type as a, containing the rounded values. """ - return Tensor(round(a.data, decimals)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def power(a: Tensor, exponent: Union[Tensor, float]) -> Tensor: """ Takes the power of each element in input with exponent and returns a tensor with the result. @@ -394,11 +324,9 @@ def power(a: Tensor, exponent: Union[Tensor, float]) -> Tensor: :param exponent: Exponent value. :return: The result of the power of each element in input with given exponent. """ - return Tensor(power(a.data, unwrap_tensor_data(exponent))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def quantile( a: Tensor, q: Union[float, List[float]], @@ -417,39 +345,9 @@ def quantile( :return: An tensor with quantiles, the first axis of the result corresponds to the quantiles, other axes of the result correspond to the quantiles values. """ - return Tensor(quantile(a.data, q, axis, keepdims)) -@functools.singledispatch -@tensor_guard -def _binary_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: - """ - Applies a binary operation with disable warnings. - - :param a: The first tensor. - :param b: The second tensor. - :param operator_fn: The binary operation function. - :return: The result of the binary operation. - """ - return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) - - -@functools.singledispatch -@tensor_guard -def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: - """ - Applies a binary reverse operation with disable warnings. - - :param a: The first tensor. - :param b: The second tensor. - :param operator_fn: The binary operation function. - :return: The result of the binary operation. - """ - return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) - - -@functools.singledispatch -@tensor_guard +@tensor_dispatch def finfo(a: Tensor) -> TypeInfo: """ Returns machine limits for tensor type. @@ -457,11 +355,9 @@ def finfo(a: Tensor) -> TypeInfo: :param a: Tensor. :return: TypeInfo. """ - return finfo(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def clip(a: Tensor, a_min: Union[Tensor, float], a_max: Union[Tensor, float]) -> Tensor: """ Clips all elements in input into the range [ a_min, a_max ] @@ -472,11 +368,9 @@ def clip(a: Tensor, a_min: Union[Tensor, float], a_max: Union[Tensor, float]) -> :return: A clipped tensor with the elements of a, but where values < a_min are replaced with a_min, and those > a_max with a_max. """ - return Tensor(clip(a.data, unwrap_tensor_data(a_min), unwrap_tensor_data(a_max))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def as_tensor_like(a: Tensor, data: Any) -> Tensor: """ Converts the data into a tensor with the same data representation and hosted on the same device @@ -487,11 +381,9 @@ def as_tensor_like(a: Tensor, data: Any) -> Tensor: :return: A tensor with the same data representation and hosted on the same device as a, and which has been initialized with data. """ - return Tensor(as_tensor_like(a.data, data)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def item(a: Tensor) -> Union[int, float, bool]: """ Returns the value of this tensor as a standard Python number. This only works for tensors with one element. @@ -499,13 +391,9 @@ def item(a: Tensor) -> Union[int, float, bool]: :param a: Tensor. :return: The value of this tensor as a standard Python number """ - if isinstance(a.data, (int, float, bool)): - return a.data - return item(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def sum(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Sum of tensor elements over a given axis. @@ -517,11 +405,9 @@ def sum(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: with size one. :return: Returns the sum of all elements in the input tensor in the given axis. """ - return Tensor(sum(a.data, axis, keepdims)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def multiply(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Multiply arguments element-wise. @@ -530,11 +416,9 @@ def multiply(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor or number. :return: The product of x1 and x2, element-wise. """ - return Tensor(multiply(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def var(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0) -> Tensor: """ Compute the variance along the specified axis. @@ -548,11 +432,9 @@ def var(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: By default ddof is zero. :return: A new tensor containing the variance. """ - return Tensor(var(a.data, axis, keepdims, ddof)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def size(a: Tensor) -> int: """ Return number of elements in the tensor. @@ -560,11 +442,9 @@ def size(a: Tensor) -> int: :param a: The input tensor :return: The size of the input tensor. """ - return size(a.data) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def matmul(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Matrix multiplication. @@ -573,11 +453,9 @@ def matmul(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: :param x2: The second input tensor or number. :return: The product of x1 and x2, matmul. """ - return Tensor(matmul(x1.data, unwrap_tensor_data(x2))) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def unsqueeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Add axes of length one to a. @@ -586,11 +464,9 @@ def unsqueeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> :param axis: Selects a subset of the entries of length one in the shape. :return: The input array, but with expanded shape with len 1 defined in axis. """ - return Tensor(unsqueeze(a.data, axis=axis)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def transpose(a: Tensor, axes: Optional[Tuple[int, ...]] = None) -> Tensor: """ Returns an array with axes transposed. @@ -599,11 +475,9 @@ def transpose(a: Tensor, axes: Optional[Tuple[int, ...]] = None) -> Tensor: :param axes: list of permutations or None. :return: array with permuted axes. """ - return Tensor(transpose(a.data, axes=axes)) -@functools.singledispatch -@tensor_guard +@tensor_dispatch def argsort(a: Tensor, axis: int = -1, descending: bool = False, stable: bool = False) -> Tensor: """ Returns the indices that would sort an array. @@ -615,4 +489,3 @@ def argsort(a: Tensor, axis: int = -1, descending: bool = False, stable: bool = If False, the relative order of values which compare equal is not guaranteed. True is slower. :return: Array of indices that sort a along the specified axis. """ - return Tensor(argsort(a.data, axis=axis)) diff --git a/nncf/experimental/tensor/functions/numpy_linalg.py b/nncf/experimental/tensor/functions/numpy_linalg.py index 6821d92ec8e..33394afb766 100644 --- a/nncf/experimental/tensor/functions/numpy_linalg.py +++ b/nncf/experimental/tensor/functions/numpy_linalg.py @@ -14,10 +14,9 @@ import numpy as np from nncf.experimental.tensor.functions import linalg -from nncf.experimental.tensor.functions.dispatcher import register_numpy_types -@register_numpy_types(linalg.norm) +@linalg.norm.register def _( a: Union[np.ndarray, np.generic], ord: Optional[Union[str, float, int]] = None, diff --git a/nncf/experimental/tensor/functions/numpy_numeric.py b/nncf/experimental/tensor/functions/numpy_numeric.py index 3aef3df1e5f..ad40d513a39 100644 --- a/nncf/experimental/tensor/functions/numpy_numeric.py +++ b/nncf/experimental/tensor/functions/numpy_numeric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -17,7 +17,6 @@ from nncf.experimental.tensor.definitions import TensorDeviceType from nncf.experimental.tensor.definitions import TypeInfo from nncf.experimental.tensor.functions import numeric as numeric -from nncf.experimental.tensor.functions.dispatcher import register_numpy_types DTYPE_MAP = { TensorDataType.float16: np.dtype(np.float16), @@ -32,63 +31,63 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@register_numpy_types(numeric.device) +@numeric.device.register def _(a: Union[np.ndarray, np.generic]) -> TensorDeviceType: return TensorDeviceType.CPU -@register_numpy_types(numeric.squeeze) +@numeric.squeeze.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[np.ndarray, np.generic]: return np.squeeze(a, axis=axis) -@register_numpy_types(numeric.flatten) +@numeric.flatten.register def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return a.flatten() -@register_numpy_types(numeric.max) +@numeric.max.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False ) -> np.ndarray: return np.array(np.max(a, axis=axis, keepdims=keepdims)) -@register_numpy_types(numeric.min) +@numeric.min.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False ) -> Union[np.ndarray, np.generic]: return np.array(np.min(a, axis=axis, keepdims=keepdims)) -@register_numpy_types(numeric.abs) +@numeric.abs.register def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]: return np.absolute(a) -@register_numpy_types(numeric.astype) +@numeric.astype.register def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> Union[np.ndarray, np.generic]: return a.astype(DTYPE_MAP[dtype]) -@register_numpy_types(numeric.dtype) +@numeric.dtype.register def _(a: Union[np.ndarray, np.generic]) -> TensorDataType: return DTYPE_MAP_REV[np.dtype(a.dtype)] -@register_numpy_types(numeric.reshape) +@numeric.reshape.register def _(a: Union[np.ndarray, np.generic], shape: Union[int, Tuple[int, ...]]) -> np.ndarray: return a.reshape(shape) -@register_numpy_types(numeric.all) +@numeric.all.register def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.array(np.all(a, axis=axis)) -@register_numpy_types(numeric.allclose) +@numeric.allclose.register def _( a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], @@ -99,22 +98,22 @@ def _( return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@register_numpy_types(numeric.any) +@numeric.any.register def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.array(np.any(a, axis=axis)) -@register_numpy_types(numeric.count_nonzero) +@numeric.count_nonzero.register def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.array(np.count_nonzero(a, axis=axis)) -@register_numpy_types(numeric.isempty) +@numeric.isempty.register def _(a: Union[np.ndarray, np.generic]) -> bool: return a.size == 0 -@register_numpy_types(numeric.isclose) +@numeric.isclose.register def _( a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], @@ -125,22 +124,22 @@ def _( return np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@register_numpy_types(numeric.maximum) +@numeric.maximum.register def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.maximum(x1, x2) -@register_numpy_types(numeric.minimum) +@numeric.minimum.register def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.minimum(x1, x2) -@register_numpy_types(numeric.ones_like) +@numeric.ones_like.register def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.ones_like(a) -@register_numpy_types(numeric.where) +@numeric.where.register def _( condition: Union[np.ndarray, np.generic], x: Union[np.ndarray, np.generic, float], @@ -149,42 +148,42 @@ def _( return np.where(condition, x, y) -@register_numpy_types(numeric.zeros_like) +@numeric.zeros_like.register def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.zeros_like(a) -@register_numpy_types(numeric.stack) +@numeric.stack.register def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: return np.stack(x, axis=axis) -@register_numpy_types(numeric.unstack) +@numeric.unstack.register def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: return [np.squeeze(e, axis) for e in np.split(x, x.shape[axis], axis=axis)] -@register_numpy_types(numeric.moveaxis) +@numeric.moveaxis.register def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> np.ndarray: return np.moveaxis(a, source, destination) -@register_numpy_types(numeric.mean) +@numeric.mean.register def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray: return np.array(np.mean(a, axis=axis, keepdims=keepdims)) -@register_numpy_types(numeric.round) +@numeric.round.register def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: return np.round(a, decimals=decimals) -@register_numpy_types(numeric.power) +@numeric.power.register def _(a: Union[np.ndarray, np.generic], exponent: Union[np.ndarray, float]) -> Union[np.ndarray, np.generic]: return np.power(a, exponent) -@register_numpy_types(numeric.quantile) +@numeric.quantile.register def _( a: Union[np.ndarray, np.generic], q: Union[float, List[float]], @@ -194,31 +193,13 @@ def _( return np.array(np.quantile(a, q=q, axis=axis, keepdims=keepdims)) -@register_numpy_types(numeric._binary_op_nowarn) -def _( - a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable -) -> Union[np.ndarray, np.generic]: - # Run operator with disabled warning - with np.errstate(invalid="ignore", divide="ignore"): - return operator_fn(a, b) - - -@register_numpy_types(numeric._binary_reverse_op_nowarn) -def _( - a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable -) -> Union[np.ndarray, np.generic]: - # Run operator with disabled warning - with np.errstate(invalid="ignore", divide="ignore"): - return operator_fn(b, a) - - -@register_numpy_types(numeric.finfo) -def _(a: np.ndarray) -> TypeInfo: +@numeric.finfo.register +def _(a: Union[np.ndarray, np.generic]) -> TypeInfo: ti = np.finfo(a.dtype) return TypeInfo(ti.eps, ti.max, ti.min) -@register_numpy_types(numeric.clip) +@numeric.clip.register def _( a: Union[np.ndarray, np.generic], a_min: Union[np.ndarray, np.generic, float], @@ -227,29 +208,29 @@ def _( return np.clip(a, a_min, a_max) -@register_numpy_types(numeric.as_tensor_like) +@numeric.as_tensor_like.register def _(a: Union[np.ndarray, np.generic], data: Any) -> Union[np.ndarray, np.generic]: return np.array(data) -@register_numpy_types(numeric.item) +@numeric.item.register def _(a: Union[np.ndarray, np.generic]) -> Union[int, float, bool]: return a.item() -@register_numpy_types(numeric.sum) +@numeric.sum.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False ) -> np.ndarray: return np.array(np.sum(a, axis=axis, keepdims=keepdims)) -@register_numpy_types(numeric.multiply) +@numeric.multiply.register def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.multiply(x1, x2) -@register_numpy_types(numeric.var) +@numeric.var.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None, @@ -259,29 +240,29 @@ def _( return np.array(np.var(a, axis=axis, keepdims=keepdims, ddof=ddof)) -@register_numpy_types(numeric.size) +@numeric.size.register def _(a: Union[np.ndarray, np.generic]) -> int: return a.size -@register_numpy_types(numeric.matmul) +@numeric.matmul.register def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.matmul(x1, x2) -@register_numpy_types(numeric.unsqueeze) +@numeric.unsqueeze.register def _( a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None ) -> Union[np.ndarray, np.generic]: return np.expand_dims(a, axis=axis) -@register_numpy_types(numeric.transpose) +@numeric.transpose.register def _(a: Union[np.ndarray, np.generic], axes: Optional[Tuple[int, ...]] = None) -> Union[np.ndarray, np.generic]: return np.transpose(a, axes=axes) -@register_numpy_types(numeric.argsort) +@numeric.argsort.register def _( a: Union[np.ndarray, np.generic], axis: Optional[int] = None, descending=False, stable=False ) -> Union[np.ndarray, np.generic]: diff --git a/nncf/experimental/tensor/functions/torch_linalg.py b/nncf/experimental/tensor/functions/torch_linalg.py index 60afceeb352..da3fc46d45c 100644 --- a/nncf/experimental/tensor/functions/torch_linalg.py +++ b/nncf/experimental/tensor/functions/torch_linalg.py @@ -15,7 +15,7 @@ from nncf.experimental.tensor.functions import linalg -@linalg.norm.register(torch.Tensor) +@linalg.norm.register def _( a: torch.Tensor, ord: Optional[Union[str, float, int]] = None, diff --git a/nncf/experimental/tensor/functions/torch_numeric.py b/nncf/experimental/tensor/functions/torch_numeric.py index 781e1ce49e8..7732cda5101 100644 --- a/nncf/experimental/tensor/functions/torch_numeric.py +++ b/nncf/experimental/tensor/functions/torch_numeric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch @@ -32,7 +32,7 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@numeric.device.register(torch.Tensor) +@numeric.device.register def _(a: torch.Tensor) -> TensorDeviceType: DEVICE_MAP = { "cpu": TensorDeviceType.CPU, @@ -41,7 +41,7 @@ def _(a: torch.Tensor) -> TensorDeviceType: return DEVICE_MAP[a.device.type] -@numeric.squeeze.register(torch.Tensor) +@numeric.squeeze.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return a.squeeze() @@ -51,55 +51,55 @@ def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> to return a.squeeze(axis) -@numeric.flatten.register(torch.Tensor) +@numeric.flatten.register def _(a: torch.Tensor) -> torch.Tensor: return a.flatten() -@numeric.max.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> torch.Tensor: +@numeric.max.register +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> torch.Tensor: # Analog of numpy.max is torch.amax if axis is None: return torch.amax(a) - return torch.amax(a, dim=axis, keepdim=keepdim) + return torch.amax(a, dim=axis, keepdim=keepdims) -@numeric.min.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> torch.Tensor: +@numeric.min.register +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> torch.Tensor: # Analog of numpy.min is torch.amin if axis is None: return torch.amin(a) - return torch.amin(a, dim=axis, keepdim=keepdim) + return torch.amin(a, dim=axis, keepdim=keepdims) -@numeric.abs.register(torch.Tensor) +@numeric.abs.register def _(a: torch.Tensor) -> torch.Tensor: return torch.absolute(a) -@numeric.astype.register(torch.Tensor) +@numeric.astype.register def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: return a.type(DTYPE_MAP[dtype]) -@numeric.dtype.register(torch.Tensor) +@numeric.dtype.register def _(a: torch.Tensor) -> TensorDataType: return DTYPE_MAP_REV[a.dtype] -@numeric.reshape.register(torch.Tensor) +@numeric.reshape.register def _(a: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: return a.reshape(shape) -@numeric.all.register(torch.Tensor) +@numeric.all.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return torch.all(a) return torch.all(a, dim=axis) -@numeric.allclose.register(torch.Tensor) +@numeric.allclose.register def _( a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> bool: @@ -108,24 +108,24 @@ def _( return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@numeric.any.register(torch.Tensor) +@numeric.any.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return torch.any(a) return torch.any(a, dim=axis) -@numeric.count_nonzero.register(torch.Tensor) +@numeric.count_nonzero.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: return torch.count_nonzero(a, dim=axis) -@numeric.isempty.register(torch.Tensor) +@numeric.isempty.register def _(a: torch.Tensor) -> bool: return a.numel() == 0 -@numeric.isclose.register(torch.Tensor) +@numeric.isclose.register def _( a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False ) -> torch.Tensor: @@ -134,70 +134,70 @@ def _( return torch.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) -@numeric.maximum.register(torch.Tensor) +@numeric.maximum.register def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.maximum(x1, x2) -@numeric.minimum.register(torch.Tensor) +@numeric.minimum.register def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.minimum(x1, x2) -@numeric.ones_like.register(torch.Tensor) +@numeric.ones_like.register def _(a: torch.Tensor) -> torch.Tensor: return torch.ones_like(a) -@numeric.where.register(torch.Tensor) +@numeric.where.register def _( condition: torch.Tensor, x: Union[torch.Tensor, float, bool], y: Union[torch.Tensor, float, bool] ) -> torch.Tensor: return torch.where(condition, x, y) -@numeric.zeros_like.register(torch.Tensor) +@numeric.zeros_like.register def _(a: torch.Tensor) -> torch.Tensor: return torch.zeros_like(a) -@numeric.stack.register(torch.Tensor) +@numeric.stack.register def _(x: List[torch.Tensor], axis: int = 0) -> List[torch.Tensor]: return torch.stack(x, dim=axis) -@numeric.unstack.register(torch.Tensor) +@numeric.unstack.register def _(x: torch.Tensor, axis: int = 0) -> List[torch.Tensor]: if not list(x.shape): x = x.unsqueeze(0) return torch.unbind(x, dim=axis) -@numeric.moveaxis.register(torch.Tensor) +@numeric.moveaxis.register def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> torch.Tensor: return torch.moveaxis(a, source, destination) -@numeric.mean.register(torch.Tensor) +@numeric.mean.register def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor: return torch.mean(a, dim=axis, keepdim=keepdims) -@numeric.round.register(torch.Tensor) +@numeric.round.register def _(a: torch.Tensor, decimals=0) -> torch.Tensor: return torch.round(a, decimals=decimals) -@numeric.power.register(torch.Tensor) +@numeric.power.register def _(a: torch.Tensor, exponent: Union[torch.Tensor, float]) -> torch.Tensor: return torch.pow(a, exponent=exponent) -@numeric.quantile.register(torch.Tensor) +@numeric.quantile.register def _( a: torch.Tensor, q: Union[float, List[float]], @@ -217,74 +217,64 @@ def _( return torch.tensor(np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims)).to(device) -@numeric._binary_op_nowarn.register(torch.Tensor) -def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: - return operator_fn(a, b) - - -@numeric._binary_reverse_op_nowarn.register(torch.Tensor) -def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: - return operator_fn(b, a) - - -@numeric.clip.register(torch.Tensor) +@numeric.clip.register def _(a: torch.Tensor, a_min: Union[torch.Tensor, float], a_max: Union[torch.Tensor, float]) -> torch.Tensor: return torch.clip(a, a_min, a_max) -@numeric.finfo.register(torch.Tensor) +@numeric.finfo.register def _(a: torch.Tensor) -> TypeInfo: ti = torch.finfo(a.dtype) return TypeInfo(ti.eps, ti.max, ti.min) -@numeric.as_tensor_like.register(torch.Tensor) +@numeric.as_tensor_like.register def _(a: torch.Tensor, data: Any) -> torch.Tensor: return torch.as_tensor(data, device=a.device) -@numeric.item.register(torch.Tensor) +@numeric.item.register def _(a: torch.Tensor) -> Union[int, float, bool]: return a.item() -@numeric.sum.register(torch.Tensor) +@numeric.sum.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> torch.Tensor: return torch.sum(a, dim=axis, keepdim=keepdims) -@numeric.multiply.register(torch.Tensor) +@numeric.multiply.register def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: return torch.multiply(x1, x2) -@numeric.var.register(torch.Tensor) +@numeric.var.register def _( a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ddof: int = 0 ) -> torch.Tensor: return torch.var(a, dim=axis, keepdim=keepdims, correction=ddof) -@numeric.size.register(torch.Tensor) +@numeric.size.register def _(a: torch.Tensor) -> int: return torch.numel(a) -@numeric.matmul.register(torch.Tensor) +@numeric.matmul.register def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return torch.matmul(x1, x2) -@numeric.unsqueeze.register(torch.Tensor) +@numeric.unsqueeze.register def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: return torch.unsqueeze(a, dim=axis) -@numeric.transpose.register(torch.Tensor) +@numeric.transpose.register def _(a: torch.Tensor, axes: Optional[Tuple[int, ...]] = None) -> torch.Tensor: return a.t() -@numeric.argsort.register(torch.Tensor) +@numeric.argsort.register def _(a: torch.Tensor, axis: Optional[int] = None, descending=False, stable=False) -> torch.Tensor: return torch.argsort(a, dim=axis, descending=descending, stable=stable) diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index c316b5d267f..b012602d7a9 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -10,7 +10,6 @@ # limitations under the License. from __future__ import annotations -import operator from typing import Any, Optional, Tuple, TypeVar, Union from nncf.experimental.tensor.definitions import TensorDataType @@ -90,16 +89,16 @@ def __pow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) def __truediv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_op_nowarn", self, other, operator.truediv) + return Tensor(self.data / unwrap_tensor_data(other)) def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) + return Tensor(unwrap_tensor_data(other) / self.data) def __floordiv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_op_nowarn", self, other, operator.floordiv) + return Tensor(self.data // unwrap_tensor_data(other)) def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: - return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) + return Tensor(unwrap_tensor_data(other) // self.data) def __neg__(self) -> Tensor: return Tensor(-self.data) From df3b0ab5d9eb648c2eefce3651cec4e6b421987c Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Fri, 26 Apr 2024 23:41:38 +0300 Subject: [PATCH 2/5] sync argsort --- .../tensor/functions/dispatcher.py | 3 +-- nncf/experimental/tensor/functions/numeric.py | 5 +--- .../tensor/functions/numpy_numeric.py | 4 +--- .../tensor/functions/torch_numeric.py | 6 +++-- .../template_test_nncf_tensor.py | 23 ++++++++++++++++--- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/nncf/experimental/tensor/functions/dispatcher.py b/nncf/experimental/tensor/functions/dispatcher.py index 402fd9068e5..ef4e72158e4 100644 --- a/nncf/experimental/tensor/functions/dispatcher.py +++ b/nncf/experimental/tensor/functions/dispatcher.py @@ -45,8 +45,7 @@ def dispatch(cls): Retrieves the registered function for a given type. :param cls: The type to retrieve the function for. - :return: The registered function for the given type, or a function that raises a NotImplementedError - if no function is registered for type. + :return: The registered function for the given type. """ try: return registry[cls] diff --git a/nncf/experimental/tensor/functions/numeric.py b/nncf/experimental/tensor/functions/numeric.py index 0f9c31c90a5..efec8c64130 100644 --- a/nncf/experimental/tensor/functions/numeric.py +++ b/nncf/experimental/tensor/functions/numeric.py @@ -481,14 +481,11 @@ def transpose(a: Tensor, axes: Optional[Tuple[int, ...]] = None) -> Tensor: @tensor_dispatch -def argsort(a: Tensor, axis: int = -1, descending: bool = False, stable: bool = False) -> Tensor: +def argsort(a: Tensor, axis: Optional[int] = -1) -> Tensor: """ Returns the indices that would sort an array. :param a: The input tensor. :param axis: Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used. - :param descending: Controls the sorting order (ascending or descending). - :param stable: If True then the sorting routine becomes stable, preserving the order of equivalent elements. - If False, the relative order of values which compare equal is not guaranteed. True is slower. :return: Array of indices that sort a along the specified axis. """ diff --git a/nncf/experimental/tensor/functions/numpy_numeric.py b/nncf/experimental/tensor/functions/numpy_numeric.py index c5c4ff6efc3..e14555160c2 100644 --- a/nncf/experimental/tensor/functions/numpy_numeric.py +++ b/nncf/experimental/tensor/functions/numpy_numeric.py @@ -269,7 +269,5 @@ def _(a: Union[np.ndarray, np.generic], axes: Optional[Tuple[int, ...]] = None) @numeric.argsort.register -def _( - a: Union[np.ndarray, np.generic], axis: Optional[int] = None, descending=False, stable=False -) -> Union[np.ndarray, np.generic]: +def _(a: Union[np.ndarray, np.generic], axis: Optional[int] = -1) -> Union[np.ndarray, np.generic]: return np.argsort(a, axis=axis) diff --git a/nncf/experimental/tensor/functions/torch_numeric.py b/nncf/experimental/tensor/functions/torch_numeric.py index ee482f24b3c..dca83961e9e 100644 --- a/nncf/experimental/tensor/functions/torch_numeric.py +++ b/nncf/experimental/tensor/functions/torch_numeric.py @@ -282,5 +282,7 @@ def _(a: torch.Tensor, axes: Optional[Tuple[int, ...]] = None) -> torch.Tensor: @numeric.argsort.register -def _(a: torch.Tensor, axis: Optional[int] = None, descending=False, stable=False) -> torch.Tensor: - return torch.argsort(a, dim=axis, descending=descending, stable=stable) +def _(a: torch.Tensor, axis: Optional[int] = -1) -> torch.Tensor: + if axis is None: + return torch.argsort(a.flatten()) + return torch.argsort(a, dim=axis) diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 382ea2cbef9..3d4278026e6 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -1024,23 +1024,40 @@ def test_fn_transpose(self, x, ref): assert res.device == tensor.device @pytest.mark.parametrize( - "x, ref", + "x, axis, ref", ( ( [1, 2, 3, 4, 5, 6], + -1, [0, 1, 2, 3, 4, 5], ), ( [6, 5, 4, 3, 2, 1], + -1, + [5, 4, 3, 2, 1, 0], + ), + ( + [[6, 5, 4], [3, 2, 1]], + None, [5, 4, 3, 2, 1, 0], ), + ( + [[6, 5, 4], [3, 2, 1]], + -1, + [[2, 1, 0], [2, 1, 0]], + ), + ( + [[6, 5, 4], [3, 2, 1]], + 0, + [[1, 1, 1], [0, 0, 0]], + ), ), ) - def test_fn_argsort(self, x, ref): + def test_fn_argsort(self, x, axis, ref): tensor = Tensor(self.to_tensor(x)) ref_tensor = self.to_tensor(ref) - res = fns.argsort(tensor) + res = fns.argsort(tensor, axis) assert isinstance(res, Tensor) assert fns.allclose(res.data, ref_tensor) From 3f1ddaecfc430ebd5d9bc2fa13ecfe320df87f5d Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Sat, 27 Apr 2024 04:15:06 +0300 Subject: [PATCH 3/5] f --- nncf/experimental/tensor/functions/dispatcher.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/nncf/experimental/tensor/functions/dispatcher.py b/nncf/experimental/tensor/functions/dispatcher.py index ef4e72158e4..ea548ab08da 100644 --- a/nncf/experimental/tensor/functions/dispatcher.py +++ b/nncf/experimental/tensor/functions/dispatcher.py @@ -13,7 +13,7 @@ from inspect import isclass from inspect import isfunction from types import MappingProxyType -from typing import List, _GenericAlias, _UnionGenericAlias, get_type_hints +from typing import List, get_type_hints from nncf.experimental.tensor import Tensor @@ -21,11 +21,10 @@ def _get_target_types(type_alias): if isclass(type_alias): return [type_alias] - if isinstance(type_alias, (_UnionGenericAlias, _GenericAlias)): - ret = [] - for t in type_alias.__args__: - ret.extend(_get_target_types(t)) - return ret + ret = [] + for t in type_alias.__args__: + ret.extend(_get_target_types(t)) + return ret def tensor_dispatch(func): @@ -130,11 +129,11 @@ def raise_not_implemented(*args, **kw): if first_type_hint is Tensor: if return_type_hint is Tensor: wrapper = wrapper_tensor_to_tensor - elif isinstance(return_type_hint, _GenericAlias) and not isinstance(return_type_hint, _UnionGenericAlias): + elif not isclass(return_type_hint) and return_type_hint._name == "List": wrapper = wrapper_tensor_to_list else: wrapper = wrapper_tensor_to_any - elif isinstance(first_type_hint, _GenericAlias) and return_type_hint is Tensor: + elif not isclass(first_type_hint) and first_type_hint._name == "List" and return_type_hint is Tensor: wrapper = wrapper_list_to_tensor assert wrapper is not None, ( From a6e07b7eeb04b954293047bfea30b5cd30e867ac Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Sat, 27 Apr 2024 23:21:59 +0300 Subject: [PATCH 4/5] unwrap kwargs --- nncf/experimental/tensor/functions/dispatcher.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nncf/experimental/tensor/functions/dispatcher.py b/nncf/experimental/tensor/functions/dispatcher.py index ea548ab08da..fc43149eab3 100644 --- a/nncf/experimental/tensor/functions/dispatcher.py +++ b/nncf/experimental/tensor/functions/dispatcher.py @@ -77,6 +77,7 @@ def wrapper_tensor_to_tensor(*args, **kw): """ is_wrapped = any(isinstance(x, Tensor) for x in args) args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()} ret = dispatch(args[0].__class__)(*args, **kw) return Tensor(ret) if is_wrapped else ret @@ -86,6 +87,7 @@ def wrapper_tensor_to_any(*args, **kw): This wrapper unwraps Tensor arguments but doesn't specifically wrap the returned value. """ args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()} return dispatch(args[0].__class__)(*args, **kw) def wrapper_tensor_to_list(*args, **kw): @@ -95,6 +97,7 @@ def wrapper_tensor_to_list(*args, **kw): """ is_wrapped = any(isinstance(x, Tensor) for x in args) args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()} ret = dispatch(args[0].__class__)(*args, **kw) if is_wrapped: return [Tensor(x) for x in ret] @@ -106,6 +109,8 @@ def wrapper_list_to_tensor(list_of_tensors: List[Tensor], *args, **kw): This wrapper handles lists containing Tensors appropriately. """ if any(isinstance(x, Tensor) for x in list_of_tensors): + args = tuple(x.data if isinstance(x, Tensor) else x for x in args) + kw = {k: v.data if isinstance(v, Tensor) else v for k, v in kw.items()} list_of_tensors = [x.data if isinstance(x, Tensor) else x for x in list_of_tensors] return Tensor(dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw)) return dispatch(list_of_tensors[0].__class__)(list_of_tensors, *args, **kw) From b05fcec8d70415b9350925dc57e6d06ada7512b1 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Wed, 15 May 2024 17:52:37 +0300 Subject: [PATCH 5/5] clean --- nncf/experimental/tensor/README.md | 8 ++++---- tests/shared/test_templates/template_test_nncf_tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nncf/experimental/tensor/README.md b/nncf/experimental/tensor/README.md index a9d614b6f0c..7271df8cd68 100644 --- a/nncf/experimental/tensor/README.md +++ b/nncf/experimental/tensor/README.md @@ -122,10 +122,10 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) **NOTE** Type of wrapper function selected by type hint of function, supported signatures of functions: ```python - def foo(a: Tensor, *args) -> Tensor: - def foo(a: Tensor, *args) -> Any: - def foo(a: Tensor, *args) -> List[Tensor]: - def foo(a: List[Tensor], *args) -> Tensor: + def foo(a: Tensor, ...) -> Tensor: + def foo(a: Tensor, ...) -> Any: + def foo(a: Tensor, ...) -> List[Tensor]: + def foo(a: List[Tensor], ...) -> Tensor: ``` 3. Add backend specific implementation of method to corresponding module: diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 90b8aad4d0f..4e681726d31 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -1063,9 +1063,9 @@ def test_fn_argsort(self, x, axis, descending, stable, ref): ref_tensor = self.to_tensor(ref) res = fns.argsort(tensor, axis, descending, stable) - print(res.data) + assert isinstance(res, Tensor) - assert fns.allclose(res.data, ref_tensor), f"{res.data} != {ref_tensor}" + assert fns.allclose(res.data, ref_tensor) assert res.device == tensor.device zero_ten_range = [x / 100 for x in range(1001)]