From aa4408b5470be7e168eb91d434be3b11e96f12a7 Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Fri, 28 Jul 2023 15:15:15 +0200 Subject: [PATCH] [PTQ] Merge quantizer insertion function for weights and activations (#2001) ### Changes Merge quantizer insertion function for weights and activations. ### Reason for changes Reduce code. ### Related tickets N/A ### Tests N/A --- nncf/common/quantization/config_assignment.py | 20 +++---- .../quantizer_propagation/solver.py | 11 +--- nncf/common/quantization/structs.py | 10 +++- .../algorithms/min_max/algorithm.py | 13 ++--- .../algorithms/min_max/backend.py | 20 +------ .../algorithms/min_max/onnx_backend.py | 25 ++------- .../algorithms/min_max/openvino_backend.py | 11 +--- .../algorithms/min_max/torch_backend.py | 13 +---- .../test_templates/test_quantizer_config.py | 56 +++++++++++-------- 9 files changed, 62 insertions(+), 117 deletions(-) diff --git a/nncf/common/quantization/config_assignment.py b/nncf/common/quantization/config_assignment.py index a2c78b95236..abaddf95427 100644 --- a/nncf/common/quantization/config_assignment.py +++ b/nncf/common/quantization/config_assignment.py @@ -94,19 +94,15 @@ def assign_qconfig_lists_to_modules( qconfig_list = [default_qconfig] elif HWConfig.is_qconf_list_corresponding_to_unspecified_op(qconfig_list): continue # The module will not have its weights quantized - try: - local_constraints = global_weight_constraints - for overridden_scope, scoped_override_dict in scope_overrides_dict.items(): - if matches_any(node.node_name, overridden_scope): - scope_constraints = QuantizationConstraints.from_config_dict(scoped_override_dict) - local_constraints = local_constraints.get_updated_constraints(scope_constraints) - qconfig_list = local_constraints.constrain_qconfig_list(qconfig_list) - except RuntimeError as e: - err_msg = "Quantization parameter constraints specified in NNCF config are incompatible with HW " - err_msg += "capabilities as specified in HW config type '{}'. ".format(hw_config.target_device) - err_msg += "First conflicting quantizer location: {}".format(str(node.node_name)) - raise RuntimeError(err_msg) from e + local_constraints = global_weight_constraints + for overridden_scope, scoped_override_dict in scope_overrides_dict.items(): + if matches_any(node.node_name, overridden_scope): + scope_constraints = QuantizationConstraints.from_config_dict(scoped_override_dict) + local_constraints = local_constraints.get_updated_constraints(scope_constraints) + qconfig_list = local_constraints.constrain_qconfig_list( + node.node_name, hw_config.target_device, qconfig_list + ) retval[node] = qconfig_list return retval diff --git a/nncf/common/quantization/quantizer_propagation/solver.py b/nncf/common/quantization/quantizer_propagation/solver.py index 435710060e0..e078ae2d019 100644 --- a/nncf/common/quantization/quantizer_propagation/solver.py +++ b/nncf/common/quantization/quantizer_propagation/solver.py @@ -1069,14 +1069,9 @@ def _filter_qconfigs_according_to_scope( local_constraints = local_constraints.get_updated_constraints(scope_constraints) if self._hw_config is not None: - try: - constrained_config_list = local_constraints.constrain_qconfig_list(qconf_list) - except RuntimeError as e: - err_msg = "Quantization parameter constraints specified in NNCF config are incompatible with HW " - err_msg += "capabilities as specified in HW config type '{}'. ".format(self._hw_config.target_device) - err_msg += "First conflicting quantizer location: " - err_msg += nncf_node_name - raise RuntimeError(err_msg) from e + constrained_config_list = local_constraints.constrain_qconfig_list( + nncf_node_name, self._hw_config.target_device, qconf_list + ) else: constrained_config_list = [local_constraints.apply_constraints_to(qconfig) for qconfig in qconf_list] diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py index 95e5ea3e9d8..f0eda4e8f6a 100644 --- a/nncf/common/quantization/structs.py +++ b/nncf/common/quantization/structs.py @@ -18,6 +18,7 @@ from nncf.common.utils.api_marker import api from nncf.config.schemata.defaults import QUANTIZATION_BITS from nncf.config.schemata.defaults import QUANTIZATION_PER_CHANNEL +from nncf.parameters import TargetDevice @api() @@ -224,7 +225,9 @@ def from_config_dict(cls, config_dict: Dict) -> "QuantizationConstraints": signedness_to_force=config_dict.get("signed"), ) - def constrain_qconfig_list(self, quantizer_config_list: List[QuantizerConfig]) -> List[QuantizerConfig]: + def constrain_qconfig_list( + self, node_name: NNCFNodeName, target_device: TargetDevice, quantizer_config_list: List[QuantizerConfig] + ) -> List[QuantizerConfig]: assert quantizer_config_list is not None constrained_quantizer_config_list = list(filter(self.is_config_compatible, quantizer_config_list)) @@ -233,7 +236,10 @@ def constrain_qconfig_list(self, quantizer_config_list: List[QuantizerConfig]) - # It means that the qconfig from overrides must be selected as final config # even if it is not valid in hw-config. if not constrained_quantizer_config_list: - raise RuntimeError() + err_msg = f"Quantization parameter constraints specified in NNCF config are incompatible \ + with HW capabilities as specified in HW config type '{target_device}'. \ + First conflicting quantizer location: {node_name}" + raise ValueError(err_msg) return constrained_quantizer_config_list diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 68d3b4f5cf0..922e31b5ec1 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -644,7 +644,7 @@ def filter_func(point: StatisticPoint) -> bool: q_group = QuantizerGroup.ACTIVATIONS narrow_range = get_quantizer_narrow_range(qconfig, q_group) parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range) - command = self._backend_entity.create_activation_quantizer_insertion_command( + command = self._backend_entity.create_quantizer_insertion_command( nncf_graph, quantization_target_point, qconfig, parameters ) transformation_layout.register(command) @@ -670,14 +670,9 @@ def filter_func(point: StatisticPoint) -> bool: narrow_range = get_quantizer_narrow_range(qconfig, quant_group) statistics = tensor_collector.get_statistics() parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, narrow_range, half_range) - if quantization_target_point.is_weight_target_point(): - command = self._backend_entity.create_weight_quantizer_insertion_command( - nncf_graph, quantization_target_point, qconfig, parameters - ) - else: - command = self._backend_entity.create_activation_quantizer_insertion_command( - nncf_graph, quantization_target_point, qconfig, parameters - ) + command = self._backend_entity.create_quantizer_insertion_command( + nncf_graph, quantization_target_point, qconfig, parameters + ) transformation_layout.register(command) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index f56df039460..342486323de 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -127,25 +127,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod @abstractmethod - def create_activation_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: TargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> TransformationCommand: - """ - Returns backend-specific quantizer insertion command. - - :param nncf_graph: NNCFGraph to get input/output shapes for the target point. - :param target_point: Target location for the correction. - :param quantizer_config: QuantizerConfig instance for the current layer. - :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. - :return: Backend-specific TransformationCommand for the quantizer insertion operation. - """ - - @staticmethod - @abstractmethod - def create_weight_quantizer_insertion_command( + def create_quantizer_insertion_command( nncf_graph: NNCFGraph, target_point: TargetPoint, quantizer_config: QuantizerConfig, diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 94397fde58b..80286800181 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -96,32 +96,15 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - return ONNXTargetPoint(target_type, target_node_name, port_id) @staticmethod - def create_activation_quantizer_insertion_command( + def create_quantizer_insertion_command( nncf_graph: NNCFGraph, target_point: ONNXTargetPoint, quantizer_config: QuantizerConfig, parameters: FakeQuantizeParameters, - ) -> ONNXQuantizerInsertionCommand: - nncf_input_node_next_nodes = ONNXMinMaxAlgoBackend._get_input_edges_mapping(nncf_graph) - axis = ONNXMinMaxAlgoBackend._get_axis(nncf_graph, target_point, quantizer_config) + ): tensor_type = np.int8 if np.any(parameters.input_low < 0) else np.uint8 - onnx_parameters = convert_fq_params_to_onnx_params(parameters, quantizer_config.num_bits, tensor_type, axis) - return ONNXQuantizerInsertionCommand(target_point, nncf_input_node_next_nodes, onnx_parameters) - - @staticmethod - def create_weight_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: ONNXTargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> ONNXQuantizerInsertionCommand: - if quantizer_config.signedness_to_force is False: - raise ValueError( - "The HW expects to have signed quantization of weights, " - "while the quantizer configuration for weights contains signedness_to_force=False." - ) - - tensor_type = np.int8 # The weight is restricted to have only signed range + if target_point.is_weight_target_point(): + tensor_type = np.int8 # The weight is restricted to have only signed range nncf_input_node_next_nodes = ONNXMinMaxAlgoBackend._get_input_edges_mapping(nncf_graph) axis = ONNXMinMaxAlgoBackend._get_axis(nncf_graph, target_point, quantizer_config) onnx_parameters = convert_fq_params_to_onnx_params(parameters, quantizer_config.num_bits, tensor_type, axis) diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 894d63a5795..008bff1357d 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -103,16 +103,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - return OVTargetPoint(target_type, target_node_name, port_id) @staticmethod - def create_activation_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: OVTargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> OVQuantizerInsertionCommand: - return OVQuantizerInsertionCommand(target_point, parameters) - - @staticmethod - def create_weight_quantizer_insertion_command( + def create_quantizer_insertion_command( nncf_graph: NNCFGraph, target_point: OVTargetPoint, quantizer_config: QuantizerConfig, diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 45a256a0a26..a0357412d91 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -128,18 +128,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) @staticmethod - def create_activation_quantizer_insertion_command( - nncf_graph: NNCFGraph, - target_point: PTTargetPoint, - quantizer_config: QuantizerConfig, - parameters: FakeQuantizeParameters, - ) -> PTInsertionCommand: - return PTMinMaxAlgoBackend._create_quantizer_insertion_command( - nncf_graph, target_point, quantizer_config, parameters - ) - - @staticmethod - def create_weight_quantizer_insertion_command( + def create_quantizer_insertion_command( nncf_graph: NNCFGraph, target_point: PTTargetPoint, quantizer_config: QuantizerConfig, diff --git a/tests/post_training/test_templates/test_quantizer_config.py b/tests/post_training/test_templates/test_quantizer_config.py index 2626a432065..e614138d0a9 100644 --- a/tests/post_training/test_templates/test_quantizer_config.py +++ b/tests/post_training/test_templates/test_quantizer_config.py @@ -114,10 +114,9 @@ def test_default_quantizer_config(self, single_conv_nncf_graph): @pytest.mark.parametrize("preset", [QuantizationPreset.MIXED, QuantizationPreset.PERFORMANCE]) @pytest.mark.parametrize("weight_bits", [8]) @pytest.mark.parametrize("activation_bits", [8]) - @pytest.mark.parametrize("signed_weights", [None]) - @pytest.mark.parametrize("signed_activations", [None]) - # TODO(kshpv): add signed_activations and signed_weights which should be independent from HW config. - def test_quantizer_config_from_ptq_params( + @pytest.mark.parametrize("signed_weights", [None, True, False]) + @pytest.mark.parametrize("signed_activations", [None, True, False]) + def test_quantizer_config_from_ptq_params_for_CPU( self, weight_per_channel, activation_per_channel, @@ -147,28 +146,37 @@ def test_quantizer_config_from_ptq_params( min_max_algo._backend_entity.shapeof_metatypes, min_max_algo._backend_entity.read_variable_metatypes, ) - q_setup = min_max_algo._get_quantizer_setup( - nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() - ) - q_g_to_quantization_mode = {} - for q_g in QuantizerGroup: - q_g_to_quantization_mode[q_g] = preset.get_params_configured_by_preset(q_g)["mode"] + if signed_weights is False or signed_activations in [True, False]: # Incompatible with HW CPU config + with pytest.raises( + ValueError, + match=".*?Quantization parameter constraints specified in NNCF config are incompatible.*?", + ): + q_setup = min_max_algo._get_quantizer_setup( + nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() + ) + else: + q_setup = min_max_algo._get_quantizer_setup( + nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern() + ) + q_g_to_quantization_mode = {} + for q_g in QuantizerGroup: + q_g_to_quantization_mode[q_g] = preset.get_params_configured_by_preset(q_g)["mode"] - assert len(q_setup.quantization_points) == 2 + assert len(q_setup.quantization_points) == 2 - for quantization_point in q_setup.quantization_points.values(): - if quantization_point.is_weight_quantization_point(): - assert quantization_point.qconfig.mode == q_g_to_quantization_mode[QuantizerGroup.WEIGHTS] - assert quantization_point.qconfig.per_channel == weight_per_channel - assert quantization_point.qconfig.num_bits == weight_bits - if signed_weights is not None: - assert quantization_point.qconfig.signedness_to_force == signed_weights - if quantization_point.is_activation_quantization_point(): - assert quantization_point.qconfig.per_channel == activation_per_channel - assert quantization_point.qconfig.num_bits == activation_bits - assert quantization_point.qconfig.mode == q_g_to_quantization_mode[QuantizerGroup.ACTIVATIONS] - if signed_activations is not None: - assert quantization_point.qconfig.signedness_to_force == signed_activations + for quantization_point in q_setup.quantization_points.values(): + if quantization_point.is_weight_quantization_point(): + assert quantization_point.qconfig.mode == q_g_to_quantization_mode[QuantizerGroup.WEIGHTS] + assert quantization_point.qconfig.per_channel == weight_per_channel + assert quantization_point.qconfig.num_bits == weight_bits + if signed_weights is not None: + assert quantization_point.qconfig.signedness_to_force == signed_weights + if quantization_point.is_activation_quantization_point(): + assert quantization_point.qconfig.per_channel == activation_per_channel + assert quantization_point.qconfig.num_bits == activation_bits + assert quantization_point.qconfig.mode == q_g_to_quantization_mode[QuantizerGroup.ACTIVATIONS] + if signed_activations is not None: + assert quantization_point.qconfig.signedness_to_force == signed_activations def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph): algo = PostTrainingQuantization()