From 922170d14b517af4ff5369cb1b04aa7aff803f2c Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 28 Feb 2024 15:15:04 +0100 Subject: [PATCH] WIP transformations command serialization --- nncf/__init__.py | 2 + nncf/quantization/__init__.py | 2 + nncf/quantization/algorithms/algorithm.py | 25 +++++- .../algorithms/bias_correction/algorithm.py | 7 +- .../algorithms/channel_alignment/algorithm.py | 9 +- .../fast_bias_correction/algorithm.py | 8 +- .../algorithms/min_max/algorithm.py | 11 +-- .../algorithms/min_max/torch_backend.py | 10 +-- nncf/quantization/algorithms/pipeline.py | 19 +++- .../algorithms/post_training/algorithm.py | 15 ++++ .../algorithms/smooth_quant/algorithm.py | 9 +- .../algorithms/smooth_quant/torch_backend.py | 12 ++- .../weight_compression/algorithm.py | 14 ++- .../algorithms/weight_compression/awq.py | 17 ++-- nncf/quantization/quantize_model.py | 53 +++++++++++ .../graph/transformations/serialization.py | 90 +++++++++++++++++++ nncf/torch/quantization/layers.py | 5 ++ nncf/torch/quantization/quantize_model.py | 57 ++++++++++++ tests/post_training/pipelines/base.py | 16 ++++ .../test_quantize_conformance.py | 63 +++++++++++++ tests/torch/qat/test_qat_classification.py | 76 ++++++++++++++++ 21 files changed, 472 insertions(+), 48 deletions(-) create mode 100644 nncf/torch/graph/transformations/serialization.py diff --git a/nncf/__init__.py b/nncf/__init__.py index d14ff86960e..0dec70c0a59 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -39,7 +39,9 @@ from nncf.parameters import SensitivityMetric as SensitivityMetric from nncf.parameters import TargetDevice as TargetDevice from nncf.quantization import QuantizationPreset as QuantizationPreset +from nncf.quantization import apply_transformations as apply_transformations from nncf.quantization import compress_weights as compress_weights +from nncf.quantization import get_quantization_transformations as get_quantization_transformations from nncf.quantization import quantize as quantize from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control from nncf.quantization.advanced_parameters import ( diff --git a/nncf/quantization/__init__.py b/nncf/quantization/__init__.py index a1b78c774e1..1c49408735e 100644 --- a/nncf/quantization/__init__.py +++ b/nncf/quantization/__init__.py @@ -10,6 +10,8 @@ # limitations under the License. """Post-training quantization APIs.""" from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset +from nncf.quantization.quantize_model import apply_transformations as apply_transformations from nncf.quantization.quantize_model import compress_weights as compress_weights +from nncf.quantization.quantize_model import get_quantization_transformations as get_quantization_transformations from nncf.quantization.quantize_model import quantize as quantize from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control diff --git a/nncf/quantization/algorithms/algorithm.py b/nncf/quantization/algorithms/algorithm.py index befe0a82f9d..7f3f1ea24e5 100644 --- a/nncf/quantization/algorithms/algorithm.py +++ b/nncf/quantization/algorithms/algorithm.py @@ -14,7 +14,9 @@ from typing import List, Optional, TypeVar from nncf import Dataset +from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -35,7 +37,6 @@ def available_backends(self) -> List[BackendType]: :return: List of backends supported by the algorithm. """ - @abstractmethod def apply( self, model: TModel, @@ -52,6 +53,28 @@ def apply( :param dataset: A representative dataset for the calibration process. :return: A resulting model. """ + transformation_layout = self.get_transformation_layout(model, graph, statistic_points, dataset) + return self.apply_transformation_layout(transformation_layout) + + @abstractmethod + def get_transformation_layout( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TransformationLayout: + """ + get_transformation_layout + """ + + def apply_transformation_layout(self, model: TModel, transformation_layout: TransformationLayout) -> TModel: + """ + apply_transformation_layout + """ + model_transformer = ModelTransformerFactory.create(model) + transformed_model = model_transformer.transform(transformation_layout) + return transformed_model @abstractmethod def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index 4dc605c85e6..75a35d95e56 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -128,16 +128,15 @@ def _set_backend_entity(self, model: TModel) -> None: "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) ) - def apply( + def get_transformation_layout( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> TransformationLayout: self._set_backend_entity(model) main_transformations_layout = TransformationLayout() - main_model_transformer = ModelTransformerFactory.create(model) model_copy = copy_model(model) graph_copy = NNCFGraphFactory.create(model_copy) @@ -202,7 +201,7 @@ def apply( # to reduce memory usage during the algorithm's pipeline. self._remove_unnecessary_stats(position, subgraphs_data) - return main_model_transformer.transform(main_transformations_layout) + return main_transformations_layout def _is_node_correctable(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: """ diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index c8582cef551..36edfa02b43 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -15,7 +15,6 @@ from nncf import Dataset from nncf.common.factory import CommandCreatorFactory -from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.patterns import GraphPattern @@ -90,15 +89,14 @@ def _set_backend_entity(self, model: TModel) -> None: self._backend_entity = OVChannelAlignmentAlgoBackend() - def apply( + def get_transformation_layout( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> TransformationLayout: self._set_backend_entity(model) - model_transformer = ModelTransformerFactory.create(model) transformation_layout = TransformationLayout() def filter_func(point: StatisticPoint) -> bool: @@ -168,8 +166,7 @@ def filter_func(point: StatisticPoint) -> bool: command = command_creator.create_command_to_insert_bias(container.op, container.bias) transformation_layout.register(command) - transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return transformation_layout @staticmethod def _align_means( diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 15f9a2f05b1..fffa68e0348 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -121,13 +121,13 @@ def _set_backend_entity(self, model: TModel) -> None: "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) ) - def apply( + def get_transformation_layout( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> TransformationLayout: self._set_backend_entity(model) model_transformer = ModelTransformerFactory.create(model) @@ -185,9 +185,7 @@ def apply( transformation_layout = TransformationLayout() for node, bias_value in node_and_new_bias_value: transformation_layout.register(self._backend_entity.create_bias_correction_command(node, bias_value, graph)) - transformed_model = model_transformer.transform(transformation_layout) - - return transformed_model + return transformation_layout @staticmethod def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> float: diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index f728d0ee99c..bb789941e56 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -18,7 +18,6 @@ import nncf from nncf import Dataset -from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -803,15 +802,14 @@ def _get_quantization_points_overflow_fix( output.update(nodes) return output - def apply( + def get_transformation_layout( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> TransformationLayout: transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model) quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph) quantization_points_overflow_fix = self._get_quantization_points_overflow_fix( self._overflow_fix, quantization_target_points, graph @@ -895,10 +893,7 @@ def filter_func(point: StatisticPoint) -> bool: graph, quantization_target_point, qconfig, parameters ) transformation_layout.register(command) - if not transformation_layout.transformations: - nncf_logger.info("The model has no operations to apply quantization.") - quantized_model = model_transformer.transform(transformation_layout) - return quantized_model + return transformation_layout def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: self._set_backend_entity(model) diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 06d1c91e6f1..d0ed3ccc3a1 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -283,22 +283,22 @@ def _create_quantizer( quantizer = quantizer_cls(quantizer_spec) # Fill it with minmax - PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters) + PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) return quantizer @staticmethod - def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) input_range = parameters.input_high - parameters.input_low # Subtract eps from the input_range to make quantizer parameters equal to # original parameters on the forward call. - quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps) + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) else: quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) # Subtract eps from the scale to make quantizer parameters equal to # original parameters on the forward call. - quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps) + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) @staticmethod def _create_quantizer_insertion_command( diff --git a/nncf/quantization/algorithms/pipeline.py b/nncf/quantization/algorithms/pipeline.py index cd615258553..f5948b35d5d 100644 --- a/nncf/quantization/algorithms/pipeline.py +++ b/nncf/quantization/algorithms/pipeline.py @@ -71,6 +71,8 @@ def __init__(self, pipeline_steps: List[PipelineStep]): :param pipeline_steps: A sequence of pipeline steps to be executed in order. """ self._pipeline_steps = pipeline_steps + self._algorithms_applied = False + self._transformation_layout_list = [] @property def pipeline_steps(self) -> List[PipelineStep]: @@ -114,12 +116,22 @@ def run_step( pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) pipeline_step = pipeline_steps[step_index] for algorithm in pipeline_step[:-1]: - current_model = algorithm.apply(current_model, current_graph, step_statistics) + current_model = self._apply_algoritm(algorithm, current_model, step_statistics, current_graph) current_graph = NNCFGraphFactory.create(current_model) - current_model = pipeline_step[-1].apply(current_model, current_graph, step_statistics) - + current_model = self._apply_algoritm(pipeline_step[-1], current_model, step_statistics, current_graph) return current_model + def _apply_algoritm( + self, + algorithm: Algorithm, + current_model: TModel, + step_statistics: StatisticPointsContainer, + current_graph: NNCFGraph, + ) -> TModel: + transformation_layout = algorithm.get_transformation_layout(current_model, current_graph, step_statistics) + self._transformation_layout_list.extend(transformation_layout.transformations) + return algorithm.apply_transformation_layout(current_model, transformation_layout) + def run_from_step( self, model: TModel, @@ -165,6 +177,7 @@ def run_from_step( step_graph = None # We should rebuild the graph for the next pipeline step + self._algorithms_applied = True return step_model def get_statistic_points_for_step( diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 862dc5d5037..6a57309d857 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -14,6 +14,7 @@ from nncf import Dataset from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.quantization.structs import QuantizationPreset from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -92,6 +93,20 @@ def available_backends(self) -> List[BackendType]: def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: return self._pipeline.get_statistic_points_for_step(0, model, graph) + def get_transformation_layout( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TransformationLayout: + if not self._pipeline._algorithms_applied: + self.apply(model, graph, statistic_points, dataset) + return self.get_transformation_layout(model, graph, statistic_points, dataset) + transformation_layout = TransformationLayout() + transformation_layout.transformations.extend(self._pipeline._transformation_layout_list) + return transformation_layout + def apply( self, model: TModel, diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 3e2994e2bea..63a79bb811c 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -16,7 +16,6 @@ import nncf from nncf import Dataset -from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -91,18 +90,17 @@ def _set_backend_entity(self, model: TModel) -> None: "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) ) - def apply( + def get_transformation_layout( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> TransformationLayout: self._set_backend_entity(model) alpha_map = self._get_alpha_map() nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys()) - model_transformer = ModelTransformerFactory.create(model) transformation_layout = TransformationLayout() node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph) @@ -169,9 +167,6 @@ def apply( ) transformation_layout.register(scale_insertion_command) - transformed_model = model_transformer.transform(transformation_layout) - return transformed_model - @staticmethod def _calculate_scale_and_ratio( activations: Tensor, weights: Tensor, alpha: float, quantile: Optional[float] = 0.1 diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index a486be98a4f..b4535771888 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -31,20 +31,30 @@ from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +COMPRESSION_MODULES.register() + class SQMultiply(torch.nn.Module): - def __init__(self, scale_value): + def __init__(self, scale_value=1.0): super().__init__() self._scale_value = scale_value def forward(self, x): return torch.mul(x, self._scale_value) + def get_state(self): + return {} + + @classmethod + def from_state(cls, state): + return cls() + PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index a2a80fea257..6fcfb4d6042 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -18,6 +18,7 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.scopes import should_consider_scope @@ -365,7 +366,18 @@ def do_compression( return transformed_model def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - pass + return StatisticPointsContainer() + + def get_transformation_layout( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TransformationLayout: + raise NotImplementedError( + "get_transformation_layout is not implemented yet for the weights compression algorithm." + ) def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]: """ diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 91c8dbcce36..45ea656609f 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -17,6 +17,7 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.graph_matching import find_subgraphs_matching_pattern +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -266,11 +267,13 @@ def apply( return model def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - """ - Returns statistic points, for which StatisticsCollector should collect statistics. - - :param model: Model for statistics collection. - :param graph: Model graph. - :return: Statistic points, for which StatisticsCollector should collect statistics. - """ return StatisticPointsContainer() + + def get_transformation_layout( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TransformationLayout: + raise NotImplementedError("get_transformation_layout is not implemented yet for the AWQ compression algorithm.") diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 6982b347600..d314701fb69 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -14,6 +14,7 @@ import nncf from nncf.api.compression import TModel from nncf.common.deprecation import warning_deprecated +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.quantization.structs import QuantizationPreset from nncf.common.utils.api_marker import api from nncf.common.utils.backend import BackendType @@ -482,3 +483,55 @@ def quantize_with_tune_hyperparams( quantized_model = hyperparameter_tuner.apply(model, validation_dataset) return quantized_model + + +@api(canonical_alias="nncf.get_quantization_transformations") +def get_quantization_transformations( + model: TModel, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> TransformationLayout: + """ + Applies transformation layout to the model. + """ + backend = get_backend(model) + if backend == BackendType.TORCH: + from nncf.torch.quantization.quantize_model import get_quantization_transformations + + return get_quantization_transformations( + model=model, + calibration_dataset=calibration_dataset, + mode=mode, + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") + + +@api(canonical_alias="nncf.apply_transformations") +def apply_transformations( + model: TModel, + transformation_layout: TransformationLayout, + example_input: TTensor, +) -> TModel: + """ + Applies transformation layout to the model. + """ + backend = get_backend(model) + if backend == BackendType.TORCH: + from nncf.torch.quantization.quantize_model import apply_transformations_impl + + return apply_transformations_impl(model, transformation_layout, example_input) + raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py new file mode 100644 index 00000000000..9c1fefbd190 --- /dev/null +++ b/nncf/torch/graph/transformations/serialization.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +import torch + +import nncf +from examples.torch.common.model_loader import COMPRESSION_STATE_ATTR +from examples.torch.common.model_loader import MODEL_STATE_ATTR +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.model_transformer import PTQuantizerInsertionCommand +from nncf.torch.model_transformer import PTSharedFnInsertionCommand +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import PTQuantizerSpec + + +class CompressionKeys(Enum): + QUANTIZER_INSERTION_COMMAND = "QUANTIZER_INSERTION_COMMAND" + SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND" + + +def serialize_transformations(model: torch.nn.Module, transformations_layout: TransformationLayout): + transformation_commands = [] + for transformation in transformations_layout.transformations: + if not isinstance(transformation, (PTQuantizerInsertionCommand, PTSharedFnInsertionCommand)): + continue + + serialized_transformation = dict() + if isinstance(transformation, PTQuantizerInsertionCommand): + serialized_transformation["type"] = CompressionKeys.QUANTIZER_INSERTION_COMMAND.value + serialized_transformation["target_point"] = transformation.target_point.get_state() + serialized_transformation["quantizer_spec"] = transformation.quantizer.quantizer_spec.get_state() + if isinstance(transformation, PTSharedFnInsertionCommand): + serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value + serialized_transformation["target_points"] = [point.get_state() for point in transformation.target_points] + serialized_transformation["fn_name"] = transformation.fn.__name__ + serialized_transformation["fn_state"] = transformation.fn.get_state() + serialized_transformation["op_name"] = transformation.op_name + serialized_transformation["priority"] = transformation.priority.value + serialized_transformation["hooks_group_name"] = transformation.hooks_group_name + transformation_commands.append(serialized_transformation) + + return {MODEL_STATE_ATTR: model.state_dict(), COMPRESSION_STATE_ATTR: transformation_commands} + + +def load_transformations(model: torch.nn.Module, transformations, example_input) -> torch.nn.Module: + transformation_layout = TransformationLayout() + for command in transformations[COMPRESSION_STATE_ATTR]: + if command["type"] == CompressionKeys.QUANTIZER_INSERTION_COMMAND.value: + qspec = PTQuantizerSpec.from_state(command["quantizer_spec"]) + quantizer_cls = QUANTIZATION_MODULES.get(qspec.mode) + quantizer = quantizer_cls(qspec) + target_point = PTTargetPoint.from_state(command["target_point"]) + command = PTQuantizerInsertionCommand( + point=target_point, quantizer=quantizer, hooks_group_name=command["hooks_group_name"] + ) + transformation_layout.register(command) + continue + + if command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value: + target_points = [PTTargetPoint.from_state(state) for state in command["target_points"]] + module_cls = COMPRESSION_MODULES.get(command["fn_name"]) + fn = module_cls.from_state(command["fn_state"]) + priority = TransformationPriority[command["priority"]] + command = PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + op_unique_name=command["op_name"], + priority=priority, + hooks_group_name=command["hooks_group_name"], + ) + transformation_layout.register(command) + + continue + raise RuntimeError(f"Command type {command['type']} is not supported.") + transformed_model = nncf.apply_transformations(model, transformation_layout, example_input) + transformed_model.load_state_dict(transformations[MODEL_STATE_ATTR]) + return transformed_model diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 937d156c8fe..2c32de77ee7 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -286,6 +286,7 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP class BaseQuantizer(nn.Module, ABC): def __init__(self, qspec: PTQuantizerSpec): super().__init__() + self._qspec = qspec self._narrow_range = qspec.narrow_range self._signedness_to_force = qspec.signedness_to_force self._is_using_log_scale_storage = qspec.logarithm_scale @@ -352,6 +353,10 @@ def level_high(self, val: int): def levels(self): return get_num_levels(self.level_low, self.level_high) + @property + def quantizer_spec(self) -> PTQuantizerSpec: + return self._qspec + @abstractmethod def enable_gradients(self): pass diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index fe883dabb6e..0804d6597db 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -15,9 +15,12 @@ import torch import nncf +from nncf.common.factory import ModelTransformerFactory from nncf.common.factory import NNCFGraphFactory +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset +from nncf.experimental.tensor import Tensor from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import QuantizationMode @@ -99,3 +102,57 @@ def compress_weights_impl( ) graph = NNCFGraphFactory.create(model) return compression_algorithm.apply(model, graph, dataset=dataset) + + +def apply_transformations_impl( + model: torch.nn.Module, transformation_layout: TransformationLayout, example_input: Tensor +): + copied_model = deepcopy(model) + nncf_network = wrap_model(copied_model.eval(), example_input) + model_transformer = ModelTransformerFactory.create(nncf_network) + transformed_model = model_transformer.transform(transformation_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def get_quantization_transformations( + model: torch.nn.Module, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.nn.Module: + """ + Implementation of the `quantize()` method for the PyTorch backend. + """ + if fast_bias_correction is False: + raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") + if target_device == TargetDevice.CPU_SPR: + raise nncf.InternalError("target_device == CPU_SPR is not supported") + if mode is not None: + raise ValueError(f"mode={mode} is not supported") + + copied_model = deepcopy(model) + + example_input = next(iter(calibration_dataset.get_inference_data())) + nncf_network = wrap_model(copied_model.eval(), example_input) + + quantization_algorithm = PostTrainingQuantization( + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + + return quantization_algorithm.get_transformation_layout( + nncf_network, nncf_network.nncf.get_graph(), dataset=calibration_dataset + ) diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index 3f7dde51993..0a3aa6275a7 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -261,6 +261,16 @@ def prepare(self): self.prepare_preprocessor() self.prepare_calibration_dataset() + def get_state(self): + """ + Get state of the compressed model. + """ + + def load_state(self, state): + """ + load state of the compressed model. + """ + def validate(self) -> None: """ Validate and compare result with reference. @@ -412,3 +422,9 @@ def collect_data_from_stdout(self, stdout: str): stats = PTQTimeStats() stats.fill(stdout) self.run_info.stats_from_output = stats + + def get_state(self): + return self.compressed_model.state_dict().copy() + + def load_state(self, state): + self.compressed_model.load_state_dict(state) diff --git a/tests/post_training/test_quantize_conformance.py b/tests/post_training/test_quantize_conformance.py index 2b3c06fa0b7..6e5cfa470c4 100644 --- a/tests/post_training/test_quantize_conformance.py +++ b/tests/post_training/test_quantize_conformance.py @@ -244,6 +244,69 @@ def test_ptq_quantization( pytest.fail(err_msg) +@pytest.mark.parametrize("test_case_name", PTQ_TEST_CASES.keys()) +def test_ptq_get_load_state( + ptq_reference_data: dict, + test_case_name: str, + data_dir: Path, + output_dir: Path, + ptq_result_data: Dict[str, RunInfo], + no_eval: bool, + run_fp32_backend: bool, + run_torch_cuda_backend: bool, + subset_size: Optional[int], + run_benchmark_app: bool, + capsys: pytest.CaptureFixture, + extra_columns: bool, +): + pipeline = None + err_msg = None + test_model_param = None + start_time = time.perf_counter() + try: + if test_case_name not in ptq_reference_data: + raise nncf.ValidationError(f"{test_case_name} does not exist in 'reference_data.yaml'") + test_model_param = PTQ_TEST_CASES[test_case_name] + maybe_skip_test_case(test_model_param, run_fp32_backend, run_torch_cuda_backend) + pipeline_cls = test_model_param["pipeline_cls"] + pipeline_kwargs = create_pipeline_kwargs(test_model_param, subset_size, test_case_name, ptq_reference_data) + pipeline_kwargs.update( + {"output_dir": output_dir, "data_dir": data_dir, "no_eval": no_eval, "run_benchmark_app": run_benchmark_app} + ) + pipeline: BaseTestPipeline = pipeline_cls(**pipeline_kwargs) + pipeline.prepare() + pipeline.compress() + state = pipeline.get_state() + pipeline.compress() + pipeline.load_state(state) + pipeline.save_compressed_model() + pipeline.get_num_compressed() + pipeline.validate() + except Exception as e: + err_msg = str(e) + traceback.print_exc() + + if pipeline is not None: + pipeline.cleanup_cache() + run_info = pipeline.run_info + if err_msg: + run_info.status = f"{run_info.status} | {err_msg}" if run_info.status else err_msg + + captured = capsys.readouterr() + write_logs(captured, pipeline) + + if extra_columns: + pipeline.collect_data_from_stdout(captured.out) + else: + run_info = create_short_run_info(test_model_param, err_msg, test_case_name) + + run_info.time_total = time.perf_counter() - start_time + ptq_result_data[test_case_name] = run_info + + if err_msg: + pytest.fail(err_msg) + + @pytest.mark.parametrize("test_case_name", WC_TEST_CASES.keys()) def test_weight_compression( wc_reference_data: dict, diff --git a/tests/torch/qat/test_qat_classification.py b/tests/torch/qat/test_qat_classification.py index 8e5a0c69287..68a748b88ab 100644 --- a/tests/torch/qat/test_qat_classification.py +++ b/tests/torch/qat/test_qat_classification.py @@ -45,6 +45,8 @@ from examples.torch.common.utils import is_pretrained_model_requested from nncf import NNCFConfig from nncf.common.compression import BaseCompressionAlgorithmController +from nncf.torch.graph.transformations.serialization import load_transformations +from nncf.torch.graph.transformations.serialization import serialize_transformations from nncf.torch.initialization import default_criterion_fn from nncf.torch.utils import is_main_process from tests.shared.paths import PROJECT_ROOT @@ -225,6 +227,16 @@ def test_compression_training(quantization_config: SampleConfig): start_worker(main_worker, quantization_config) +def test_compression_training_with_safe_and_load_state(quantization_config): + if quantization_config.model == "mobilenet_v3_small": + # Use default range initializer for mobilenet_v3_small + # as due to PTQ advantages it works better for the model. + del quantization_config.nncf_config["compression"]["initializer"]["range"] + del quantization_config["compression"]["initializer"]["range"] + + start_worker(save_load_main_worker, quantization_config) + + def main_worker(current_gpu: int, config: SampleConfig): configure_device(current_gpu, config) if is_main_process(): @@ -280,3 +292,67 @@ def main_worker(current_gpu: int, config: SampleConfig): assert accuracy_drop_is_acceptable(acc_drop) check_training_correctness(config, model, datasets, criterion, train_criterion_fn) logger.info("Done!") + + +def save_load_main_worker(current_gpu: int, config: SampleConfig): + configure_device(current_gpu, config) + if is_main_process(): + configure_logging(logger, config) + else: + config.tb = None + + pretrained = is_pretrained_model_requested(config) + model_name = config["model"] + # create model + logger.info(f"\nCreating model from config: {config.config}") + model = load_model( + model_name, + pretrained=pretrained, + num_classes=config.get("num_classes", 1000), + model_params=config.get("model_params"), + weights_path=config.get("weights"), + ) + model.to(config.device) + + datasets = get_datasets(config) + criterion = nn.CrossEntropyLoss() + criterion = criterion.to(config.device) + + logger.info("Original model validation:") + original_accuracy, *_ = validate(datasets.val_data_loader, model, criterion, config) + + logger.info("Apply quantization to the model:") + config_quantization_params = config["compression"] + + preset = get_quantization_preset(config_quantization_params) + advanced_parameters = get_advanced_ptq_parameters(config_quantization_params) + subset_size = get_num_samples(config_quantization_params) + + transformations = nncf.get_quantization_transformations( + # quantized_model = nncf.quantize( + model, + datasets.calibration_dataset, + preset=preset, + advanced_parameters=advanced_parameters, + subset_size=subset_size, + ) + quantized_model = nncf.apply_transformations( + model, transformations, next(iter(datasets.calibration_dataset.get_inference_data())) + ) + ckpt = serialize_transformations(quantized_model, transformations) + del quantized_model + quantized_model = load_transformations(model, ckpt, next(iter(datasets.calibration_dataset.get_inference_data()))) + + train_criterion_fn = inception_criterion_fn if "inception" in model_name else default_criterion_fn + acc_drop = train( + quantized_model, + config, + criterion, + train_criterion_fn, + datasets, + original_accuracy, + get_mocked_compression_ctrl(), + ) + assert accuracy_drop_is_acceptable(acc_drop) + check_training_correctness(config, model, datasets, criterion, train_criterion_fn) + logger.info("Done!")