From 9da200c1fa7a7fe9b9b719e92d3926cddaa3306c Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 21 Aug 2023 16:38:11 +0200 Subject: [PATCH] Weights layout in conv/matmul layer attributes is introduced Refactor smooth quant to use weights layout Tests --- nncf/common/graph/layer_attributes.py | 24 +- nncf/openvino/graph/layer_attributes.py | 122 +++++--- .../algorithms/channel_alignment/algorithm.py | 39 ++- .../algorithms/channel_alignment/backend.py | 21 -- .../channel_alignment/openvino_backend.py | 47 --- .../algorithms/smooth_quant/algorithm.py | 12 +- .../algorithms/smooth_quant/backend.py | 11 - .../smooth_quant/openvino_backend.py | 17 - .../quantization/test_channel_alignment.py | 56 +--- .../openvino/native/test_layer_attributes.py | 295 ++++++++++++++---- tests/openvino/native/test_smooth_quant.py | 71 ++++- tests/post_training/test_templates/models.py | 5 +- .../test_templates/test_channel_alignment.py | 141 +++++++-- .../test_templates/test_smooth_quant.py | 6 +- 14 files changed, 545 insertions(+), 322 deletions(-) diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index ce934c23b8c..a29a1dca42a 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -13,7 +13,7 @@ from abc import abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union class Dtype(Enum): @@ -21,6 +21,13 @@ class Dtype(Enum): INTEGER = "int" +class LayoutElem(Enum): + C_IN = "channels_in" + C_OUT = "channels_out" + SPATIAL = "spatial" + GROUPS = "groups" + + class BaseLayerAttributes(ABC): """ This class stores base useful for some algorithms attributes @@ -30,6 +37,9 @@ class BaseLayerAttributes(ABC): def __eq__(self, __o: object) -> bool: return isinstance(__o, self.__class__) and self.__dict__ == __o.__dict__ + def get_backend_agnostic_attributes(self) -> "BaseLayerAttributes": + return self + class MultipleInputLayerAttributes(BaseLayerAttributes): def __init__(self, axis: int): @@ -109,7 +119,14 @@ def get_target_dim_for_compression(self) -> int: class LinearLayerAttributes(WeightedLayerAttributes): - def __init__(self, weight_requires_grad: bool, in_features: int, out_features: int, with_bias: bool = True): + def __init__( + self, + weight_requires_grad: bool, + in_features: int, + out_features: int, + with_bias: bool = True, + weights_layout: Optional[Tuple[LayoutElem, ...]] = None, + ): """ :param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor, @@ -120,6 +137,7 @@ def __init__(self, weight_requires_grad: bool, in_features: int, out_features: i super().__init__(weight_requires_grad, with_bias=with_bias) self.in_features = in_features self.out_features = out_features + self.weights_layout = weights_layout def get_weight_shape(self) -> List[int]: return [self.out_features, self.in_features] @@ -144,6 +162,7 @@ def __init__( transpose: bool, padding_values: Tuple[int, ...], with_bias: bool = False, + weights_layout: Optional[Tuple[LayoutElem, ...]] = None, ): """ @@ -167,6 +186,7 @@ def __init__( self.groups = groups self.transpose = transpose self.padding_values = padding_values + self.weights_layout = weights_layout def get_weight_shape(self) -> List[int]: if not self.transpose: diff --git a/nncf/openvino/graph/layer_attributes.py b/nncf/openvino/graph/layer_attributes.py index 588ddd4cd0b..488dc67c02b 100644 --- a/nncf/openvino/graph/layer_attributes.py +++ b/nncf/openvino/graph/layer_attributes.py @@ -16,12 +16,15 @@ from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.common.graph.layer_attributes import WeightedLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionBackpropDataMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype +from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype @@ -33,7 +36,7 @@ class OVLayerAttributes(BaseLayerAttributes): def __init__( self, constant_attributes: Dict[int, Any], - layer_attributes: Optional[Dict[int, BaseLayerAttributes]] = None, + layer_attributes: Optional[BaseLayerAttributes] = None, inputs_attributes: Optional[Dict[Any, Any]] = None, ): """ @@ -49,10 +52,6 @@ def __init__( def constant_attributes(self) -> Dict[int, Any]: return self._constant_attributes - @property - def layer_attributes(self) -> Optional[Dict[int, BaseLayerAttributes]]: - return self._layer_attributes - @property def input_attributes(self) -> Optional[Dict[Any, Any]]: return self._inputs_attributes @@ -67,6 +66,9 @@ def get_const_port_ids(self) -> List[int]: return list(self._constant_attributes.keys()) return [] + def get_backend_agnostic_attributes(self): + return self._layer_attributes + def get_weighted_layer_attributes( ov_node: ov.Node, ov_metatype: OVOpMetatype, constant_attributes: Dict[str, Any] @@ -79,51 +81,71 @@ def get_weighted_layer_attributes( :param constant_attributes: Constant attributes collected for the given node. :return: Weighted layer attributes for the given node. """ - retval = {} - for port_id, attrs in constant_attributes.items(): - if ov_metatype in [ - OVConvolutionMetatype, - OVDepthwiseConvolutionMetatype, - OVGroupConvolutionMetatype, - OVConvolutionBackpropDataMetatype, - OVGroupConvolutionBackpropDataMetatype, - ]: - node_attrs = ov_node.get_attributes() - kwargs = { - "weight_requires_grad": False, - "stride": tuple(node_attrs["strides"]), - "dilations": node_attrs["dilations"], - "transpose": ov_metatype in [OVConvolutionBackpropDataMetatype, OVGroupConvolutionBackpropDataMetatype], - # TODO: ticket 114378: unify pad attribute - "padding_values": tuple(node_attrs["pads_begin"] + node_attrs["pads_end"]), + if len(constant_attributes) != 1: + return None + + port_id, attrs = constant_attributes.copy().popitem() + if ov_metatype in [ + OVConvolutionMetatype, + OVDepthwiseConvolutionMetatype, + OVGroupConvolutionMetatype, + OVConvolutionBackpropDataMetatype, + OVGroupConvolutionBackpropDataMetatype, + ]: + node_attrs = ov_node.get_attributes() + kwargs = { + "weight_requires_grad": False, + "stride": tuple(node_attrs["strides"]), + "dilations": node_attrs["dilations"], + "transpose": ov_metatype in [OVConvolutionBackpropDataMetatype, OVGroupConvolutionBackpropDataMetatype], + # TODO: ticket 114378: unify pad attribute + "padding_values": tuple(node_attrs["pads_begin"] + node_attrs["pads_end"]), + } + + weights_layout_map = { + OVConvolutionMetatype: [LayoutElem.C_OUT, LayoutElem.C_IN], + OVGroupConvolutionMetatype: [LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN], + OVDepthwiseConvolutionMetatype: [LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN], + OVConvolutionBackpropDataMetatype: [LayoutElem.C_IN, LayoutElem.C_OUT], + OVGroupConvolutionBackpropDataMetatype: [LayoutElem.GROUPS, LayoutElem.C_IN, LayoutElem.C_OUT], + } + + weights_layout = weights_layout_map[ov_metatype] + weights_shape = attrs["shape"] + kwargs.update( + { + "in_channels": weights_shape[weights_layout.index(LayoutElem.C_IN)], + "out_channels": weights_shape[weights_layout.index(LayoutElem.C_OUT)], + "kernel_size": tuple(weights_shape[len(weights_layout) :]), + "groups": weights_shape[weights_layout.index(LayoutElem.GROUPS)] + if LayoutElem.GROUPS in weights_layout + else 1, } - - const_shape = attrs["shape"] - if ov_metatype in [OVConvolutionMetatype, OVConvolutionBackpropDataMetatype]: - kwargs.update( - { - "in_channels": const_shape[1], - "out_channels": const_shape[0], - "kernel_size": tuple(const_shape[2:]), - "groups": 1, - } - ) + ) + kwargs.update({"weights_layout": tuple(weights_layout + len(kwargs["kernel_size"]) * [LayoutElem.SPATIAL])}) + + return ConvolutionLayerAttributes(**kwargs) + if ov_metatype == OVMatMulMetatype: + weights_shape = attrs["shape"] + + weights_layout = [LayoutElem.SPATIAL] * (len(weights_shape) - 2) + if len(weights_shape) > 1: + transpose = attrs.get("transpose", False) + if (transpose and port_id == 0) or (not transpose and port_id == 1): + weights_layout += [LayoutElem.C_IN, LayoutElem.C_OUT] else: - kwargs.update( - { - "in_channels": const_shape[2], - "out_channels": const_shape[1], - "kernel_size": tuple(const_shape[3:]), - "groups": const_shape[0], - } - ) - if kwargs["transpose"]: - kwargs["in_channels"], kwargs["out_channels"] = kwargs["out_channels"], kwargs["in_channels"] - - common_layer_attr = ConvolutionLayerAttributes(**kwargs) + weights_layout += [LayoutElem.C_OUT, LayoutElem.C_IN] else: - common_layer_attr = GenericWeightedLayerAttributes( - weight_requires_grad=False, weight_shape=attrs.get("shape", None) - ) - retval[port_id] = common_layer_attr - return retval + weights_layout += [LayoutElem.C_IN] + + kwargs = { + "weight_requires_grad": False, + "in_features": weights_shape[weights_layout.index(LayoutElem.C_IN)], + "out_features": weights_shape[weights_layout.index(LayoutElem.C_OUT)] + if LayoutElem.C_OUT in weights_layout + else None, + "with_bias": False, + "weights_layout": weights_layout, + } + return LinearLayerAttributes(**kwargs) + return GenericWeightedLayerAttributes(weight_requires_grad=False, weight_shape=attrs.get("shape", None)) diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index dcec35d910b..6455024ca6a 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -19,14 +19,17 @@ from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.layer_attributes import LayoutElem from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging import nncf_logger from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend @@ -112,10 +115,23 @@ def filter_func(point: StatisticPoint) -> bool: assert len(tensor_collectors) == 1 stat = tensor_collectors[0].get_statistics() if stat.min_values is None or stat.max_values is None: + nncf_logger.debug( + f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} " + "because statistics were not collected for this pair." + ) continue conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity) conv_out_cont = ConvParamsContainer(conv_out, model, graph, self._backend_entity) + if ( + conv_in_cont.dims.conv_weight_out_channels_dim is None + or conv_out_cont.dims.conv_weight_out_channels_dim is None + ): + nncf_logger.debug( + f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} " + " because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algortihm yet." + ) + continue amean = (stat.max_values + stat.min_values) * 0.5 conv_in_cont.bias, conv_out_cont.bias = self._align_means( @@ -247,7 +263,7 @@ def _align_scales( return updated_conv_in_value, updated_conv_out_value, updated_bias_in_value def _check_consumer_conv_node(self, conv_node: NNCFNode) -> bool: - attrs = self._backend_entity.get_conv_layer_attributes(conv_node) + attrs = conv_node.layer_attributes.get_backend_agnostic_attributes() if attrs is None: return False # Check groups amount == 1 @@ -373,9 +389,10 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin statistic_container = StatisticPointsContainer() for conv_in, add_in, _ in self._get_node_pairs(graph): target_point, node_in = self._get_target_point_and_node_in(conv_in, add_in) + channel_axis = conv_in.metatype.output_channel_axis - reduction_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape))) - reduction_shape.remove(channel_axis) + activation_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape))) + reduction_shape = get_channel_agnostic_reduction_shape([channel_axis], activation_shape) statistic_collector = self._backend_entity.get_statistic_collector( tuple(reduction_shape), self._quantile, self.subset_size, self.inplace_statistics @@ -447,7 +464,21 @@ def __init__( bias = backend_entity.create_bias_tensor(conv_op, nncf_graph, 0) self.stated_bias = StatedTensor(bias) self._op = conv_op - self._dims = backend_entity.get_dims_descriptor(conv_op) + weights_layout = conv_op.layer_attributes.get_backend_agnostic_attributes().weights_layout + if LayoutElem.GROUPS in weights_layout: + # Using groups dim as output channels dim for ChannelAlignment algorithm + # TODO(dlyakhov) support group convolutions with groups number not in [1, out_channels] + self._dims = LayoutDescriptor( + weights_layout.index(LayoutElem.GROUPS), + weights_layout.index(LayoutElem.C_IN), + conv_op.metatype.output_channel_axis, + ) + else: + self._dims = LayoutDescriptor( + weights_layout.index(LayoutElem.C_OUT) if LayoutElem.C_OUT in weights_layout else None, + weights_layout.index(LayoutElem.C_IN), + conv_op.metatype.output_channel_axis, + ) @property def weight(self): diff --git a/nncf/quantization/algorithms/channel_alignment/backend.py b/nncf/quantization/algorithms/channel_alignment/backend.py index f10ad307d05..986f0e54d79 100644 --- a/nncf/quantization/algorithms/channel_alignment/backend.py +++ b/nncf/quantization/algorithms/channel_alignment/backend.py @@ -122,27 +122,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: (bias is added to the output tensor of that operation), `False` otherwise. """ - @staticmethod - @abstractmethod - def get_dims_descriptor(node: NNCFNode) -> LayoutDescriptor: - """ - Return weights layout descriptor of the given node if it is possible and None otherwise. - Only convolutional and linear nodes are supported. - - :param node: NNCFNode to get layout descriptor from. - :return: Weights layout descriptor of the given node if it is possible and None otherwise. - """ - - @staticmethod - @abstractmethod - def get_conv_layer_attributes(node: NNCFNode) -> Optional[ConvolutionLayerAttributes]: - """ - Returns convolutional layer attributes of given node if they are present and None otherwise. - - :param node: NNCFNode to take convolutional layer attributes from. - :return: Convolutional layer attributes of given node if they are present and None otherwise - """ - @staticmethod @abstractmethod def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any): diff --git a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py index 0301d239b0c..46e84f80db7 100644 --- a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py +++ b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py @@ -101,53 +101,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: bias_constant = get_node_with_bias_value(add_node, nncf_graph) return bias_constant is not None - @staticmethod - def get_dims_descriptor(node: NNCFNode): - if node.metatype == OVConvolutionMetatype: - return LayoutDescriptor( - conv_weight_out_channels_dim=0, - conv_weight_in_channels_dim=1, - bias_channels_dim=node.metatype.output_channel_axis, - ) - if node.metatype in [OVGroupConvolutionMetatype, OVDepthwiseConvolutionMetatype]: - # Using groups dim as output channels dim for ChannelAlignment algorithm - # TODO(dlyakhov) support group convolutions with groups number not in [1, out_channels] - return LayoutDescriptor( - conv_weight_out_channels_dim=0, - conv_weight_in_channels_dim=2, - bias_channels_dim=node.metatype.output_channel_axis, - ) - if node.metatype == OVMatMulMetatype: - if node.layer_attributes is None: - raise RuntimeError(f"Attempt to align matmul node {node.node_name} that have no any constant inputs") - layer_attributes: OVLayerAttributes = node.layer_attributes - key = layer_attributes.get_const_port_ids() - assert len(key) == 1 - key = key[0] - const_attr = layer_attributes.constant_attributes[key] - a, b = list(range(len(const_attr["shape"])))[-2:] - assert key in [a, b] - if key == a: - out_ch_dim = a - in_ch_dim = b - else: - out_ch_dim = b - in_ch_dim = a - if const_attr.get("transpose", False): - out_ch_dim, in_ch_dim = in_ch_dim, out_ch_dim - return LayoutDescriptor( - conv_weight_in_channels_dim=in_ch_dim, - conv_weight_out_channels_dim=out_ch_dim, - bias_channels_dim=node.metatype.output_channel_axis, - ) - raise RuntimeError(f"Could not retrieve dims description for node {node} with metatype {node.metatype}") - - @staticmethod - def get_conv_layer_attributes(node: NNCFNode) -> Optional[ConvolutionLayerAttributes]: - if node.layer_attributes is None: - return None - return node.layer_attributes.layer_attributes[1] - @staticmethod def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any): return create_bias_tensor(node, nncf_graph, value) diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index bd7fa7785c7..1858e3adae8 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -30,6 +30,7 @@ from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.layer_attributes import LayoutElem from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger @@ -319,7 +320,7 @@ def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode) -> TTens port_id = self._backend_entity.get_weight_tensor_port_id(node) weights_size = len(node.layer_attributes.constant_attributes[port_id]["shape"]) if weights_size > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id) + channel_axis = self._get_weight_channel_axis(node) return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis) return scale_value @@ -350,7 +351,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id: """ channel_axis = 0 if len(weights.shape) > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id) + channel_axis = self._get_weight_channel_axis(node) return self._backend_entity.process_weight_statistics(weights, channel_axis) def _create_scale_node_name(self, source_name: str, source_port_id: int) -> str: @@ -365,3 +366,10 @@ def _create_scale_node_name(self, source_name: str, source_port_id: int) -> str: unique_index = self._cached_multiply_names[scale_node_name] self._cached_multiply_names[scale_node_name] += 1 return f"{scale_node_name}_{unique_index}/sq_multiply" + + @staticmethod + def _get_weight_channel_axis(node: NNCFNode) -> int: + layer_attributes = node.layer_attributes.get_backend_agnostic_attributes() + if layer_attributes is None or layer_attributes.weights_layout is None: + return 1 + return layer_attributes.weights_layout.index(LayoutElem.C_IN) diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index dda7fc44d2d..379a9aee508 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -215,17 +215,6 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: :return: Channel axis number. """ - @staticmethod - @abstractmethod - def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: - """ - Returns axis number of the weight tensor which correspond to it channel. - - :param node: NNCFNode instance. - :param port_id: Specified input port id. - :return: Channel axis number. - """ - @staticmethod @abstractmethod def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index 415ea5a626a..93a4d81ad49 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -160,23 +160,6 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: return channel_axis - @staticmethod - def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: - channel_axis = 1 if node.metatype.const_channel_axis is None else node.metatype.const_channel_axis[0] - - if port_id not in node.layer_attributes.constant_attributes: - raise RuntimeError(f"{node.node_name} should contain {port_id} in the attributes map.") - - if node.metatype == OVMatMulMetatype: - if port_id > 1: - raise RuntimeError(f"{OVMatMulMetatype.name} can not take more than 2 input tensors.") - - if "transpose" in node.layer_attributes.constant_attributes[port_id]: - transpose = node.layer_attributes.constant_attributes[port_id]["transpose"] - channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose) - - return channel_axis - @staticmethod def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: return -2 + port_id if transpose else -1 - port_id diff --git a/tests/openvino/native/quantization/test_channel_alignment.py b/tests/openvino/native/quantization/test_channel_alignment.py index 432aa89a536..ff81d656a7d 100644 --- a/tests/openvino/native/quantization/test_channel_alignment.py +++ b/tests/openvino/native/quantization/test_channel_alignment.py @@ -11,37 +11,20 @@ from typing import Type -import pytest - -from nncf.common.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConstantMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype -from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype -from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.transformations.command_creation import OVCommandCreator from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand -from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor from nncf.quantization.algorithms.channel_alignment.openvino_backend import OVChannelAlignmentAlgoBackend from tests.post_training.test_templates.test_channel_alignment import TemplateTestChannelAlignment -def _get_nncf_node(metatype, layer_attrs): - return NNCFNode( - { - NNCFNode.ID_NODE_ATTR: 0, - NNCFNode.NODE_NAME_ATTR: "test", - NNCFNode.METATYPE_ATTR: metatype, - NNCFNode.LAYER_ATTRIBUTES: layer_attrs, - } - ) - - class TestOVChannelAlignment(TemplateTestChannelAlignment): def get_backend_cls(self) -> Type[OVChannelAlignmentAlgoBackend]: return OVChannelAlignmentAlgoBackend @@ -50,7 +33,7 @@ def target_point(self, target_type: TargetType, target_node_name: str, port_id: return OVTargetPoint(target_type, target_node_name, port_id) def convert_conv_layer_attrs(self, layer_attributes): - return OVLayerAttributes({}, {1: layer_attributes}) + return OVLayerAttributes({}, layer_attributes) def get_conv_metatype(self): return OVConvolutionMetatype @@ -69,40 +52,3 @@ def get_transformation_commands(self): def mock_command_creation_factory(self, mocker) -> None: mocker.patch("nncf.common.factory.CommandCreatorFactory.create", return_value=OVCommandCreator) - - @pytest.mark.parametrize("transpose", [False, True]) - @pytest.mark.parametrize("shape", [[3, 4], [1, 2, 3, 4]]) - @pytest.mark.parametrize("port_id", [-1, -2]) - def test_get_dims_descriptor_matmul(self, transpose, shape, port_id): - _port_id = len(shape) + port_id - node = _get_nncf_node(OVMatMulMetatype, OVLayerAttributes({_port_id: {"transpose": transpose, "shape": shape}})) - dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - - in_dims, out_dims = (0, 1) if port_id == -1 else (1, 0) - if len(shape) > 2: - in_dims += 2 - out_dims += 2 - if transpose: - in_dims, out_dims = out_dims, in_dims - - assert dims_descr.conv_weight_in_channels_dim == in_dims - assert dims_descr.conv_weight_out_channels_dim == out_dims - assert dims_descr.bias_channels_dim == OVMatMulMetatype.output_channel_axis - - def test_get_dims_descriptor_mm_no_layer_attrs(self): - node = _get_nncf_node(OVMatMulMetatype, None) - with pytest.raises(RuntimeError): - OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - - @pytest.mark.parametrize( - "metatype,ref_desc", - [ - (OVConvolutionMetatype, LayoutDescriptor(0, 1, 1)), - (OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)), - (OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)), - ], - ) - def test_get_dims_descriptor_convs(self, metatype, ref_desc): - node = _get_nncf_node(metatype, None) - dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - assert dims_descr.__dict__ == ref_desc.__dict__ diff --git a/tests/openvino/native/test_layer_attributes.py b/tests/openvino/native/test_layer_attributes.py index 6e48b437b81..075eddda6f5 100644 --- a/tests/openvino/native/test_layer_attributes.py +++ b/tests/openvino/native/test_layer_attributes.py @@ -16,6 +16,8 @@ from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.nncf_graph_builder import GraphConverter @@ -30,7 +32,13 @@ def get_conv(input_1, node_name, input_shape, kernel=None): return opset.convolution(input_1, kernel, strides, pads, pads, dilations, name=node_name) -def get_group_conv(input_1, node_name, input_shape, kernel=None): +def get_group_conv(input_1, node_name, input_shape): + shape = (input_shape[1] // 2, input_shape[1], 2, 1, 1) + kernel = opset.constant(np.ones(shape), dtype=np.float32, name="Const") + return get_depthwise_conv(input_1, node_name, input_shape, kernel) + + +def get_depthwise_conv(input_1, node_name, input_shape, kernel=None): strides = [1, 2] pads = [0, 1] dilations = [3, 1] @@ -79,13 +87,49 @@ def get_matmul_a(input_1, node_name, input_shape): return get_matmul(input_1, node_name, input_shape, transpose_a=True) -def get_matmul(input_1, node_name, input_shape, transpose_a=False, transpose_b=False): +def get_matmul_b_swapped(input_1, node_name, input_shape): + return get_matmul(input_1, node_name, input_shape, transpose_b=True, swap_inputs=True) + + +def get_matmul_a_swapped(input_1, node_name, input_shape): + return get_matmul(input_1, node_name, input_shape, transpose_a=True, swap_inputs=True) + + +def get_matmul(input_1, node_name, input_shape, transpose_a=False, transpose_b=False, swap_inputs=False): channel_position = 1 if transpose_a else -1 data_shape = [input_shape[channel_position], 1] if transpose_b: data_shape = data_shape[::-1] data = opset.constant(np.ones(tuple(data_shape)), dtype=np.float32, name="Const") - return opset.matmul(input_1, data, transpose_a=transpose_a, transpose_b=transpose_b, name=node_name) + a, b = (data, input_1) if swap_inputs else (input_1, data) + return opset.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=node_name) + + +def get_1d_matmul(input_1, node_name, input_shape): + data_shape = (input_shape[-1],) + data = opset.constant(np.ones(tuple(data_shape)), dtype=np.float32, name="Const") + return opset.matmul(input_1, data, transpose_a=False, transpose_b=False, name=node_name) + + +def get_add(input_1, node_name, input_shape): + data_shape = [1] * len(input_shape) + data = opset.constant(np.ones(tuple(data_shape)), dtype=np.float32, name="Const") + return opset.add(input_1, data, name=node_name) + + +def get_lstm(input_1, node_name, input_shape): + batch_size, _, input_size = input_shape + hidden_size = 4 + num_directions = 1 + hs = opset.constant(np.ones((batch_size, num_directions, hidden_size)), dtype=np.float32, name="hs") + cs = opset.constant(np.ones((batch_size, num_directions, hidden_size)), dtype=np.float32, name="cs") + seq_len_const = opset.constant(np.ones((batch_size)), dtype=np.int32, name="seq_len_const") + w = opset.constant(np.ones((num_directions, 4 * hidden_size, input_size)), dtype=np.float32, name="w") + r = opset.constant(np.ones((num_directions, 4 * hidden_size, hidden_size)), dtype=np.float32, name="r") + b = opset.constant(np.ones((num_directions, 4 * hidden_size)), dtype=np.float32, name="b") + return opset.lstm_sequence( + input_1, hs, cs, seq_len_const, w, r, b, hidden_size, "forward", name=node_name + ).outputs()[0] def get_shape_node(input_, op_name, input_shape): @@ -108,19 +152,18 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 3, 3), OVLayerAttributes( {1: {"name": "Const", "shape": (4, 3, 2, 1)}}, - { - 1: ConvolutionLayerAttributes( - weight_requires_grad=False, - in_channels=3, - out_channels=4, - kernel_size=(2, 1), - stride=(1, 1), - dilations=[1, 1], - groups=1, - transpose=False, - padding_values=(0, 0, 0, 0), - ), - }, + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=3, + out_channels=4, + kernel_size=(2, 1), + stride=(1, 1), + dilations=[1, 1], + groups=1, + transpose=False, + padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), + ), {}, ), ), @@ -129,40 +172,70 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 3, 3), OVLayerAttributes( {1: {"name": "Const", "shape": (4, 3, 1, 1)}}, - { - 1: ConvolutionLayerAttributes( - weight_requires_grad=False, - in_channels=3, - out_channels=4, - kernel_size=(1, 1), - stride=(1, 1), - dilations=[1, 1], - groups=1, - transpose=False, - padding_values=(0, 0, 0, 0), - ), - }, + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=3, + out_channels=4, + kernel_size=(1, 1), + stride=(1, 1), + dilations=[1, 1], + groups=1, + transpose=False, + padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), + ), {}, ), ), ( - get_group_conv, + get_depthwise_conv, (1, 3, 3, 3), OVLayerAttributes( {1: {"name": "Const", "shape": (3, 3, 1, 1, 1)}}, - { - 1: ConvolutionLayerAttributes( - weight_requires_grad=False, - in_channels=1, - out_channels=3, - kernel_size=(1, 1), - stride=(1, 2), - dilations=[3, 1], - groups=3, - transpose=False, - padding_values=(0, 1, 0, 1), + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=1, + out_channels=3, + kernel_size=(1, 1), + stride=(1, 2), + dilations=[3, 1], + groups=3, + transpose=False, + padding_values=(0, 1, 0, 1), + weights_layout=( + LayoutElem.GROUPS, + LayoutElem.C_OUT, + LayoutElem.C_IN, + LayoutElem.SPATIAL, + LayoutElem.SPATIAL, ), - }, + ), + {}, + ), + ), + ( + get_group_conv, + (1, 10, 3, 3), + OVLayerAttributes( + {1: {"name": "Const", "shape": (5, 10, 2, 1, 1)}}, + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=2, + out_channels=10, + kernel_size=(1, 1), + stride=(1, 2), + dilations=[3, 1], + groups=5, + transpose=False, + padding_values=(0, 1, 0, 1), + weights_layout=( + LayoutElem.GROUPS, + LayoutElem.C_OUT, + LayoutElem.C_IN, + LayoutElem.SPATIAL, + LayoutElem.SPATIAL, + ), + ), {}, ), ), @@ -171,19 +244,18 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 3, 3), OVLayerAttributes( {1: {"name": "Const", "shape": (3, 4, 2, 1)}}, - { - 1: ConvolutionLayerAttributes( - weight_requires_grad=False, - in_channels=3, - out_channels=4, - kernel_size=(2, 1), - stride=(1, 1), - dilations=[1, 1], - groups=1, - transpose=True, - padding_values=(0, 0, 0, 0), - ), - }, + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=3, + out_channels=4, + kernel_size=(2, 1), + stride=(1, 1), + dilations=[1, 1], + groups=1, + transpose=True, + padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_IN, LayoutElem.C_OUT, LayoutElem.SPATIAL, LayoutElem.SPATIAL), + ), {}, ), ), @@ -192,19 +264,24 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 3, 3), OVLayerAttributes( {1: {"name": "Const", "shape": (3, 1, 3, 1, 1)}}, - { - 1: ConvolutionLayerAttributes( - weight_requires_grad=False, - in_channels=1, - out_channels=3, - kernel_size=(1, 1), - stride=(1, 2), - dilations=[3, 1], - groups=3, - transpose=True, - padding_values=(0, 1, 0, 1), + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=1, + out_channels=3, + kernel_size=(1, 1), + stride=(1, 2), + dilations=[3, 1], + groups=3, + transpose=True, + padding_values=(0, 1, 0, 1), + weights_layout=( + LayoutElem.GROUPS, + LayoutElem.C_IN, + LayoutElem.C_OUT, + LayoutElem.SPATIAL, + LayoutElem.SPATIAL, ), - }, + ), {}, ), ), @@ -214,7 +291,13 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 4), OVLayerAttributes( {1: {"name": "Const", "shape": (1, 4), "transpose": True}}, - {1: GenericWeightedLayerAttributes(False, (1, 4))}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=4, + out_features=1, + with_bias=False, + weights_layout=[LayoutElem.C_OUT, LayoutElem.C_IN], + ), {"transpose": False}, ), ), @@ -223,10 +306,84 @@ def get_one_layer_model(op_name: str, node_creator, input_shape): (1, 3, 4), OVLayerAttributes( {1: {"name": "Const", "shape": (3, 1), "transpose": False}}, - {1: GenericWeightedLayerAttributes(False, (3, 1))}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=3, + out_features=1, + with_bias=False, + weights_layout=[LayoutElem.C_IN, LayoutElem.C_OUT], + ), + {"transpose": True}, + ), + ), + ( + get_matmul_a_swapped, + (1, 3, 4), + OVLayerAttributes( + {0: {"name": "Const", "shape": (3, 1), "transpose": True}}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=3, + out_features=1, + with_bias=False, + weights_layout=[LayoutElem.C_IN, LayoutElem.C_OUT], + ), + {"transpose": False}, + ), + ), + ( + get_matmul_b_swapped, + (1, 3, 4), + OVLayerAttributes( + {0: {"name": "Const", "shape": (1, 4), "transpose": False}}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=4, + out_features=1, + with_bias=False, + weights_layout=[LayoutElem.C_OUT, LayoutElem.C_IN], + ), {"transpose": True}, ), ), + ( + get_1d_matmul, + (1, 3, 4), + OVLayerAttributes( + {1: {"name": "Const", "shape": (4,), "transpose": False}}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=4, + out_features=None, + with_bias=False, + weights_layout=[LayoutElem.C_IN], + ), + {"transpose": False}, + ), + ), + ( + get_add, + (1, 3, 4, 5), + OVLayerAttributes( + {1: {"name": "Const", "shape": (1, 1, 1, 1)}}, + GenericWeightedLayerAttributes(False, weight_shape=(1, 1, 1, 1)), + {}, + ), + ), + ( + get_lstm, + (2, 3, 4), + OVLayerAttributes( + { + 1: {"name": "hs", "shape": (2, 1, 4)}, + 2: {"name": "cs", "shape": (2, 1, 4)}, + 4: {"name": "w", "shape": (1, 16, 4)}, + 5: {"name": "r", "shape": (1, 16, 4)}, + }, + None, + {}, + ), + ), ], ) def test_layer_attributes(node_creator, input_shape, ref_layer_attrs): diff --git a/tests/openvino/native/test_smooth_quant.py b/tests/openvino/native/test_smooth_quant.py index 32e51f34079..71bca07983e 100644 --- a/tests/openvino/native/test_smooth_quant.py +++ b/tests/openvino/native/test_smooth_quant.py @@ -17,6 +17,9 @@ import pytest import torch +from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype @@ -82,18 +85,68 @@ def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port return super().test_get_activation_channel_axis(node_metatype, layer_attributes, port_id, reference_value) @pytest.mark.parametrize( - "node_metatype, layer_attributes, port_id, reference_value", + "node_metatype,layer_attributes,reference_value", ( - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 1, -2), - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": True}}), 1, -1), - (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": False}}), 0, -1), - (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": True}}), 0, -2), - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 2, RuntimeError), - (OVConvolutionMetatype, OVLayerAttributes({1: {}}), 1, 0), + ( + OVMatMulMetatype, + OVLayerAttributes( + {}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=10, + with_bias=False, + weights_layout=[LayoutElem.C_OUT, LayoutElem.C_IN], + ), + ), + 1, + ), + ( + OVMatMulMetatype, + OVLayerAttributes( + {}, + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=None, + with_bias=False, + weights_layout=[LayoutElem.C_IN], + ), + ), + 0, + ), + ( + OVConvolutionMetatype, + OVLayerAttributes( + {}, + ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=5, + out_channels=10, + kernel_size=(5, 5), + stride=(1, 1), + dilations=(1, 1), + groups=1, + transpose=False, + padding_values=[1, 1, 1, 1], + with_bias=False, + weights_layout=[LayoutElem.SPATIAL, LayoutElem.SPATIAL, LayoutElem.C_IN, LayoutElem.C_OUT], + ), + ), + 2, + ), + ( + OVMatMulMetatype, + OVLayerAttributes( + {}, + None, + ), + 1, + ), ), ) - def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): - return super().test_get_weight_channel_axis(node_metatype, layer_attributes, port_id, reference_value) + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value): + return super().test_get_weight_channel_axis(node_metatype, layer_attributes, reference_value) @staticmethod def get_matmul_metatype(): diff --git a/tests/post_training/test_templates/models.py b/tests/post_training/test_templates/models.py index 546a4104318..43246691048 100644 --- a/tests/post_training/test_templates/models.py +++ b/tests/post_training/test_templates/models.py @@ -171,6 +171,7 @@ def __init__( conv_metatype, add_metatype, conv_layer_attrs=None, + conv_layer_attrs_1=None, both_biases=True, add_layer_attrs=None, constant_metatype=ConstantTestMetatype, @@ -187,6 +188,8 @@ def __init__( # | # Add_2 # Output_1 + if conv_layer_attrs_1 is None: + conv_layer_attrs_1 = conv_layer_attrs nodes = [ NodeWithType("Input_1", InputNoopMetatype), NodeWithType("Conv_1_W", constant_metatype), @@ -194,7 +197,7 @@ def __init__( NodeWithType("Add_1_W", constant_metatype), NodeWithType("Add_1", add_metatype, layer_attributes=add_layer_attrs), NodeWithType("Conv_2_W", constant_metatype), - NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs), + NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs_1), NodeWithType("Output_1", OutputNoopMetatype), ] if both_biases: diff --git a/tests/post_training/test_templates/test_channel_alignment.py b/tests/post_training/test_templates/test_channel_alignment.py index d3b6dd045e5..9c2a24cd781 100644 --- a/tests/post_training/test_templates/test_channel_alignment.py +++ b/tests/post_training/test_templates/test_channel_alignment.py @@ -17,6 +17,8 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationType @@ -27,6 +29,7 @@ from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment +from nncf.quantization.algorithms.channel_alignment.algorithm import ConvParamsContainer from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor from tests.post_training.test_templates.models import NNCFGraphCA @@ -46,9 +49,47 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ) +DEPTHWISE_CONV_LAYER_ATTR = ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=5, + out_channels=1, + kernel_size=(5, 5), + stride=(1, 1), + dilations=(1, 1), + groups=5, + transpose=False, + padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), +) + +MATMUL_LAYER_METATYPES = [ + # 2D + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=10, + with_bias=False, + weights_layout=[LayoutElem.C_IN, LayoutElem.C_OUT], + ), + # 1D + LinearLayerAttributes( + weight_requires_grad=False, in_features=5, out_features=None, with_bias=False, weights_layout=[LayoutElem.C_IN] + ), + # 5D + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=None, + with_bias=False, + weights_layout=[LayoutElem.SPATIAL, LayoutElem.SPATIAL, LayoutElem.SPATIAL, LayoutElem.C_IN, LayoutElem.C_OUT], + ), +] + + INVALID_CONSUMER_CONV_LAYER_ATTRS = [ ConvolutionLayerAttributes( weight_requires_grad=False, @@ -60,6 +101,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -71,6 +113,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -82,6 +125,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -93,6 +137,7 @@ groups=1, transpose=False, padding_values=(1, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ] @@ -107,6 +152,7 @@ groups=5, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ) @@ -232,9 +278,8 @@ def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in): (INVALID_CONV_LAYER_ATTR, INVALID_CONV_LAYER_ATTR, False), ] ) - GET_NODES_TEST_CASES.extend( - [(VALID_CONV_LAYER_ATTR, None, False), (None, VALID_CONV_LAYER_ATTR, False), (None, None, False)] - ) + GET_NODES_TEST_CASES.extend([(attr, VALID_CONV_LAYER_ATTR, True) for attr in MATMUL_LAYER_METATYPES]) + GET_NODES_TEST_CASES.append((None, VALID_CONV_LAYER_ATTR, False)) @pytest.mark.parametrize("first_conv_attrs,second_conv_attrs,ref_match", GET_NODES_TEST_CASES) def test_get_node_pairs(self, first_conv_attrs, second_conv_attrs, ref_match): @@ -260,16 +305,21 @@ def test_get_node_pairs(self, first_conv_attrs, second_conv_attrs, ref_match): else: assert len(pairs) == 0 - def _get_nncf_graph(self, num_biases: int) -> NNCFGraph: - cla = self.convert_conv_layer_attrs(VALID_CONV_LAYER_ATTR) + def _get_nncf_graph( + self, num_biases: int, conv_layer_attrs=DEPTHWISE_CONV_LAYER_ATTR, conv_layer_attrs_1=VALID_CONV_LAYER_ATTR + ) -> NNCFGraph: + cla = self.convert_conv_layer_attrs(conv_layer_attrs) + cla_1 = self.convert_conv_layer_attrs(conv_layer_attrs_1) + if num_biases == 0: - return NNCFGraphCA(self.get_conv_metatype(), cla).nncf_graph + return NNCFGraphCA(self.get_conv_metatype(), cla, cla_1).nncf_graph bla = self.get_add_layer_attrs() if num_biases == 1: return NNCFGraphCAWithBias( self.get_conv_metatype(), self.get_add_metatype(), cla, + cla_1, both_biases=False, constant_metatype=self.get_constant_metatype(), add_layer_attrs=bla, @@ -278,20 +328,37 @@ def _get_nncf_graph(self, num_biases: int) -> NNCFGraph: self.get_conv_metatype(), self.get_add_metatype(), cla, + cla_1, both_biases=True, add_layer_attrs=bla, constant_metatype=self.get_constant_metatype(), ).nncf_graph + @staticmethod + def _get_constant_lambda(value, counter=False): + if counter: + _state = 0 + + def f(*args, **kwargs): + if not counter: + return value + nonlocal _state + _state += 1 + return value + str(_state) + + return f + + @pytest.mark.parametrize("one_dim_mm", [False, True]) @pytest.mark.parametrize("empty_statistics", [False, True]) @pytest.mark.parametrize("num_biases", [0, 1, 2]) # pylint: disable=too-many-statements # pylint: disable=too-many-branches - def test_transformation_layout(self, empty_statistics, num_biases, mocker): + def test_transformation_layout(self, one_dim_mm, empty_statistics, num_biases, mocker): mocked_transformer = mocker.MagicMock() self.mock_model_transformer_factory(mocker, mocked_transformer) - nncf_graph = self._get_nncf_graph(num_biases) + first_conv_layer_attrs = DEPTHWISE_CONV_LAYER_ATTR if not one_dim_mm else MATMUL_LAYER_METATYPES[1] + nncf_graph = self._get_nncf_graph(num_biases, first_conv_layer_attrs) self.mock_nncf_graph_factory(mocker, nncf_graph) self.mock_command_creation_factory(mocker) @@ -308,19 +375,6 @@ class TestTensorStats(MinMaxTensorStatistic): def tensor_eq(*args, **kwargs): return True - def get_constant_lambda(value, counter=False): - if counter: - _state = 0 - - def f(*args, **kwargs): - if not counter: - return value - nonlocal _state - _state += 1 - return value + str(_state) - - return f - algorithm = ChannelAlignment() tensor_collector = TensorCollector() if empty_statistics: @@ -328,18 +382,16 @@ def f(*args, **kwargs): else: stat_value = (np.array([-1], dtype=np.int32), np.array([2], dtype=np.int32)) - tensor_collector.get_statistics = get_constant_lambda(TestTensorStats(*stat_value)) + tensor_collector.get_statistics = self._get_constant_lambda(TestTensorStats(*stat_value)) statistic_points.add_statistic_point(StatisticPoint(target_point, tensor_collector, algorithm._algorithm_key)) class MockBackend(backend_cls): pass ref_weights_val = "ref_weights_val" - MockBackend.get_weight_value = get_constant_lambda(ref_weights_val, True) + MockBackend.get_weight_value = self._get_constant_lambda(ref_weights_val, True) ref_bias_val = "ref_bias_val" - MockBackend.get_bias_value = get_constant_lambda(ref_bias_val, True) - ref_dims_descr = "ref_dims_descr" - MockBackend.get_dims_descriptor = get_constant_lambda(ref_dims_descr, True) + MockBackend.get_bias_value = self._get_constant_lambda(ref_bias_val, True) algorithm._backend_entity = MockBackend algorithm._set_backend_entity = mocker.MagicMock() @@ -358,7 +410,7 @@ class MockBackend(backend_cls): ) algorithm.apply(None, nncf_graph, statistic_points) - if empty_statistics: + if empty_statistics or one_dim_mm: assert algorithm._align_means.call_count == 0 assert algorithm._align_scales.call_count == 0 mocked_transformer.transform.assert_called_once() @@ -367,12 +419,15 @@ class MockBackend(backend_cls): return assert algorithm._align_means.call_count == 1 + + ref_dims = LayoutDescriptor(0, 2, 1) + ref_dims_1 = LayoutDescriptor(0, 1, 1) args = [ np.zeros((1, 1, 1, 1)), np.zeros((1, 1, 1, 1)), ref_weights_val + "2", np.array(0.5, dtype=np.float32), - ref_dims_descr + "2", + ref_dims_1, ] for i in range(num_biases): args[i] = f"ref_bias_val{i + 1}" @@ -385,8 +440,8 @@ class MockBackend(backend_cls): assert args[1] == ref_weights_val + "2" assert args[2] == ref_bias_in_after_align assert ((args[3] - 3) < EPS).all() - assert args[4] == ref_dims_descr + "1" - assert args[5] == ref_dims_descr + "2" + assert args[4] == ref_dims + assert args[5] == ref_dims_1 assert args[6] < EPS mocked_transformer.transform.assert_called_once() @@ -497,3 +552,29 @@ def test_statistic_collectors(self, inplace_ref, q_ref): assert isinstance(aggr, MedianAggregator) assert aggr.num_samples == num_samples_ref assert not aggr._use_per_sample_stats + + @pytest.mark.parametrize( + "layer_attributes,ref_layout_desc", + [ + (VALID_CONV_LAYER_ATTR, LayoutDescriptor(0, 1, 1)), + (DEPTHWISE_CONV_LAYER_ATTR, LayoutDescriptor(0, 2, 1)), + (MATMUL_LAYER_METATYPES[0], LayoutDescriptor(1, 0, 1)), + (MATMUL_LAYER_METATYPES[1], LayoutDescriptor(None, 0, 1)), + (MATMUL_LAYER_METATYPES[2], LayoutDescriptor(4, 3, 1)), + ], + ) + def test_conv_params_dims(self, layer_attributes, ref_layout_desc): + backend_cls = self.get_backend_cls() + + class MockBackend(backend_cls): + pass + + ref_weights_val = "ref_weights_val" + MockBackend.get_weight_value = self._get_constant_lambda(ref_weights_val) + ref_bias_val = "ref_bias_val" + MockBackend.get_bias_value = self._get_constant_lambda(ref_bias_val) + nncf_graph = NNCFGraphCAWithBias( + self.get_conv_metatype(), self.get_add_metatype(), self.convert_conv_layer_attrs(layer_attributes) + ).nncf_graph + cont = ConvParamsContainer(nncf_graph.get_node_by_name("/Conv_1_0"), None, nncf_graph, MockBackend) + assert cont.dims == ref_layout_desc diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index 42fe17e01b0..ab20510e67d 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -226,9 +226,7 @@ def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port assert activation_channel_axis == reference_value - def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): - backend = self.get_backend() - + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value): attributes = { NNCFNode.METATYPE_ATTR: node_metatype, NNCFNode.LAYER_ATTRIBUTES: layer_attributes, @@ -239,7 +237,7 @@ def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, try: # pylint: disable=protected-access - activation_channel_axis = backend.get_weight_channel_axis(node, port_id) + activation_channel_axis = SmoothQuant._get_weight_channel_axis(node) except RuntimeError as e: if isinstance(e, reference_value): pytest.xfail("Expected exception")