From 60c28043087141bb666c46152960f570894bfd6c Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 28 Sep 2023 14:37:05 +0200 Subject: [PATCH] Fix PTNNCFCollectorTensorProcessor for cuda + cuda tests for reducers/aggregators --- nncf/common/tensor_statistics/collectors.py | 21 ++++++---- nncf/onnx/statistics/collectors.py | 27 ++++++++---- nncf/openvino/statistics/collectors.py | 27 ++++++++---- .../tensor_statistics/collectors.py | 21 ++++++---- nncf/torch/tensor_statistics/collectors.py | 42 ++++++++++++------- .../ptq/test_reducers_and_aggregators.py | 30 ++++++++++--- 6 files changed, 115 insertions(+), 53 deletions(-) diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index a3295f58632..ae9c2e536b2 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -116,7 +116,7 @@ class NNCFCollectorTensorProcessor(ABC): @staticmethod @abstractmethod - def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: """ Computes minimum of elements across dimensions of NNCFTensor. @@ -129,7 +129,7 @@ def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = Fa @staticmethod @abstractmethod - def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: """ Computes maximum of elements across dimensions of NNCFTensor. @@ -174,7 +174,7 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: @staticmethod @abstractmethod - def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: """ Computes the mean of elements across given dimensions of NNCFTensor. @@ -187,7 +187,7 @@ def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTe @staticmethod @abstractmethod - def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: """ Computes the median of elements across given dimensions of NNCFTensor. @@ -200,7 +200,9 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF @classmethod @abstractmethod - def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: + def masked_mean( + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False + ) -> NNCFTensor: """ Computes the masked mean of elements across given dimensions of NNCFTensor. @@ -216,7 +218,7 @@ def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTen @classmethod @abstractmethod def masked_median( - cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False ) -> NNCFTensor: """ Computes the masked median of elements across given dimensions of NNCFTensor. @@ -275,7 +277,10 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod @abstractmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False + tensor: NNCFTensor, + quantile: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, ) -> List[TensorElementsType]: """ Compute the quantile(s) of the data along the specified axis. @@ -295,7 +300,7 @@ def percentile( cls, tensor: NNCFTensor, percentile: Union[float, List[float]], - axis: Union[int, tuple, list], + axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False, ) -> List[TensorElementsType]: """ diff --git a/nncf/onnx/statistics/collectors.py b/nncf/onnx/statistics/collectors.py index 7af2792f003..2ce98915edc 100644 --- a/nncf/onnx/statistics/collectors.py +++ b/nncf/onnx/statistics/collectors.py @@ -33,11 +33,11 @@ class ONNXNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return ONNXNNCFTensor(np.amin(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return ONNXNNCFTensor(np.amax(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod @@ -53,16 +53,20 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: return ONNXNNCFTensor(np.maximum(x1.tensor, x2.tensor)) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims=False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: return ONNXNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims=False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: return ONNXNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims)) @classmethod def masked_mean( - cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, + x: NNCFTensor, + axis: Optional[Union[int, Tuple[int, ...], List[int]]], + mask: Optional[NNCFTensor], + keepdims: bool = False, ) -> NNCFTensor: if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) @@ -71,7 +75,11 @@ def masked_mean( @classmethod def masked_median( - cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, + x: NNCFTensor, + axis: Optional[Union[int, Tuple[int, ...], List[int]]], + mask: Optional[NNCFTensor], + keepdims: bool = False, ) -> NNCFTensor: if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) @@ -105,7 +113,10 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, Tuple, list], keepdims: bool = False + tensor: NNCFTensor, + quantile: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, ) -> List[TensorElementsType]: result = np.quantile(tensor.tensor, quantile, axis, keepdims=keepdims) return [ONNXNNCFTensor(x) for x in result] @@ -115,7 +126,7 @@ def percentile( cls, tensor: NNCFTensor, percentile: Union[float, List[float]], - axis: Union[int, Tuple, list], + axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False, ) -> List[TensorElementsType]: raise NotImplementedError() diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index b9253974b28..4672541d86b 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -50,11 +50,11 @@ class OVNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amin(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amax(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod @@ -70,16 +70,20 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: return OVNNCFTensor(np.maximum(x1.tensor, x2.tensor)) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims: bool = False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims)) @classmethod def masked_mean( - cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, + x: NNCFTensor, + axis: Optional[Union[int, Tuple[int, ...], List[int]]], + mask: Optional[NNCFTensor], + keepdims: bool = False, ) -> NNCFTensor: if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) @@ -91,7 +95,11 @@ def masked_mean( @classmethod def masked_median( - cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, + x: NNCFTensor, + axis: Optional[Union[int, Tuple[int, ...], List[int]]], + mask: Optional[NNCFTensor], + keepdims: bool = False, ) -> NNCFTensor: if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) @@ -140,7 +148,10 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, Tuple, list], keepdims: bool = False + tensor: NNCFTensor, + quantile: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, ) -> List[NNCFTensor]: result = np.quantile(tensor.tensor, quantile, axis, keepdims=keepdims) return [OVNNCFTensor(x) for x in result] @@ -150,7 +161,7 @@ def percentile( cls, tensor: NNCFTensor, percentile: Union[float, List[float]], - axis: Union[int, tuple, list], + axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False, ) -> List[TensorElementsType]: quantile = np.true_divide(percentile, 100) diff --git a/nncf/tensorflow/tensor_statistics/collectors.py b/nncf/tensorflow/tensor_statistics/collectors.py index d3dd952e9cf..a8c3da70a8b 100644 --- a/nncf/tensorflow/tensor_statistics/collectors.py +++ b/nncf/tensorflow/tensor_statistics/collectors.py @@ -37,11 +37,11 @@ class TFNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return TFNNCFTensor(tf.reduce_min(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return TFNNCFTensor(tf.reduce_max(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod @@ -57,20 +57,22 @@ def max(x1: tf.Tensor, x2: tf.Tensor) -> NNCFTensor: return TFNNCFTensor(tf.math.maximum(x1.tensor, x2.tensor)) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: return TFNNCFTensor(tf.math.reduce_mean(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: raise NotImplementedError() @classmethod - def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: + def masked_mean( + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False + ) -> NNCFTensor: raise NotImplementedError() @classmethod def masked_median( - cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False ) -> NNCFTensor: raise NotImplementedError() @@ -105,7 +107,10 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False + tensor: NNCFTensor, + quantile: Union[float, List[float]], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, ) -> List[NNCFTensor]: raise NotImplementedError() @@ -114,7 +119,7 @@ def percentile( cls, tensor: NNCFTensor, percentile: Union[float, List[float]], - axis: Union[int, tuple, list], + axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False, ) -> List[TensorElementsType]: raise NotImplementedError() diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index f1934f63b58..4089fb77f2e 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -50,11 +50,11 @@ class PTNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return PTNNCFTensor(torch.amin(x.tensor, dim=axis, keepdim=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: return PTNNCFTensor(torch.amax(x.tensor, dim=axis, keepdim=keepdims)) @staticmethod @@ -72,38 +72,44 @@ def max(cls, *args) -> NNCFTensor: return cls.reduce_max(stacked, axis=0, keepdims=False) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: return PTNNCFTensor(x.tensor.mean(dim=axis, keepdim=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: # See https://github.com/pytorch/pytorch/issues/61582 if not isinstance(axis, int): - return PTNNCFTensor(torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))) + device = x.tensor.device + result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims)) + return PTNNCFTensor(result.type(x.tensor.dtype).to(device)) return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values) @classmethod - def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple], mask: NNCFTensor, keepdims=False) -> NNCFTensor: + def masked_mean( + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False + ) -> NNCFTensor: if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) - masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor) + device = x.tensor.device + masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy()) result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) if isinstance(result, np.ma.MaskedArray): - return PTNNCFTensor(torch.tensor(result.data)) - return PTNNCFTensor(torch.tensor(result)) + result = result.data + return PTNNCFTensor(torch.tensor(result).to(device=device)) @classmethod def masked_median( - cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False + cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False ) -> NNCFTensor: # Implemented in numy as torch.masked.median is not implemented yet if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) + device = x.tensor.device masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy()) result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) if isinstance(result, np.ma.MaskedArray): - return PTNNCFTensor(torch.tensor(result.data)) - return PTNNCFTensor(torch.tensor(result)) + result = result.data + return PTNNCFTensor(torch.tensor(result).to(device=device)) @staticmethod def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: @@ -148,8 +154,12 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False + tensor: NNCFTensor, + quantile: Union[float, List[float], np.ndarray], + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, ) -> List[NNCFTensor]: + device = tensor.device # See https://github.com/pytorch/pytorch/issues/61582 if not isinstance(axis, int): result = torch.tensor( @@ -157,15 +167,15 @@ def quantile( ) else: result = torch.quantile(tensor.tensor, torch.tensor(quantile).type(tensor.tensor.dtype), axis, keepdims) - result = result.type(tensor.tensor.dtype) + result = result.type(tensor.tensor.dtype).to(device) return [PTNNCFTensor(x) for x in result] @classmethod def percentile( cls, tensor: NNCFTensor, - percentile: Union[float, List[float]], - axis: Union[int, tuple, list], + percentile: Union[float, List[float], np.ndarray], + axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False, ) -> List[TensorElementsType]: quantile = np.true_divide(percentile, 100) diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py index 8bf86713700..5f63a9da861 100644 --- a/tests/torch/ptq/test_reducers_and_aggregators.py +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC from typing import Any, List, Optional, Tuple import numpy as np @@ -30,18 +31,18 @@ from tests.common.experimental.test_reducers_and_aggregators import TemplateTestReducersAggreagtors -class TestReducersAggregators(TemplateTestReducersAggreagtors): +class BaseTestReducersAggregators(TemplateTestReducersAggreagtors, ABC): @pytest.fixture def tensor_processor(self): return PTNNCFCollectorTensorProcessor - def get_nncf_tensor(self, x: np.ndarray, dtype: Optional[Dtype] = None): + def _get_torch_tensor(self, x: np.ndarray, dtype: Optional[Dtype] = None): torch_tensor = torch.tensor(x) if dtype == Dtype.FLOAT: torch_tensor = torch_tensor.float() elif dtype == Dtype.INTEGER: torch_tensor = torch_tensor.int() - return PTNNCFTensor(torch_tensor) + return torch_tensor @pytest.fixture(scope="module") def reducers(self): @@ -58,8 +59,8 @@ def reducers(self): } def all_close(self, val, ref) -> bool: - val_ = torch.tensor(val) - ref_ = torch.tensor(ref) + val_ = val + ref_ = torch.tensor(ref).to(val_.device) return torch.allclose(val_, ref_) and val_.shape == ref_.shape def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): @@ -81,3 +82,22 @@ def expand_dims(self, tensor, dims: Tuple[int, ...]): for dim in dims: shape.insert(dim, 1) return tensor_.view(shape) + + +class TestCPUReducersAggregators(BaseTestReducersAggregators): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): + return PTNNCFTensor(self._get_torch_tensor(x, dtype=dtype).cpu()) + + def all_close(self, val: torch.Tensor, ref) -> bool: + assert not val.is_cuda + return super().all_close(val, ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available in current environment") +class TestCudaReducersAggregators(BaseTestReducersAggregators): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): + return PTNNCFTensor(self._get_torch_tensor(x, dtype=dtype).cuda()) + + def all_close(self, val: torch.Tensor, ref) -> bool: + assert val.is_cuda + return super().all_close(val, ref)